debeir.training.evaluate_reranker
1from collections import defaultdict 2from typing import Dict, List, Union 3 4import numpy as np 5from debeir.evaluation.evaluator import Evaluator 6from debeir.rankers.transformer_sent_encoder import Encoder 7from sklearn.metrics.pairwise import cosine_similarity 8 9from datasets import Dataset 10 11distance_fns = { 12 "dot_score": np.dot, 13 "cos_sim": cosine_similarity 14} 15 16 17class SentenceEvaluator(Evaluator): 18 def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict], 19 text_cols: List[str], query_cols: List[str], id_col: str, 20 distance_fn: str, 21 qrels: str, metrics: List[str]): 22 super().__init__(qrels, metrics) 23 self.encoder = model 24 self.dataset = dataset 25 self.parsed_topics = parsed_topics 26 self.distance_fn = distance_fns[distance_fn] 27 self.query_cols = query_cols 28 self.text_cols = text_cols 29 30 self._get_topic_embeddings(query_cols) 31 self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols) 32 33 def _get_topic_embeddings(self, query_cols): 34 for topic_num, topic in self.parsed_topics.items(): 35 for query_col in query_cols: 36 query = topic[query_col] 37 query_eb = self.encoder(query) 38 39 topic[query_col + "_eb"] = query_eb 40 41 def _get_document_embedding_and_mapping(self, id_col, text_cols): 42 document_ebs = defaultdict(lambda: defaultdict(lambda: [])) 43 44 for datum in self.dataset: 45 for text_col in text_cols: 46 embedding = self.encoder(datum[text_col]) 47 topic_num, doc_id = datum[id_col].split("_") 48 document_ebs[topic_num][doc_id].append([text_col, embedding]) 49 50 return document_ebs 51 52 def _get_score(self, a, b, aggregate="sum"): 53 scores = [] 54 55 aggs = { 56 "max": max, 57 "min": min, 58 "sum": sum, 59 "avg": lambda k: sum(k) / len(k) 60 } 61 62 if not isinstance(a[0], list): 63 a = [a] 64 65 if not isinstance(b[0], list): 66 b = [b] 67 68 for _a in a: 69 for _b in b: 70 scores.append(float(self.distance_fn(_a, _b))) 71 72 return aggs[aggregate](scores) 73 74 def produce_ranked_lists(self): 75 # Store the indexes to access 76 # For each topic, sort. 77 78 topics = defaultdict(lambda: []) # [document_id, score] 79 80 for topic_num, doc_topics in self.document_ebs.items(): 81 for doc_id, document_repr in doc_topics.items(): 82 doc_txt_cols, doc_embeddings = list(zip(*document_repr)) 83 84 query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols] 85 topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)]) 86 87 for topic_num in topics: 88 topics[topic_num].sort(key=lambda k: k[1], reverse=True) 89 90 return topics
18class SentenceEvaluator(Evaluator): 19 def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict], 20 text_cols: List[str], query_cols: List[str], id_col: str, 21 distance_fn: str, 22 qrels: str, metrics: List[str]): 23 super().__init__(qrels, metrics) 24 self.encoder = model 25 self.dataset = dataset 26 self.parsed_topics = parsed_topics 27 self.distance_fn = distance_fns[distance_fn] 28 self.query_cols = query_cols 29 self.text_cols = text_cols 30 31 self._get_topic_embeddings(query_cols) 32 self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols) 33 34 def _get_topic_embeddings(self, query_cols): 35 for topic_num, topic in self.parsed_topics.items(): 36 for query_col in query_cols: 37 query = topic[query_col] 38 query_eb = self.encoder(query) 39 40 topic[query_col + "_eb"] = query_eb 41 42 def _get_document_embedding_and_mapping(self, id_col, text_cols): 43 document_ebs = defaultdict(lambda: defaultdict(lambda: [])) 44 45 for datum in self.dataset: 46 for text_col in text_cols: 47 embedding = self.encoder(datum[text_col]) 48 topic_num, doc_id = datum[id_col].split("_") 49 document_ebs[topic_num][doc_id].append([text_col, embedding]) 50 51 return document_ebs 52 53 def _get_score(self, a, b, aggregate="sum"): 54 scores = [] 55 56 aggs = { 57 "max": max, 58 "min": min, 59 "sum": sum, 60 "avg": lambda k: sum(k) / len(k) 61 } 62 63 if not isinstance(a[0], list): 64 a = [a] 65 66 if not isinstance(b[0], list): 67 b = [b] 68 69 for _a in a: 70 for _b in b: 71 scores.append(float(self.distance_fn(_a, _b))) 72 73 return aggs[aggregate](scores) 74 75 def produce_ranked_lists(self): 76 # Store the indexes to access 77 # For each topic, sort. 78 79 topics = defaultdict(lambda: []) # [document_id, score] 80 81 for topic_num, doc_topics in self.document_ebs.items(): 82 for doc_id, document_repr in doc_topics.items(): 83 doc_txt_cols, doc_embeddings = list(zip(*document_repr)) 84 85 query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols] 86 topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)]) 87 88 for topic_num in topics: 89 topics[topic_num].sort(key=lambda k: k[1], reverse=True) 90 91 return topics
Evaluation class for computing metrics from TREC-style files
SentenceEvaluator( model: debeir.rankers.transformer_sent_encoder.Encoder, dataset: datasets.arrow_dataset.Dataset, parsed_topics: Dict[Union[str, int], Dict], text_cols: List[str], query_cols: List[str], id_col: str, distance_fn: str, qrels: str, metrics: List[str])
19 def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict], 20 text_cols: List[str], query_cols: List[str], id_col: str, 21 distance_fn: str, 22 qrels: str, metrics: List[str]): 23 super().__init__(qrels, metrics) 24 self.encoder = model 25 self.dataset = dataset 26 self.parsed_topics = parsed_topics 27 self.distance_fn = distance_fns[distance_fn] 28 self.query_cols = query_cols 29 self.text_cols = text_cols 30 31 self._get_topic_embeddings(query_cols) 32 self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
def
produce_ranked_lists(self):
75 def produce_ranked_lists(self): 76 # Store the indexes to access 77 # For each topic, sort. 78 79 topics = defaultdict(lambda: []) # [document_id, score] 80 81 for topic_num, doc_topics in self.document_ebs.items(): 82 for doc_id, document_repr in doc_topics.items(): 83 doc_txt_cols, doc_embeddings = list(zip(*document_repr)) 84 85 query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols] 86 topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)]) 87 88 for topic_num in topics: 89 topics[topic_num].sort(key=lambda k: k[1], reverse=True) 90 91 return topics