debeir.core.query

  1import dataclasses
  2from typing import Dict, Optional, Union
  3
  4import loguru
  5from debeir.engines.elasticsearch.generate_script_score import generate_script
  6from debeir.core.config import GenericConfig, apply_config
  7from debeir.utils.scaler import get_z_value
  8
  9
 10@dataclasses.dataclass(init=True)
 11class Query:
 12    """
 13    A query interface class
 14    :param topics: Topics that the query will be composed of
 15    :param config: Config object that contains the settings for querying
 16    """
 17    topics: Dict[int, Dict[str, str]]
 18    config: GenericConfig
 19
 20
 21class GenericElasticsearchQuery(Query):
 22    """
 23    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries.
 24    Requires topics, configs to be included
 25    """
 26    id_mapping: str = "Id"
 27
 28    def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs):
 29        super().__init__(topics, config)
 30
 31        if id_mapping is None:
 32            self.id_mapping = "id"
 33
 34        if mappings is None:
 35            self.mappings = ["Text"]
 36        else:
 37            self.mappings = mappings
 38
 39        self.topics = topics
 40        self.config = config
 41        self.query_type = self.config.query_type
 42
 43        self.embed_mappings = ["Text_Embedding"]
 44
 45        self.query_funcs = {
 46            "query": self.generate_query,
 47            "embedding": self.generate_query_embedding,
 48        }
 49
 50        self.top_bm25_scores = top_bm25_scores
 51
 52    def _generate_base_query(self, topic_num):
 53        qfield = list(self.topics[topic_num].keys())[0]
 54        query = self.topics[topic_num][qfield]
 55        should = {"should": []}
 56
 57        for i, field in enumerate(self.mappings):
 58            should["should"].append(
 59                {
 60                    "match": {
 61                        f"{field}": {
 62                            "query": query,
 63                        }
 64                    }
 65                }
 66            )
 67
 68        return qfield, query, should
 69
 70    def generate_query(self, topic_num, *args, **kwargs):
 71        """
 72        Generates a simple BM25 query based off the query facets. Searches over all the document facets.
 73        :param topic_num:
 74        :param args:
 75        :param kwargs:
 76        :return:
 77        """
 78        _, _, should = self._generate_base_query(topic_num)
 79
 80        query = {
 81            "query": {
 82                "bool": should,
 83            }
 84        }
 85
 86        return query
 87
 88    def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
 89        """
 90        Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used
 91        for log normalization.
 92
 93        Score = log(bm25)/log(z) + embed_score
 94        :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
 95        """
 96        self.top_bm25_scores = scores
 97
 98    def has_bm25_scores(self):
 99        """
100        Checks if BM25 scores have been set
101        :return:
102        """
103        return self.top_bm25_scores is not None
104
105    @apply_config
106    def generate_query_embedding(
107            self, topic_num, encoder, *args, norm_weight=2.15, ablations=False, cosine_ceiling=Optional[float],
108            cosine_offset: float = 1.0, **kwargs):
109        """
110        Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.
111
112        :param topic_num: The topic number to search for
113        :param encoder: The encoder that will be used for encoding the topics
114        :param norm_weight: The BM25 log normalization constant
115        :param ablations: Whether to execute ablation style queries (i.e. one query facet
116                          or one document facet at a time)
117        :param cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
118        :param args:
119        :param kwargs: Pass disable_cache to disable encoder caching
120        :return:
121            An elasticsearch script_score query
122        """
123
124        qfields = list(self.topics[topic_num].keys())
125        should = {"should": []}
126
127        if self.has_bm25_scores():
128            cosine_ceiling = len(self.embed_mappings) * len(qfields) if cosine_ceiling is None else cosine_ceiling
129            norm_weight = get_z_value(
130                cosine_ceiling=cosine_ceiling,
131                bm25_ceiling=self.top_bm25_scores[topic_num],
132            )
133            loguru.logger.debug(f"Automatic norm_weight: {norm_weight}")
134
135        params = {
136            "weights": [1] * (len(self.embed_mappings) * len(self.mappings)),
137            "offset": cosine_offset,
138            "norm_weight": norm_weight,
139            "disable_bm25": ablations,
140        }
141
142        embed_fields = []
143
144        for qfield in qfields:
145            for field in self.mappings:
146                should["should"].append(
147                    {
148                        "match": {
149                            f"{field}": {
150                                "query": self.topics[topic_num][qfield],
151                            }
152                        }
153                    }
154                )
155
156            params[f"{qfield}_eb"] = encoder.encode(topic=self.topics[topic_num][qfield])
157            embed_fields.append(f"{qfield}_eb")
158
159        query = {
160            "query": {
161                "script_score": {
162                    "query": {
163                        "bool": should,
164                    },
165                    "script": generate_script(
166                        self.embed_mappings, params, qfields=embed_fields
167                    ),
168                }
169            }
170        }
171
172        loguru.logger.debug(query)
173        return query
174
175    @classmethod
176    def get_id_mapping(cls, hit):
177        """
178        Get the document ID
179
180        :param hit: The raw document result
181        :return:
182            The document's ID
183        """
184        return hit[cls.id_mapping]
@dataclasses.dataclass(init=True)
class Query:
11@dataclasses.dataclass(init=True)
12class Query:
13    """
14    A query interface class
15    :param topics: Topics that the query will be composed of
16    :param config: Config object that contains the settings for querying
17    """
18    topics: Dict[int, Dict[str, str]]
19    config: GenericConfig

