debeir.datasets.marco
1from dataclasses import dataclass 2from typing import Dict, Optional, Union 3 4from debeir.core.config import GenericConfig 5from debeir.core.executor import GenericElasticsearchExecutor 6from debeir.core.query import GenericElasticsearchQuery 7from debeir.rankers.transformer_sent_encoder import Encoder 8from elasticsearch import AsyncElasticsearch as Elasticsearch 9 10 11class MarcoElasticsearchExecutor(GenericElasticsearchExecutor): 12 query: GenericElasticsearchQuery 13 14 def __init__( 15 self, 16 topics: Dict[Union[str, int], Dict[str, str]], 17 client: Elasticsearch, 18 index_name: str, 19 output_file: str, 20 query: GenericElasticsearchQuery, 21 encoder: Optional[Encoder] = None, 22 config=None, 23 *args, 24 **kwargs, 25 ): 26 super().__init__( 27 topics, 28 client, 29 index_name, 30 output_file, 31 query, 32 encoder, 33 config=config, 34 *args, 35 **kwargs, 36 ) 37 38 self.query_fns = { 39 "query": self.generate_query, 40 "embedding": self.generate_embedding_query, 41 } 42 43 def generate_query(self, topic_num, best_fields=True, **kwargs): 44 return self.query.generate_query(topic_num) 45 46 def generate_embedding_query( 47 self, 48 topic_num, 49 cosine_weights=None, 50 query_weights=None, 51 norm_weight=2.15, 52 automatic_scores=None, 53 **kwargs, 54 ): 55 return super().generate_embedding_query( 56 topic_num, 57 cosine_weights=cosine_weights, 58 query_weights=query_weights, 59 norm_weight=2.15, 60 automatic_scores=None, 61 **kwargs, 62 ) 63 64 async def execute_query( 65 self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs 66 ): 67 return super().execute_query( 68 query, topic_num, ablation, query_type=query_type, **kwargs 69 ) 70 71 72@dataclass(init=True, unsafe_hash=True) 73class MarcoQueryConfig(GenericConfig): 74 def validate(self): 75 if self.query_type == "embedding": 76 assert ( 77 self.encoder_fp and self.encoder 78 ), "Must provide encoder path for embedding model" 79 assert self.norm_weight is not None or self.automatic is not None, ( 80 "Norm weight be " "specified or be automatic" 81 ) 82 83 @classmethod 84 def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig": 85 return super().from_toml(fp, cls, *args, **kwargs) 86 87 @classmethod 88 def from_dict(cls, **kwargs) -> "MarcoQueryConfig": 89 return super().from_dict(cls, **kwargs)
12class MarcoElasticsearchExecutor(GenericElasticsearchExecutor): 13 query: GenericElasticsearchQuery 14 15 def __init__( 16 self, 17 topics: Dict[Union[str, int], Dict[str, str]], 18 client: Elasticsearch, 19 index_name: str, 20 output_file: str, 21 query: GenericElasticsearchQuery, 22 encoder: Optional[Encoder] = None, 23 config=None, 24 *args, 25 **kwargs, 26 ): 27 super().__init__( 28 topics, 29 client, 30 index_name, 31 output_file, 32 query, 33 encoder, 34 config=config, 35 *args, 36 **kwargs, 37 ) 38 39 self.query_fns = { 40 "query": self.generate_query, 41 "embedding": self.generate_embedding_query, 42 } 43 44 def generate_query(self, topic_num, best_fields=True, **kwargs): 45 return self.query.generate_query(topic_num) 46 47 def generate_embedding_query( 48 self, 49 topic_num, 50 cosine_weights=None, 51 query_weights=None, 52 norm_weight=2.15, 53 automatic_scores=None, 54 **kwargs, 55 ): 56 return super().generate_embedding_query( 57 topic_num, 58 cosine_weights=cosine_weights, 59 query_weights=query_weights, 60 norm_weight=2.15, 61 automatic_scores=None, 62 **kwargs, 63 ) 64 65 async def execute_query( 66 self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs 67 ): 68 return super().execute_query( 69 query, topic_num, ablation, query_type=query_type, **kwargs 70 )
Generic Executor class for Elasticsearch
MarcoElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.core.query.GenericElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs)
15 def __init__( 16 self, 17 topics: Dict[Union[str, int], Dict[str, str]], 18 client: Elasticsearch, 19 index_name: str, 20 output_file: str, 21 query: GenericElasticsearchQuery, 22 encoder: Optional[Encoder] = None, 23 config=None, 24 *args, 25 **kwargs, 26 ): 27 super().__init__( 28 topics, 29 client, 30 index_name, 31 output_file, 32 query, 33 encoder, 34 config=config, 35 *args, 36 **kwargs, 37 ) 38 39 self.query_fns = { 40 "query": self.generate_query, 41 "embedding": self.generate_embedding_query, 42 }
def
generate_query(self, topic_num, best_fields=True, **kwargs):
44 def generate_query(self, topic_num, best_fields=True, **kwargs): 45 return self.query.generate_query(topic_num)
Generates a standard BM25 query given the topic number
Parameters
- topic_num: Query topic number to generate
- best_fields: Whether to use a curated list of fields
- kwargs:
Returns
def
generate_embedding_query( self, topic_num, cosine_weights=None, query_weights=None, norm_weight=2.15, automatic_scores=None, **kwargs):
47 def generate_embedding_query( 48 self, 49 topic_num, 50 cosine_weights=None, 51 query_weights=None, 52 norm_weight=2.15, 53 automatic_scores=None, 54 **kwargs, 55 ): 56 return super().generate_embedding_query( 57 topic_num, 58 cosine_weights=cosine_weights, 59 query_weights=query_weights, 60 norm_weight=2.15, 61 automatic_scores=None, 62 **kwargs, 63 )
Executes an NIR-style query with combined scoring.
Parameters
- topic_num:
- cosine_weights:
- query_weights:
- norm_weight:
- automatic_scores:
- kwargs:
Returns
async def
execute_query( self, query=None, topic_num=None, ablation=False, query_type='query', **kwargs):
65 async def execute_query( 66 self, query=None, topic_num=None, ablation=False, query_type="query", **kwargs 67 ): 68 return super().execute_query( 69 query, topic_num, ablation, query_type=query_type, **kwargs 70 )
Execute a query given parameters
Parameters
- args:
- kwargs:
73@dataclass(init=True, unsafe_hash=True) 74class MarcoQueryConfig(GenericConfig): 75 def validate(self): 76 if self.query_type == "embedding": 77 assert ( 78 self.encoder_fp and self.encoder 79 ), "Must provide encoder path for embedding model" 80 assert self.norm_weight is not None or self.automatic is not None, ( 81 "Norm weight be " "specified or be automatic" 82 ) 83 84 @classmethod 85 def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig": 86 return super().from_toml(fp, cls, *args, **kwargs) 87 88 @classmethod 89 def from_dict(cls, **kwargs) -> "MarcoQueryConfig": 90 return super().from_dict(cls, **kwargs)
MarcoQueryConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None)
def
validate(self):
75 def validate(self): 76 if self.query_type == "embedding": 77 assert ( 78 self.encoder_fp and self.encoder 79 ), "Must provide encoder path for embedding model" 80 assert self.norm_weight is not None or self.automatic is not None, ( 81 "Norm weight be " "specified or be automatic" 82 )
Validates if the config is correct. Must be implemented by inherited classes.
@classmethod
def
from_toml(cls, fp: str, *args, **kwargs) -> debeir.datasets.marco.MarcoQueryConfig:
84 @classmethod 85 def from_toml(cls, fp: str, *args, **kwargs) -> "MarcoQueryConfig": 86 return super().from_toml(fp, cls, *args, **kwargs)
Instantiates a Config object from a toml file
Parameters
- fp: File path of the Config TOML file
- field_class: Class of the Config object to be instantiated
- args: Arguments to be passed to Config
- kwargs: Keyword arguments to be passed
Returns
A instantiated and validated Config object.
88 @classmethod 89 def from_dict(cls, **kwargs) -> "MarcoQueryConfig": 90 return super().from_dict(cls, **kwargs)
Instantiates a Config object from a dictionary
Parameters
- data_class:
- kwargs: