debeir.datasets.clinical_trials

  1import csv
  2from dataclasses import dataclass
  3from typing import Dict, List, Optional, Union
  4
  5import loguru
  6from debeir.engines.elasticsearch.generate_script_score import generate_script
  7from debeir.core.config import GenericConfig, apply_config
  8from debeir.core.executor import GenericElasticsearchExecutor
  9from debeir.core.parser import Parser
 10from debeir.core.query import GenericElasticsearchQuery
 11from debeir.rankers.transformer_sent_encoder import Encoder
 12from debeir.utils.scaler import get_z_value
 13from elasticsearch import AsyncElasticsearch as Elasticsearch
 14
 15
 16@dataclass(init=True, unsafe_hash=True)
 17class TrialsQueryConfig(GenericConfig):
 18    query_field_usage: str = None
 19    embed_field_usage: str = None
 20    fields: List[str] = None
 21
 22    def validate(self):
 23        """
 24        Checks if query type is included, and checks if an encoder is included for embedding queries
 25        """
 26        if self.query_type == "embedding":
 27            assert self.query_field_usage and self.embed_field_usage, (
 28                "Must have both field usages" " if embedding query"
 29            )
 30            assert (
 31                    self.encoder_fp and self.encoder
 32            ), "Must provide encoder path for embedding model"
 33            assert self.norm_weight is not None or self.automatic is not None, (
 34                "Norm weight be specified or be " "automatic "
 35            )
 36
 37        assert (
 38                self.query_field_usage is not None or self.fields is not None
 39        ), "Must have a query field"
 40        assert self.query_type in [
 41            "ablation",
 42            "query",
 43            "query_best",
 44            "embedding",
 45        ], "Check your query type"
 46
 47    @classmethod
 48    def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig":
 49        return super().from_toml(fp, cls, *args, **kwargs)
 50
 51    @classmethod
 52    def from_dict(cls, **kwargs) -> "GenericConfig":
 53        return super().from_dict(cls, **kwargs)
 54
 55
 56class TrialsElasticsearchQuery(GenericElasticsearchQuery):
 57    """
 58    Elasticsearch Query object for the Clinical Trials Index
 59    """
 60    topics: Dict[int, Dict[str, str]]
 61    query_type: str
 62    fields: List[int]
 63    query_funcs: Dict
 64    config: GenericConfig
 65    id_mapping: str = "_id"
 66    mappings: List[str]
 67    config: TrialsQueryConfig
 68
 69    def __init__(self, topics, query_type, config=None, *args, **kwargs):
 70        super().__init__(topics, config, *args, **kwargs)
 71        self.query_type = query_type
 72        self.config = config
 73        self.topics = topics
 74        self.fields = []
 75        self.mappings = [
 76            "HasExpandedAccess",
 77            "BriefSummary.Textblock",
 78            "CompletionDate.Type",
 79            "OversightInfo.Text",
 80            "OverallContactBackup.PhoneExt",
 81            "RemovedCountries.Text",
 82            "SecondaryOutcome",
 83            "Sponsors.LeadSponsor.Text",
 84            "BriefTitle",
 85            "IDInfo.NctID",
 86            "IDInfo.SecondaryID",
 87            "OverallContactBackup.Phone",
 88            "Eligibility.StudyPop.Textblock",
 89            "DetailedDescription.Textblock",
 90            "Eligibility.MinimumAge",
 91            "Sponsors.Collaborator",
 92            "Reference",
 93            "Eligibility.Criteria.Textblock",
 94            "XMLName.Space",
 95            "Rank",
 96            "OverallStatus",
 97            "InterventionBrowse.Text",
 98            "Eligibility.Text",
 99            "Intervention",
100            "BiospecDescr.Textblock",
101            "ResponsibleParty.NameTitle",
102            "NumberOfArms",
103            "ResponsibleParty.ResponsiblePartyType",
104            "IsSection801",
105            "Acronym",
106            "Eligibility.MaximumAge",
107            "DetailedDescription.Text",
108            "StudyDesign",
109            "OtherOutcome",
110            "VerificationDate",
111            "ConditionBrowse.MeshTerm",
112            "Enrollment.Text",
113            "IDInfo.Text",
114            "ConditionBrowse.Text",
115            "FirstreceivedDate",
116            "NumberOfGroups",
117            "OversightInfo.HasDmc",
118            "PrimaryCompletionDate.Text",
119            "ResultsReference",
120            "Eligibility.StudyPop.Text",
121            "IsFdaRegulated",
122            "WhyStopped",
123            "ArmGroup",
124            "OverallContact.LastName",
125            "Phase",
126            "RemovedCountries.Country",
127            "InterventionBrowse.MeshTerm",
128            "Eligibility.HealthyVolunteers",
129            "Location",
130            "OfficialTitle",
131            "OverallContact.Email",
132            "RequiredHeader.Text",
133            "RequiredHeader.URL",
134            "LocationCountries.Country",
135            "OverallContact.PhoneExt",
136            "Condition",
137            "PrimaryOutcome",
138            "LocationCountries.Text",
139            "BiospecDescr.Text",
140            "IDInfo.OrgStudyID",
141            "Link",
142            "OverallContact.Phone",
143            "Source",
144            "ResponsibleParty.InvestigatorAffiliation",
145            "StudyType",
146            "FirstreceivedResultsDate",
147            "Enrollment.Type",
148            "Eligibility.Gender",
149            "OverallContactBackup.LastName",
150            "Keyword",
151            "BiospecRetention",
152            "CompletionDate.Text",
153            "OverallContact.Text",
154            "RequiredHeader.DownloadDate",
155            "Sponsors.Text",
156            "Text",
157            "Eligibility.SamplingMethod",
158            "LastchangedDate",
159            "ResponsibleParty.InvestigatorFullName",
160            "StartDate",
161            "RequiredHeader.LinkText",
162            "OverallOfficial",
163            "Sponsors.LeadSponsor.AgencyClass",
164            "OverallContactBackup.Text",
165            "Eligibility.Criteria.Text",
166            "XMLName.Local",
167            "OversightInfo.Authority",
168            "PrimaryCompletionDate.Type",
169            "ResponsibleParty.Organization",
170            "IDInfo.NctAlias",
171            "ResponsibleParty.Text",
172            "TargetDuration",
173            "Sponsors.LeadSponsor.Agency",
174            "BriefSummary.Text",
175            "OverallContactBackup.Email",
176            "ResponsibleParty.InvestigatorTitle",
177        ]
178
179        self.best_recall_fields = [
180            "LocationCountries.Country",
181            "BiospecRetention",
182            "DetailedDescription.Textblock",
183            "HasExpandedAccess",
184            "ConditionBrowse.MeshTerm",
185            "RequiredHeader.LinkText",
186            "WhyStopped",
187            "BriefSummary.Textblock",
188            "Eligibility.Criteria.Textblock",
189            "OfficialTitle",
190            "Eligibility.MaximumAge",
191            "Eligibility.StudyPop.Textblock",
192            "BiospecDescr.Textblock",
193            "BriefTitle",
194            "Eligibility.MinimumAge",
195            "ResponsibleParty.Organization",
196            "TargetDuration",
197            "Condition",
198            "IDInfo.OrgStudyID",
199            "Keyword",
200            "Source",
201            "Sponsors.LeadSponsor.Agency",
202            "ResponsibleParty.InvestigatorAffiliation",
203            "OversightInfo.Authority",
204            "OversightInfo.HasDmc",
205            "OverallContact.Phone",
206            "Phase",
207            "OverallContactBackup.LastName",
208            "Acronym",
209            "InterventionBrowse.MeshTerm",
210            "RemovedCountries.Country",
211        ]
212        self.best_map_fields = [
213            "Eligibility.Gender",
214            "LocationCountries.Country",
215            "DetailedDescription.Textblock",
216            "BriefSummary.Textblock",
217            "ConditionBrowse.MeshTerm",
218            "Eligibility.Criteria.Textblock",
219            "InterventionBrowse.MeshTerm",
220            "StudyType",
221            "IsFdaRegulated",
222            "HasExpandedAccess",
223            "RequiredHeader.LinkText",
224            "BiospecRetention",
225            "OfficialTitle",
226            "Eligibility.SamplingMethod",
227            "Eligibility.StudyPop.Textblock",
228            "Condition",
229            "Eligibility.MinimumAge",
230            "Keyword",
231            "Eligibility.MaximumAge",
232            "BriefTitle",
233        ]
234        self.best_embed_fields = [
235            "WhyStopped",
236            "HasExpandedAccess",
237            "BiospecRetention",
238            "BriefSummary.Textblock",
239            "LocationCountries.Country",
240            "ConditionBrowse.MeshTerm",
241            "DetailedDescription.Textblock",
242            "RequiredHeader.LinkText",
243            "Eligibility.Criteria.Textblock",
244        ]
245
246        self.sensible = [
247            "BriefSummary.Textblock" "BriefTitle",
248            "Eligibility.StudyPop.Textblock",
249            "DetailedDescription.Textblock",
250            "Eligibility.MinimumAge",
251            "Eligibility.Criteria.Textblock",
252            "InterventionBrowse.Text",
253            "Eligibility.Text",
254            "BiospecDescr.Textblock",
255            "Eligibility.MaximumAge",
256            "DetailedDescription.Text",
257            "ConditionBrowse.MeshTerm",
258            "ConditionBrowse.Text",
259            "Eligibility.StudyPop.Text",
260            "InterventionBrowse.MeshTerm",
261            "OfficialTitle",
262            "Condition",
263            "PrimaryOutcome",
264            "BiospecDescr.Text",
265            "Eligibility.Gender",
266            "Keyword",
267            "BiospecRetention",
268            "Eligibility.Criteria.Text",
269            "BriefSummary.Text",
270        ]
271
272        self.sensible_embed = [
273            "BriefSummary.Textblock" "BriefTitle",
274            "Eligibility.StudyPop.Textblock",
275            "DetailedDescription.Textblock",
276            "Eligibility.Criteria.Textblock",
277            "InterventionBrowse.Text",
278            "Eligibility.Text",
279            "BiospecDescr.Textblock",
280            "DetailedDescription.Text",
281            "ConditionBrowse.MeshTerm",
282            "ConditionBrowse.Text",
283            "Eligibility.StudyPop.Text",
284            "InterventionBrowse.MeshTerm",
285            "OfficialTitle",
286            "Condition",
287            "PrimaryOutcome",
288            "BiospecDescr.Text",
289            "Keyword",
290            "BiospecRetention",
291            "Eligibility.Criteria.Text",
292            "BriefSummary.Text",
293        ]
294
295        self.sensible_embed_safe = list(
296            set(self.best_recall_fields).intersection(set(self.sensible_embed))
297        )
298
299        self.query_funcs = {
300            "query": self.generate_query,
301            "ablation": self.generate_query_ablation,
302            "embedding": self.generate_query_embedding,
303        }
304
305        loguru.logger.debug(self.sensible_embed_safe)
306
307        self.field_usage = {
308            "best_recall_fields": self.best_recall_fields,
309            "all": self.mappings,
310            "best_map_fields": self.best_map_fields,
311            "best_embed_fields": self.best_embed_fields,
312            "sensible": self.sensible,
313            "sensible_embed": self.sensible_embed,
314            "sensible_embed_safe": self.sensible_embed_safe,
315        }
316
317    @apply_config
318    def generate_query(self, topic_num, query_field_usage, **kwargs) -> Dict:
319        """
320        Generates a query for the clinical trials index
321
322        :param topic_num: Topic number to search
323        :param query_field_usage: Which document facets to search over
324        :param kwargs:
325        :return:
326            A basic elasticsearch query for clinical trials
327        """
328        fields = self.field_usage[query_field_usage]
329        should = {"should": []}
330
331        qfield = list(self.topics[topic_num].keys())[0]
332        query = self.topics[topic_num][qfield]
333
334        for i, field in enumerate(fields):
335            should["should"].append(
336                {
337                    "match": {
338                        f"{field}": {
339                            "query": query,
340                        }
341                    }
342                }
343            )
344
345        query = {
346            "query": {
347                "bool": should,
348            }
349        }
350
351        return query
352
353    def generate_query_ablation(self, topic_num, **kwargs):
354        """
355        Only search one document facet at a time
356        :param topic_num:
357        :param kwargs:
358        :return:
359        """
360        query = {"query": {"match": {}}}
361
362        for field in self.fields:
363            query["query"]["match"][self.mappings[field]] = ""
364
365        for qfield in self.fields:
366            qfield = self.mappings[qfield]
367            for field in self.topics[topic_num]:
368                query["query"]["match"][qfield] += self.topics[topic_num][field]
369
370        return query
371
372    @apply_config
373    def generate_query_embedding(
374            self,
375            topic_num,
376            encoder,
377            query_field_usage,
378            embed_field_usage,
379            cosine_weights: List[float] = None,
380            query_weight: List[float] = None,
381            norm_weight=2.15,
382            ablations=False,
383            automatic_scores=None,
384            **kwargs,
385    ):
386        """
387        Computes the NIR score for a given topic
388
389        Score = log(BM25)/log(norm_weight) + embedding_score
390
391        :param topic_num:
392        :param encoder:
393        :param query_field_usage:
394        :param embed_field_usage:
395        :param cosine_weights:
396        :param query_weight:
397        :param norm_weight:
398        :param ablations:
399        :param automatic_scores:
400        :param kwargs:
401        :return:
402        """
403        should = {"should": []}
404
405        assert norm_weight or automatic_scores
406
407        query_fields = self.field_usage[query_field_usage]
408        embed_fields = self.field_usage[embed_field_usage]
409
410        qfield = list(self.topics[topic_num].keys())[0]
411        query = self.topics[topic_num][qfield]
412
413        for i, field in enumerate(query_fields):
414            should["should"].append(
415                {
416                    "match": {
417                        f"{field}": {
418                            "query": query,
419                            "boost": query_weight[i] if query_weight else 1,
420                        }
421                    }
422                }
423            )
424
425        if automatic_scores is not None:
426            norm_weight = get_z_value(
427                cosine_ceiling=len(embed_fields) * len(query_fields),
428                bm25_ceiling=automatic_scores[topic_num],
429            )
430
431        params = {
432            "weights": cosine_weights if cosine_weights else [1] * len(embed_fields),
433            "q_eb": encoder.encode(self.topics[topic_num][qfield]),
434            "offset": 1.0,
435            "norm_weight": norm_weight,
436            "disable_bm25": ablations,
437        }
438
439        query = {
440            "query": {
441                "script_score": {
442                    "query": {
443                        "bool": should,
444                    },
445                    "script": generate_script(self.best_embed_fields, params=params),
446                },
447            }
448        }
449
450        return query
451
452    def get_query_type(self, *args, **kwargs):
453        return self.query_funcs[self.query_type](*args, **kwargs)
454
455    def get_id_mapping(self, hit):
456        return hit[self.id_mapping]
457
458
459class ClinicalTrialsElasticsearchExecutor(GenericElasticsearchExecutor):
460    """
461    Executes queries given a query object.
462    """
463    query: TrialsElasticsearchQuery
464
465    def __init__(
466            self,
467            topics: Dict[Union[str, int], Dict[str, str]],
468            client: Elasticsearch,
469            index_name: str,
470            output_file: str,
471            query: TrialsElasticsearchQuery,
472            encoder: Optional[Encoder] = None,
473            config=None,
474            *args,
475            **kwargs,
476    ):
477        super().__init__(
478            topics,
479            client,
480            index_name,
481            output_file,
482            query,
483            encoder,
484            config=config,
485            *args,
486            **kwargs,
487        )
488
489        self.query_fns = {
490            "query": self.generate_query,
491            "ablation": self.generate_query_ablation,
492            "embedding": self.generate_embedding_query,
493        }
494
495
496class ClinicalTrialParser(Parser):
497    """
498    Parser for Clinical Trials topics
499    """
500
501    @classmethod
502    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
503        topics = {}
504        reader = csv.reader(csvfile)
505        for i, row in enumerate(reader):
506            if i == 0:
507                continue
508
509            _id = row[0]
510            text = row[1]
511
512            topics[_id] = {"text": text}
513
514        return topics
@dataclass(init=True, unsafe_hash=True)
class TrialsQueryConfig(debeir.core.config.GenericConfig):
17@dataclass(init=True, unsafe_hash=True)
18class TrialsQueryConfig(GenericConfig):
19    query_field_usage: str = None
20    embed_field_usage: str = None
21    fields: List[str] = None
22
23    def validate(self):
24        """
25        Checks if query type is included, and checks if an encoder is included for embedding queries
26        """
27        if self.query_type == "embedding":
28            assert self.query_field_usage and self.embed_field_usage, (
29                "Must have both field usages" " if embedding query"
30            )
31            assert (
32                    self.encoder_fp and self.encoder
33            ), "Must provide encoder path for embedding model"
34            assert self.norm_weight is not None or self.automatic is not None, (
35                "Norm weight be specified or be " "automatic "
36            )
37
38        assert (
39                self.query_field_usage is not None or self.fields is not None
40        ), "Must have a query field"
41        assert self.query_type in [
42            "ablation",
43            "query",
44            "query_best",
45            "embedding",
46        ], "Check your query type"
47
48    @classmethod
49    def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig":
50        return super().from_toml(fp, cls, *args, **kwargs)
51
52    @classmethod
53    def from_dict(cls, **kwargs) -> "GenericConfig":
54        return super().from_dict(cls, **kwargs)
TrialsQueryConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None, query_field_usage: str = None, embed_field_usage: str = None, fields: List[str] = None)
def validate(self):
23    def validate(self):
24        """
25        Checks if query type is included, and checks if an encoder is included for embedding queries
26        """
27        if self.query_type == "embedding":
28            assert self.query_field_usage and self.embed_field_usage, (
29                "Must have both field usages" " if embedding query"
30            )
31            assert (
32                    self.encoder_fp and self.encoder
33            ), "Must provide encoder path for embedding model"
34            assert self.norm_weight is not None or self.automatic is not None, (
35                "Norm weight be specified or be " "automatic "
36            )
37
38        assert (
39                self.query_field_usage is not None or self.fields is not None
40        ), "Must have a query field"
41        assert self.query_type in [
42            "ablation",
43            "query",
44            "query_best",
45            "embedding",
46        ], "Check your query type"

Checks if query type is included, and checks if an encoder is included for embedding queries

@classmethod
def from_toml(cls, fp: str, *args, **kwargs) -> debeir.core.config.GenericConfig:
48    @classmethod
49    def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig":
50        return super().from_toml(fp, cls, *args, **kwargs)

Instantiates a Config object from a toml file

Parameters
  • fp: File path of the Config TOML file
  • field_class: Class of the Config object to be instantiated
  • args: Arguments to be passed to Config
  • kwargs: Keyword arguments to be passed
Returns
A instantiated and validated Config object.
@classmethod
def from_dict(cls, **kwargs) -> debeir.core.config.GenericConfig:
52    @classmethod
53    def from_dict(cls, **kwargs) -> "GenericConfig":
54        return super().from_dict(cls, **kwargs)

Instantiates a Config object from a dictionary

Parameters
  • data_class:
  • kwargs:
Returns
class TrialsElasticsearchQuery(debeir.core.query.GenericElasticsearchQuery):
 57class TrialsElasticsearchQuery(GenericElasticsearchQuery):
 58    """
 59    Elasticsearch Query object for the Clinical Trials Index
 60    """
 61    topics: Dict[int, Dict[str, str]]
 62    query_type: str
 63    fields: List[int]
 64    query_funcs: Dict
 65    config: GenericConfig
 66    id_mapping: str = "_id"
 67    mappings: List[str]
 68    config: TrialsQueryConfig
 69
 70    def __init__(self, topics, query_type, config=None, *args, **kwargs):
 71        super().__init__(topics, config, *args, **kwargs)
 72        self.query_type = query_type
 73        self.config = config
 74        self.topics = topics
 75        self.fields = []
 76        self.mappings = [
 77            "HasExpandedAccess",
 78            "BriefSummary.Textblock",
 79            "CompletionDate.Type",
 80            "OversightInfo.Text",
 81            "OverallContactBackup.PhoneExt",
 82            "RemovedCountries.Text",
 83            "SecondaryOutcome",
 84            "Sponsors.LeadSponsor.Text",
 85            "BriefTitle",
 86            "IDInfo.NctID",
 87            "IDInfo.SecondaryID",
 88            "OverallContactBackup.Phone",
 89            "Eligibility.StudyPop.Textblock",
 90            "DetailedDescription.Textblock",
 91            "Eligibility.MinimumAge",
 92            "Sponsors.Collaborator",
 93            "Reference",
 94            "Eligibility.Criteria.Textblock",
 95            "XMLName.Space",
 96            "Rank",
 97            "OverallStatus",
 98            "InterventionBrowse.Text",
 99            "Eligibility.Text",
100            "Intervention",
101            "BiospecDescr.Textblock",
102            "ResponsibleParty.NameTitle",
103            "NumberOfArms",
104            "ResponsibleParty.ResponsiblePartyType",
105            "IsSection801",
106            "Acronym",
107            "Eligibility.MaximumAge",
108            "DetailedDescription.Text",
109            "StudyDesign",
110            "OtherOutcome",
111            "VerificationDate",
112            "ConditionBrowse.MeshTerm",
113            "Enrollment.Text",
114            "IDInfo.Text",
115            "ConditionBrowse.Text",
116            "FirstreceivedDate",
117            "NumberOfGroups",
118            "OversightInfo.HasDmc",
119            "PrimaryCompletionDate.Text",
120            "ResultsReference",
121            "Eligibility.StudyPop.Text",
122            "IsFdaRegulated",
123            "WhyStopped",
124            "ArmGroup",
125            "OverallContact.LastName",
126            "Phase",
127            "RemovedCountries.Country",
128            "InterventionBrowse.MeshTerm",
129            "Eligibility.HealthyVolunteers",
130            "Location",
131            "OfficialTitle",
132            "OverallContact.Email",
133            "RequiredHeader.Text",
134            "RequiredHeader.URL",
135            "LocationCountries.Country",
136            "OverallContact.PhoneExt",
137            "Condition",
138            "PrimaryOutcome",
139            "LocationCountries.Text",
140            "BiospecDescr.Text",
141            "IDInfo.OrgStudyID",
142            "Link",
143            "OverallContact.Phone",
144            "Source",
145            "ResponsibleParty.InvestigatorAffiliation",
146            "StudyType",
147            "FirstreceivedResultsDate",
148            "Enrollment.Type",
149            "Eligibility.Gender",
150            "OverallContactBackup.LastName",
151            "Keyword",
152            "BiospecRetention",
153            "CompletionDate.Text",
154            "OverallContact.Text",
155            "RequiredHeader.DownloadDate",
156            "Sponsors.Text",
157            "Text",
158            "Eligibility.SamplingMethod",
159            "LastchangedDate",
160            "ResponsibleParty.InvestigatorFullName",
161            "StartDate",
162            "RequiredHeader.LinkText",
163            "OverallOfficial",
164            "Sponsors.LeadSponsor.AgencyClass",
165            "OverallContactBackup.Text",
166            "Eligibility.Criteria.Text",
167            "XMLName.Local",
168            "OversightInfo.Authority",
169            "PrimaryCompletionDate.Type",
170            "ResponsibleParty.Organization",
171            "IDInfo.NctAlias",
172            "ResponsibleParty.Text",
173            "TargetDuration",
174            "Sponsors.LeadSponsor.Agency",
175            "BriefSummary.Text",
176            "OverallContactBackup.Email",
177            "ResponsibleParty.InvestigatorTitle",
178        ]
179
180        self.best_recall_fields = [
181            "LocationCountries.Country",
182            "BiospecRetention",
183            "DetailedDescription.Textblock",
184            "HasExpandedAccess",
185            "ConditionBrowse.MeshTerm",
186            "RequiredHeader.LinkText",
187            "WhyStopped",
188            "BriefSummary.Textblock",
189            "Eligibility.Criteria.Textblock",
190            "OfficialTitle",
191            "Eligibility.MaximumAge",
192            "Eligibility.StudyPop.Textblock",
193            "BiospecDescr.Textblock",
194            "BriefTitle",
195            "Eligibility.MinimumAge",
196            "ResponsibleParty.Organization",
197            "TargetDuration",
198            "Condition",
199            "IDInfo.OrgStudyID",
200            "Keyword",
201            "Source",
202            "Sponsors.LeadSponsor.Agency",
203            "ResponsibleParty.InvestigatorAffiliation",
204            "OversightInfo.Authority",
205            "OversightInfo.HasDmc",
206            "OverallContact.Phone",
207            "Phase",
208            "OverallContactBackup.LastName",
209            "Acronym",
210            "InterventionBrowse.MeshTerm",
211            "RemovedCountries.Country",
212        ]
213        self.best_map_fields = [
214            "Eligibility.Gender",
215            "LocationCountries.Country",
216            "DetailedDescription.Textblock",
217            "BriefSummary.Textblock",
218            "ConditionBrowse.MeshTerm",
219            "Eligibility.Criteria.Textblock",
220            "InterventionBrowse.MeshTerm",
221            "StudyType",
222            "IsFdaRegulated",
223            "HasExpandedAccess",
224            "RequiredHeader.LinkText",
225            "BiospecRetention",
226            "OfficialTitle",
227            "Eligibility.SamplingMethod",
228            "Eligibility.StudyPop.Textblock",
229            "Condition",
230            "Eligibility.MinimumAge",
231            "Keyword",
232            "Eligibility.MaximumAge",
233            "BriefTitle",
234        ]
235        self.best_embed_fields = [
236            "WhyStopped",
237            "HasExpandedAccess",
238            "BiospecRetention",
239            "BriefSummary.Textblock",
240            "LocationCountries.Country",
241            "ConditionBrowse.MeshTerm",
242            "DetailedDescription.Textblock",
243            "RequiredHeader.LinkText",
244            "Eligibility.Criteria.Textblock",
245        ]
246
247        self.sensible = [
248            "BriefSummary.Textblock" "BriefTitle",
249            "Eligibility.StudyPop.Textblock",
250            "DetailedDescription.Textblock",
251            "Eligibility.MinimumAge",
252            "Eligibility.Criteria.Textblock",
253            "InterventionBrowse.Text",
254            "Eligibility.Text",
255            "BiospecDescr.Textblock",
256            "Eligibility.MaximumAge",
257            "DetailedDescription.Text",
258            "ConditionBrowse.MeshTerm",
259            "ConditionBrowse.Text",
260            "Eligibility.StudyPop.Text",
261            "InterventionBrowse.MeshTerm",
262            "OfficialTitle",
263            "Condition",
264            "PrimaryOutcome",
265            "BiospecDescr.Text",
266            "Eligibility.Gender",
267            "Keyword",
268            "BiospecRetention",
269            "Eligibility.Criteria.Text",
270            "BriefSummary.Text",
271        ]
272
273        self.sensible_embed = [
274            "BriefSummary.Textblock" "BriefTitle",
275            "Eligibility.StudyPop.Textblock",
276            "DetailedDescription.Textblock",
277            "Eligibility.Criteria.Textblock",
278            "InterventionBrowse.Text",
279            "Eligibility.Text",
280            "BiospecDescr.Textblock",
281            "DetailedDescription.Text",
282            "ConditionBrowse.MeshTerm",
283            "ConditionBrowse.Text",
284            "Eligibility.StudyPop.Text",
285            "InterventionBrowse.MeshTerm",
286            "OfficialTitle",
287            "Condition",
288            "PrimaryOutcome",
289            "BiospecDescr.Text",
290            "Keyword",
291            "BiospecRetention",
292            "Eligibility.Criteria.Text",
293            "BriefSummary.Text",
294        ]
295
296        self.sensible_embed_safe = list(
297            set(self.best_recall_fields).intersection(set(self.sensible_embed))
298        )
299
300        self.query_funcs = {
301            "query": self.generate_query,
302            "ablation": self.generate_query_ablation,
303            "embedding": self.generate_query_embedding,
304        }
305
306        loguru.logger.debug(self.sensible_embed_safe)
307
308        self.field_usage = {
309            "best_recall_fields": self.best_recall_fields,
310            "all": self.mappings,
311            "best_map_fields": self.best_map_fields,
312            "best_embed_fields": self.best_embed_fields,
313            "sensible": self.sensible,
314            "sensible_embed": self.sensible_embed,
315            "sensible_embed_safe": self.sensible_embed_safe,
316        }
317
318    @apply_config
319    def generate_query(self, topic_num, query_field_usage, **kwargs) -> Dict:
320        """
321        Generates a query for the clinical trials index
322
323        :param topic_num: Topic number to search
324        :param query_field_usage: Which document facets to search over
325        :param kwargs:
326        :return:
327            A basic elasticsearch query for clinical trials
328        """
329        fields = self.field_usage[query_field_usage]
330        should = {"should": []}
331
332        qfield = list(self.topics[topic_num].keys())[0]
333        query = self.topics[topic_num][qfield]
334
335        for i, field in enumerate(fields):
336            should["should"].append(
337                {
338                    "match": {
339                        f"{field}": {
340                            "query": query,
341                        }
342                    }
343                }
344            )
345
346        query = {
347            "query": {
348                "bool": should,
349            }
350        }
351
352        return query
353
354    def generate_query_ablation(self, topic_num, **kwargs):
355        """
356        Only search one document facet at a time
357        :param topic_num:
358        :param kwargs:
359        :return:
360        """
361        query = {"query": {"match": {}}}
362
363        for field in self.fields:
364            query["query"]["match"][self.mappings[field]] = ""
365
366        for qfield in self.fields:
367            qfield = self.mappings[qfield]
368            for field in self.topics[topic_num]:
369                query["query"]["match"][qfield] += self.topics[topic_num][field]
370
371        return query
372
373    @apply_config
374    def generate_query_embedding(
375            self,
376            topic_num,
377            encoder,
378            query_field_usage,
379            embed_field_usage,
380            cosine_weights: List[float] = None,
381            query_weight: List[float] = None,
382            norm_weight=2.15,
383            ablations=False,
384            automatic_scores=None,
385            **kwargs,
386    ):
387        """
388        Computes the NIR score for a given topic
389
390        Score = log(BM25)/log(norm_weight) + embedding_score
391
392        :param topic_num:
393        :param encoder:
394        :param query_field_usage:
395        :param embed_field_usage:
396        :param cosine_weights:
397        :param query_weight:
398        :param norm_weight:
399        :param ablations:
400        :param automatic_scores:
401        :param kwargs:
402        :return:
403        """
404        should = {"should": []}
405
406        assert norm_weight or automatic_scores
407
408        query_fields = self.field_usage[query_field_usage]
409        embed_fields = self.field_usage[embed_field_usage]
410
411        qfield = list(self.topics[topic_num].keys())[0]
412        query = self.topics[topic_num][qfield]
413
414        for i, field in enumerate(query_fields):
415            should["should"].append(
416                {
417                    "match": {
418                        f"{field}": {
419                            "query": query,
420                            "boost": query_weight[i] if query_weight else 1,
421                        }
422                    }
423                }
424            )
425
426        if automatic_scores is not None:
427            norm_weight = get_z_value(
428                cosine_ceiling=len(embed_fields) * len(query_fields),
429                bm25_ceiling=automatic_scores[topic_num],
430            )
431
432        params = {
433            "weights": cosine_weights if cosine_weights else [1] * len(embed_fields),
434            "q_eb": encoder.encode(self.topics[topic_num][qfield]),
435            "offset": 1.0,
436            "norm_weight": norm_weight,
437            "disable_bm25": ablations,
438        }
439
440        query = {
441            "query": {
442                "script_score": {
443                    "query": {
444                        "bool": should,
445                    },
446                    "script": generate_script(self.best_embed_fields, params=params),
447                },
448            }
449        }
450
451        return query
452
453    def get_query_type(self, *args, **kwargs):
454        return self.query_funcs[self.query_type](*args, **kwargs)
455
456    def get_id_mapping(self, hit):
457        return hit[self.id_mapping]