A query interface class

Parameters
  • topics: Topics that the query will be composed of
  • config: Config object that contains the settings for querying
Query( topics: Dict[int, Dict[str, str]], config: debeir.core.config.GenericConfig)
class GenericElasticsearchQuery(Query):
 22class GenericElasticsearchQuery(Query):
 23    """
 24    A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries.
 25    Requires topics, configs to be included
 26    """
 27    id_mapping: str = "Id"
 28
 29    def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs):
 30        super().__init__(topics, config)
 31
 32        if id_mapping is None:
 33            self.id_mapping = "id"
 34
 35        if mappings is None:
 36            self.mappings = ["Text"]
 37        else:
 38            self.mappings = mappings
 39
 40        self.topics = topics
 41        self.config = config
 42        self.query_type = self.config.query_type
 43
 44        self.embed_mappings = ["Text_Embedding"]
 45
 46        self.query_funcs = {
 47            "query": self.generate_query,
 48            "embedding": self.generate_query_embedding,
 49        }
 50
 51        self.top_bm25_scores = top_bm25_scores
 52
 53    def _generate_base_query(self, topic_num):
 54        qfield = list(self.topics[topic_num].keys())[0]
 55        query = self.topics[topic_num][qfield]
 56        should = {"should": []}
 57
 58        for i, field in enumerate(self.mappings):
 59            should["should"].append(
 60                {
 61                    "match": {
 62                        f"{field}": {
 63                            "query": query,
 64                        }
 65                    }
 66                }
 67            )
 68
 69        return qfield, query, should
 70
 71    def generate_query(self, topic_num, *args, **kwargs):
 72        """
 73        Generates a simple BM25 query based off the query facets. Searches over all the document facets.
 74        :param topic_num:
 75        :param args:
 76        :param kwargs:
 77        :return:
 78        """
 79        _, _, should = self._generate_base_query(topic_num)
 80
 81        query = {
 82            "query": {
 83                "bool": should,
 84            }
 85        }
 86
 87        return query
 88
 89    def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
 90        """
 91        Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used
 92        for log normalization.
 93
 94        Score = log(bm25)/log(z) + embed_score
 95        :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
 96        """
 97        self.top_bm25_scores = scores
 98
 99    def has_bm25_scores(self):
100        """
101        Checks if BM25 scores have been set
102        :return:
103        """
104        return self.top_bm25_scores is not None
105
106    @apply_config
107    def generate_query_embedding(
108            self, topic_num, encoder, *args, norm_weight=2.15, ablations=False, cosine_ceiling=Optional[float],
109            cosine_offset: float = 1.0, **kwargs):
110        """
111        Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.
112
113        :param topic_num: The topic number to search for
114        :param encoder: The encoder that will be used for encoding the topics
115        :param norm_weight: The BM25 log normalization constant
116        :param ablations: Whether to execute ablation style queries (i.e. one query facet
117                          or one document facet at a time)
118        :param cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
119        :param args:
120        :param kwargs: Pass disable_cache to disable encoder caching
121        :return:
122            An elasticsearch script_score query
123        """
124
125        qfields = list(self.topics[topic_num].keys())
126        should = {"should": []}
127
128        if self.has_bm25_scores():
129            cosine_ceiling = len(self.embed_mappings) * len(qfields) if cosine_ceiling is None else cosine_ceiling
130            norm_weight = get_z_value(
131                cosine_ceiling=cosine_ceiling,
132                bm25_ceiling=self.top_bm25_scores[topic_num],
133            )
134            loguru.logger.debug(f"Automatic norm_weight: {norm_weight}")
135
136        params = {
137            "weights": [1] * (len(self.embed_mappings) * len(self.mappings)),
138            "offset": cosine_offset,
139            "norm_weight": norm_weight,
140            "disable_bm25": ablations,
141        }
142
143        embed_fields = []
144
145        for qfield in qfields:
146            for field in self.mappings:
147                should["should"].append(
148                    {
149                        "match": {
150                            f"{field}": {
151                                "query": self.topics[topic_num][qfield],
152                            }
153                        }
154                    }
155                )
156
157            params[f"{qfield}_eb"] = encoder.encode(topic=self.topics[topic_num][qfield])
158            embed_fields.append(f"{qfield}_eb")
159
160        query = {
161            "query": {
162                "script_score": {
163                    "query": {
164                        "bool": should,
165                    },
166                    "script": generate_script(
167                        self.embed_mappings, params, qfields=embed_fields
168                    ),
169                }
170            }
171        }
172
173        loguru.logger.debug(query)
174        return query
175
176    @classmethod
177    def get_id_mapping(cls, hit):
178        """
179        Get the document ID
180
181        :param hit: The raw document result
182        :return:
183            The document's ID
184        """
185        return hit[cls.id_mapping]

A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. Requires topics, configs to be included

GenericElasticsearchQuery( topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs)
29    def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs):
30        super().__init__(topics, config)
31
32        if id_mapping is None:
33            self.id_mapping = "id"
34
35        if mappings is None:
36            self.mappings = ["Text"]
37        else:
38            self.mappings = mappings
39
40        self.topics = topics
41        self.config = config
42        self.query_type = self.config.query_type
43
44        self.embed_mappings = ["Text_Embedding"]
45
46        self.query_funcs = {
47            "query": self.generate_query,
48            "embedding": self.generate_query_embedding,
49        }
50
51        self.top_bm25_scores = top_bm25_scores
def generate_query(self, topic_num, *args, **kwargs):
71    def generate_query(self, topic_num, *args, **kwargs):
72        """
73        Generates a simple BM25 query based off the query facets. Searches over all the document facets.
74        :param topic_num:
75        :param args:
76        :param kwargs:
77        :return:
78        """
79        _, _, should = self._generate_base_query(topic_num)
80
81        query = {
82            "query": {
83                "bool": should,
84            }
85        }
86
87        return query

Generates a simple BM25 query based off the query facets. Searches over all the document facets.

Parameters
  • topic_num:
  • args:
  • kwargs:
Returns
def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
89    def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
90        """
91        Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used
92        for log normalization.
93
94        Score = log(bm25)/log(z) + embed_score
95        :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
96        """
97        self.top_bm25_scores = scores

Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used for log normalization.

Score = log(bm25)/log(z) + embed_score

Parameters
  • scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
def has_bm25_scores(self):
 99    def has_bm25_scores(self):
100        """
101        Checks if BM25 scores have been set
102        :return:
103        """
104        return self.top_bm25_scores is not None

Checks if BM25 scores have been set

Returns
def generate_query_embedding(self, *args, **kwargs):
229    def use_config(self, *args, **kwargs):
230        """
231        Replaces keywords and args passed to the function with ones from self.config.
232
233        :param self:
234        :param args: To be updated
235        :param kwargs: To be updated
236        :return:
237        """
238        if self.config is not None:
239            kwargs = self.config.__update__(**kwargs)
240
241        return func(self, *args, **kwargs)

Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.

Parameters
  • topic_num: The topic number to search for
  • encoder: The encoder that will be used for encoding the topics
  • norm_weight: The BM25 log normalization constant
  • ablations: Whether to execute ablation style queries (i.e. one query facet or one document facet at a time)
  • cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
  • args:
  • kwargs: Pass disable_cache to disable encoder caching
Returns
An elasticsearch script_score query
@classmethod
def get_id_mapping(cls, hit):
176    @classmethod
177    def get_id_mapping(cls, hit):
178        """
179        Get the document ID
180
181        :param hit: The raw document result
182        :return:
183            The document's ID
184        """
185        return hit[cls.id_mapping]

Get the document ID

Parameters
  • hit: The raw document result
Returns
The document's ID