debeir.datasets.marco

 1from dataclasses import dataclass
 2from typing import Dict, Optional, Union
 3
 4from debeir.core.config import GenericConfig
 5from debeir.core.executor import GenericElasticsearchExecutor
 6from debeir.core.query import GenericElasticsearchQuery
 7from debeir.rankers.transformer_sent_encoder import Encoder
 8from elasticsearch import AsyncElasticsearch as Elasticsearch
 9
10
11class MarcoElasticsearchExecutor(GenericElasticsearchExecutor):
12    query: GenericElasticsearchQuery
13
14    def __init__(
15            self,
16            topics: Dict[Union[str, int], Dict[str, str]],
17            client: Elasticsearch,
18            index_name: str,
19            output_file: str,
20            query: GenericElasticsearchQuery,
21            encoder: Optional[Encoder] = None,
22            config=None,
23            *args,
24            **kwargs,
25    ):
26        super().__init__(
27            topics,
28            client,
29            index_name,
30            output_file,
31            query,
32            encoder,
33            config=config,
34            *args,
35            **kwargs,
36        )
37
38        self.query_fns = {
39            "query": self.generate_query,
40            "embedding": self.generate_embedding_query,
41        }
42
43    def generate_query(self, topic_num, best_fields=True, **kwargs):
44        return self.query.generate_query(topic_num)
45
46    def generate_embedding_query(
47            self,
48            topic_num,
49            cosine_weights=None,
50            query_weights=None,
51            norm_weight=2.15,
52            automatic_scores=None,
53            **kwargs,
54    ):
55        return super().generate_embedding_query(
56            topic_num,
57            cosine_weights=cosine_weights,
58            query_weights=query_weights,
59            norm_weight=2.15,
60            automatic_scores=None,
61            **kwargs,
62        )
63
64    async def execute_query(
65            self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs
66    ):
67        return super().execute_query(
68            query, topic_num, ablation, query_type=query_type, **kwargs
69        )
70
71
72@dataclass(init=True, unsafe_hash=True)
73class MarcoQueryConfig(GenericConfig):
74    def validate(self):
75        if self.query_type == "embedding":
76            assert (
77                    self.encoder_fp and self.encoder
78            ), "Must provide encoder path for embedding model"
79            assert self.norm_weight is not None or self.automatic is not None, (
80                "Norm weight be " "specified or be automatic"
81            )
82
83    @classmethod
84    def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig":
85        return super().from_toml(fp, cls, *args, **kwargs)
86
87    @classmethod
88    def from_dict(cls, **kwargs) -> "MarcoQueryConfig":
89        return super().from_dict(cls, **kwargs)
class MarcoElasticsearchExecutor(debeir.core.executor.GenericElasticsearchExecutor):
12class MarcoElasticsearchExecutor(GenericElasticsearchExecutor):
13    query: GenericElasticsearchQuery
14
15    def __init__(
16            self,
17            topics: Dict[Union[str, int], Dict[str, str]],
18            client: Elasticsearch,
19            index_name: str,
20            output_file: str,
21            query: GenericElasticsearchQuery,
22            encoder: Optional[Encoder] = None,
23            config=None,
24            *args,
25            **kwargs,
26    ):
27        super().__init__(
28            topics,
29            client,
30            index_name,
31            output_file,
32            query,
33            encoder,
34            config=config,
35            *args,
36            **kwargs,
37        )
38
39        self.query_fns = {
40            "query": self.generate_query,
41            "embedding": self.generate_embedding_query,
42        }
43
44    def generate_query(self, topic_num, best_fields=True, **kwargs):
45        return self.query.generate_query(topic_num)
46
47    def generate_embedding_query(
48            self,
49            topic_num,
50            cosine_weights=None,
51            query_weights=None,
52            norm_weight=2.15,
53            automatic_scores=None,
54            **kwargs,
55    ):
56        return super().generate_embedding_query(
57            topic_num,
58            cosine_weights=cosine_weights,
59            query_weights=query_weights,
60            norm_weight=2.15,
61            automatic_scores=None,
62            **kwargs,
63        )
64
65    async def execute_query(
66            self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs
67    ):
68        return super().execute_query(
69            query, topic_num, ablation, query_type=query_type, **kwargs
70        )

Generic Executor class for Elasticsearch

MarcoElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.core.query.GenericElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs)
15    def __init__(
16            self,
17            topics: Dict[Union[str, int], Dict[str, str]],
18            client: Elasticsearch,
19            index_name: str,
20            output_file: str,
21            query: GenericElasticsearchQuery,
22            encoder: Optional[Encoder] = None,
23            config=None,
24            *args,
25            **kwargs,
26    ):
27        super().__init__(
28            topics,
29            client,
30            index_name,
31            output_file,
32            query,
33            encoder,
34            config=config,
35            *args,
36            **kwargs,
37        )
38
39        self.query_fns = {
40            "query": self.generate_query,
41            "embedding": self.generate_embedding_query,
42        }
def generate_query(self, topic_num, best_fields=True, **kwargs):
44    def generate_query(self, topic_num, best_fields=True, **kwargs):
45        return self.query.generate_query(topic_num)

Generates a standard BM25 query given the topic number

Parameters
  • topic_num: Query topic number to generate
  • best_fields: Whether to use a curated list of fields
  • kwargs:
Returns
def generate_embedding_query( self, topic_num, cosine_weights=None, query_weights=None, norm_weight=2.15, automatic_scores=None, **kwargs):
47    def generate_embedding_query(
48            self,
49            topic_num,
50            cosine_weights=None,
51            query_weights=None,
52            norm_weight=2.15,
53            automatic_scores=None,
54            **kwargs,
55    ):
56        return super().generate_embedding_query(
57            topic_num,
58            cosine_weights=cosine_weights,
59            query_weights=query_weights,
60            norm_weight=2.15,
61            automatic_scores=None,
62            **kwargs,
63        )

Executes an NIR-style query with combined scoring.

Parameters
  • topic_num:
  • cosine_weights:
  • query_weights:
  • norm_weight:
  • automatic_scores:
  • kwargs:
Returns
async def execute_query( self, query=None, topic_num=None, ablation=False, query_type='query', **kwargs):
65    async def execute_query(
66            self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs
67    ):
68        return super().execute_query(
69            query, topic_num, ablation, query_type=query_type, **kwargs
70        )

Execute a query given parameters

Parameters
  • args:
  • kwargs:
@dataclass(init=True, unsafe_hash=True)
class MarcoQueryConfig(debeir.core.config.GenericConfig):
73@dataclass(init=True, unsafe_hash=True)
74class MarcoQueryConfig(GenericConfig):
75    def validate(self):
76        if self.query_type == "embedding":
77            assert (
78                    self.encoder_fp and self.encoder
79            ), "Must provide encoder path for embedding model"
80            assert self.norm_weight is not None or self.automatic is not None, (
81                "Norm weight be " "specified or be automatic"
82            )
83
84    @classmethod
85    def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig":
86        return super().from_toml(fp, cls, *args, **kwargs)
87
88    @classmethod
89    def from_dict(cls, **kwargs) -> "MarcoQueryConfig":
90        return super().from_dict(cls, **kwargs)
MarcoQueryConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None)
def validate(self):
75    def validate(self):
76        if self.query_type == "embedding":
77            assert (
78                    self.encoder_fp and self.encoder
79            ), "Must provide encoder path for embedding model"
80            assert self.norm_weight is not None or self.automatic is not None, (
81                "Norm weight be " "specified or be automatic"
82            )

Validates if the config is correct. Must be implemented by inherited classes.

@classmethod
def from_toml(cls, fp: str, *args, **kwargs) -> debeir.datasets.marco.MarcoQueryConfig:
84    @classmethod
85    def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig":
86        return super().from_toml(fp, cls, *args, **kwargs)

Instantiates a Config object from a toml file

Parameters
  • fp: File path of the Config TOML file
  • field_class: Class of the Config object to be instantiated
  • args: Arguments to be passed to Config
  • kwargs: Keyword arguments to be passed
Returns
A instantiated and validated Config object.
@classmethod
def from_dict(cls, **kwargs) -> debeir.datasets.marco.MarcoQueryConfig:
88    @classmethod
89    def from_dict(cls, **kwargs) -> "MarcoQueryConfig":
90        return super().from_dict(cls, **kwargs)

Instantiates a Config object from a dictionary

Parameters
  • data_class:
  • kwargs:
Returns