Elasticsearch Query object for the Clinical Trials Index

TrialsElasticsearchQuery(topics, query_type, config=None, *args, **kwargs)
 70    def __init__(self, topics, query_type, config=None, *args, **kwargs):
 71        super().__init__(topics, config, *args, **kwargs)
 72        self.query_type = query_type
 73        self.config = config
 74        self.topics = topics
 75        self.fields = []
 76        self.mappings = [
 77            "HasExpandedAccess",
 78            "BriefSummary.Textblock",
 79            "CompletionDate.Type",
 80            "OversightInfo.Text",
 81            "OverallContactBackup.PhoneExt",
 82            "RemovedCountries.Text",
 83            "SecondaryOutcome",
 84            "Sponsors.LeadSponsor.Text",
 85            "BriefTitle",
 86            "IDInfo.NctID",
 87            "IDInfo.SecondaryID",
 88            "OverallContactBackup.Phone",
 89            "Eligibility.StudyPop.Textblock",
 90            "DetailedDescription.Textblock",
 91            "Eligibility.MinimumAge",
 92            "Sponsors.Collaborator",
 93            "Reference",
 94            "Eligibility.Criteria.Textblock",
 95            "XMLName.Space",
 96            "Rank",
 97            "OverallStatus",
 98            "InterventionBrowse.Text",
 99            "Eligibility.Text",
100            "Intervention",
101            "BiospecDescr.Textblock",
102            "ResponsibleParty.NameTitle",
103            "NumberOfArms",
104            "ResponsibleParty.ResponsiblePartyType",
105            "IsSection801",
106            "Acronym",
107            "Eligibility.MaximumAge",
108            "DetailedDescription.Text",
109            "StudyDesign",
110            "OtherOutcome",
111            "VerificationDate",
112            "ConditionBrowse.MeshTerm",
113            "Enrollment.Text",
114            "IDInfo.Text",
115            "ConditionBrowse.Text",
116            "FirstreceivedDate",
117            "NumberOfGroups",
118            "OversightInfo.HasDmc",
119            "PrimaryCompletionDate.Text",
120            "ResultsReference",
121            "Eligibility.StudyPop.Text",
122            "IsFdaRegulated",
123            "WhyStopped",
124            "ArmGroup",
125            "OverallContact.LastName",
126            "Phase",
127            "RemovedCountries.Country",
128            "InterventionBrowse.MeshTerm",
129            "Eligibility.HealthyVolunteers",
130            "Location",
131            "OfficialTitle",
132            "OverallContact.Email",
133            "RequiredHeader.Text",
134            "RequiredHeader.URL",
135            "LocationCountries.Country",
136            "OverallContact.PhoneExt",
137            "Condition",
138            "PrimaryOutcome",
139            "LocationCountries.Text",
140            "BiospecDescr.Text",
141            "IDInfo.OrgStudyID",
142            "Link",
143            "OverallContact.Phone",
144            "Source",
145            "ResponsibleParty.InvestigatorAffiliation",
146            "StudyType",
147            "FirstreceivedResultsDate",
148            "Enrollment.Type",
149            "Eligibility.Gender",
150            "OverallContactBackup.LastName",
151            "Keyword",
152            "BiospecRetention",
153            "CompletionDate.Text",
154            "OverallContact.Text",
155            "RequiredHeader.DownloadDate",
156            "Sponsors.Text",
157            "Text",
158            "Eligibility.SamplingMethod",
159            "LastchangedDate",
160            "ResponsibleParty.InvestigatorFullName",
161            "StartDate",
162            "RequiredHeader.LinkText",
163            "OverallOfficial",
164            "Sponsors.LeadSponsor.AgencyClass",
165            "OverallContactBackup.Text",
166            "Eligibility.Criteria.Text",
167            "XMLName.Local",
168            "OversightInfo.Authority",
169            "PrimaryCompletionDate.Type",
170            "ResponsibleParty.Organization",
171            "IDInfo.NctAlias",
172            "ResponsibleParty.Text",
173            "TargetDuration",
174            "Sponsors.LeadSponsor.Agency",
175            "BriefSummary.Text",
176            "OverallContactBackup.Email",
177            "ResponsibleParty.InvestigatorTitle",
178        ]
179
180        self.best_recall_fields = [
181            "LocationCountries.Country",
182            "BiospecRetention",
183            "DetailedDescription.Textblock",
184            "HasExpandedAccess",
185            "ConditionBrowse.MeshTerm",
186            "RequiredHeader.LinkText",
187            "WhyStopped",
188            "BriefSummary.Textblock",
189            "Eligibility.Criteria.Textblock",
190            "OfficialTitle",
191            "Eligibility.MaximumAge",
192            "Eligibility.StudyPop.Textblock",
193            "BiospecDescr.Textblock",
194            "BriefTitle",
195            "Eligibility.MinimumAge",
196            "ResponsibleParty.Organization",
197            "TargetDuration",
198            "Condition",
199            "IDInfo.OrgStudyID",
200            "Keyword",
201            "Source",
202            "Sponsors.LeadSponsor.Agency",
203            "ResponsibleParty.InvestigatorAffiliation",
204            "OversightInfo.Authority",
205            "OversightInfo.HasDmc",
206            "OverallContact.Phone",
207            "Phase",
208            "OverallContactBackup.LastName",
209            "Acronym",
210            "InterventionBrowse.MeshTerm",
211            "RemovedCountries.Country",
212        ]
213        self.best_map_fields = [
214            "Eligibility.Gender",
215            "LocationCountries.Country",
216            "DetailedDescription.Textblock",
217            "BriefSummary.Textblock",
218            "ConditionBrowse.MeshTerm",
219            "Eligibility.Criteria.Textblock",
220            "InterventionBrowse.MeshTerm",
221            "StudyType",
222            "IsFdaRegulated",
223            "HasExpandedAccess",
224            "RequiredHeader.LinkText",
225            "BiospecRetention",
226            "OfficialTitle",
227            "Eligibility.SamplingMethod",
228            "Eligibility.StudyPop.Textblock",
229            "Condition",
230            "Eligibility.MinimumAge",
231            "Keyword",
232            "Eligibility.MaximumAge",
233            "BriefTitle",
234        ]
235        self.best_embed_fields = [
236            "WhyStopped",
237            "HasExpandedAccess",
238            "BiospecRetention",
239            "BriefSummary.Textblock",
240            "LocationCountries.Country",
241            "ConditionBrowse.MeshTerm",
242            "DetailedDescription.Textblock",
243            "RequiredHeader.LinkText",
244            "Eligibility.Criteria.Textblock",
245        ]
246
247        self.sensible = [
248            "BriefSummary.Textblock" "BriefTitle",
249            "Eligibility.StudyPop.Textblock",
250            "DetailedDescription.Textblock",
251            "Eligibility.MinimumAge",
252            "Eligibility.Criteria.Textblock",
253            "InterventionBrowse.Text",
254            "Eligibility.Text",
255            "BiospecDescr.Textblock",
256            "Eligibility.MaximumAge",
257            "DetailedDescription.Text",
258            "ConditionBrowse.MeshTerm",
259            "ConditionBrowse.Text",
260            "Eligibility.StudyPop.Text",
261            "InterventionBrowse.MeshTerm",
262            "OfficialTitle",
263            "Condition",
264            "PrimaryOutcome",
265            "BiospecDescr.Text",
266            "Eligibility.Gender",
267            "Keyword",
268            "BiospecRetention",
269            "Eligibility.Criteria.Text",
270            "BriefSummary.Text",
271        ]
272
273        self.sensible_embed = [
274            "BriefSummary.Textblock" "BriefTitle",
275            "Eligibility.StudyPop.Textblock",
276            "DetailedDescription.Textblock",
277            "Eligibility.Criteria.Textblock",
278            "InterventionBrowse.Text",
279            "Eligibility.Text",
280            "BiospecDescr.Textblock",
281            "DetailedDescription.Text",
282            "ConditionBrowse.MeshTerm",
283            "ConditionBrowse.Text",
284            "Eligibility.StudyPop.Text",
285            "InterventionBrowse.MeshTerm",
286            "OfficialTitle",
287            "Condition",
288            "PrimaryOutcome",
289            "BiospecDescr.Text",
290            "Keyword",
291            "BiospecRetention",
292            "Eligibility.Criteria.Text",
293            "BriefSummary.Text",
294        ]
295
296        self.sensible_embed_safe = list(
297            set(self.best_recall_fields).intersection(set(self.sensible_embed))
298        )
299
300        self.query_funcs = {
301            "query": self.generate_query,
302            "ablation": self.generate_query_ablation,
303            "embedding": self.generate_query_embedding,
304        }
305
306        loguru.logger.debug(self.sensible_embed_safe)
307
308        self.field_usage = {
309            "best_recall_fields": self.best_recall_fields,
310            "all": self.mappings,
311            "best_map_fields": self.best_map_fields,
312            "best_embed_fields": self.best_embed_fields,
313            "sensible": self.sensible,
314            "sensible_embed": self.sensible_embed,
315            "sensible_embed_safe": self.sensible_embed_safe,
316        }
def generate_query(self, *args, **kwargs):
229    def use_config(self, *args, **kwargs):
230        """
231        Replaces keywords and args passed to the function with ones from self.config.
232
233        :param self:
234        :param args: To be updated
235        :param kwargs: To be updated
236        :return:
237        """
238        if self.config is not None:
239            kwargs = self.config.__update__(**kwargs)
240
241        return func(self, *args, **kwargs)

Generates a query for the clinical trials index

Parameters
  • topic_num: Topic number to search
  • query_field_usage: Which document facets to search over
  • kwargs:
Returns
A basic elasticsearch query for clinical trials
def generate_query_ablation(self, topic_num, **kwargs):
354    def generate_query_ablation(self, topic_num, **kwargs):
355        """
356        Only search one document facet at a time
357        :param topic_num:
358        :param kwargs:
359        :return:
360        """
361        query = {"query": {"match": {}}}
362
363        for field in self.fields:
364            query["query"]["match"][self.mappings[field]] = ""
365
366        for qfield in self.fields:
367            qfield = self.mappings[qfield]
368            for field in self.topics[topic_num]:
369                query["query"]["match"][qfield] += self.topics[topic_num][field]
370
371        return query

Only search one document facet at a time

Parameters
  • topic_num:
  • kwargs:
Returns
def generate_query_embedding(self, *args, **kwargs):
229    def use_config(self, *args, **kwargs):
230        """
231        Replaces keywords and args passed to the function with ones from self.config.
232
233        :param self:
234        :param args: To be updated
235        :param kwargs: To be updated
236        :return:
237        """
238        if self.config is not None:
239            kwargs = self.config.__update__(**kwargs)
240
241        return func(self, *args, **kwargs)

Computes the NIR score for a given topic

Score = log(BM25)/log(norm_weight) + embedding_score

Parameters
  • topic_num:
  • encoder:
  • query_field_usage:
  • embed_field_usage:
  • cosine_weights:
  • query_weight:
  • norm_weight:
  • ablations:
  • automatic_scores:
  • kwargs:
Returns
def get_query_type(self, *args, **kwargs):
453    def get_query_type(self, *args, **kwargs):
454        return self.query_funcs[self.query_type](*args, **kwargs)
def get_id_mapping(self, hit):
456    def get_id_mapping(self, hit):
457        return hit[self.id_mapping]

Get the document ID

Parameters
  • hit: The raw document result
Returns
The document's ID
class ClinicalTrialsElasticsearchExecutor(debeir.core.executor.GenericElasticsearchExecutor):
460class ClinicalTrialsElasticsearchExecutor(GenericElasticsearchExecutor):
461    """
462    Executes queries given a query object.
463    """
464    query: TrialsElasticsearchQuery
465
466    def __init__(
467            self,
468            topics: Dict[Union[str, int], Dict[str, str]],
469            client: Elasticsearch,
470            index_name: str,
471            output_file: str,
472            query: TrialsElasticsearchQuery,
473            encoder: Optional[Encoder] = None,
474            config=None,
475            *args,
476            **kwargs,
477    ):
478        super().__init__(
479            topics,
480            client,
481            index_name,
482            output_file,
483            query,
484            encoder,
485            config=config,
486            *args,
487            **kwargs,
488        )
489
490        self.query_fns = {
491            "query": self.generate_query,
492            "ablation": self.generate_query_ablation,
493            "embedding": self.generate_embedding_query,
494        }

Executes queries given a query object.

ClinicalTrialsElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.datasets.clinical_trials.TrialsElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs)
466    def __init__(
467            self,
468            topics: Dict[Union[str, int], Dict[str, str]],
469            client: Elasticsearch,
470            index_name: str,
471            output_file: str,
472            query: TrialsElasticsearchQuery,
473            encoder: Optional[Encoder] = None,
474            config=None,
475            *args,
476            **kwargs,
477    ):
478        super().__init__(
479            topics,
480            client,
481            index_name,
482            output_file,
483            query,
484            encoder,
485            config=config,
486            *args,
487            **kwargs,
488        )
489
490        self.query_fns = {
491            "query": self.generate_query,
492            "ablation": self.generate_query_ablation,
493            "embedding": self.generate_embedding_query,
494        }
class ClinicalTrialParser(debeir.core.parser.Parser):
497class ClinicalTrialParser(Parser):
498    """
499    Parser for Clinical Trials topics
500    """
501
502    @classmethod
503    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
504        topics = {}
505        reader = csv.reader(csvfile)
506        for i, row in enumerate(reader):
507            if i == 0:
508                continue
509
510            _id = row[0]
511            text = row[1]
512
513            topics[_id] = {"text": text}
514
515        return topics

Parser for Clinical Trials topics

@classmethod
def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
502    @classmethod
503    def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
504        topics = {}
505        reader = csv.reader(csvfile)
506        for i, row in enumerate(reader):
507            if i == 0:
508                continue
509
510            _id = row[0]
511            text = row[1]
512
513            topics[_id] = {"text": text}
514
515        return topics

Instance method for getting topics, forwards instance self parameters to the _get_topics class method.