debeir.rankers.transformer_sent_encoder
1from hashlib import md5 2from typing import List 3 4import sentence_transformers 5import spacy 6import torch 7import torch.nn.functional as F 8from analysis_tools_ir.utils import cache 9 10EMBEDDING_DIM_SIZE = 768 11 12 13class Encoder: 14 """ 15 A wrapper for the Sentence Transformer Encoder used in Universal Sentence Embeddings (USE) for ranking or reranking. 16 17 :param model_path: The path to a sentence transformer or transformer model. 18 :param normalize: Normalize the output vectors to unit length for dot product retrieval rather than cosine. 19 :param spacy_model: the spacy or scispacy model to use for sentence boundary detection. 20 :param max_length: Maximum input length for the spacy nlp model. 21 """ 22 23 def __init__( 24 self, 25 model_path, 26 normalize=False, 27 spacy_model="en_core_sci_md", 28 max_length=2000000, 29 ): 30 self.model = sentence_transformers.SentenceTransformer(model_path) 31 self.model_path = model_path 32 self.nlp = spacy.load(spacy_model) 33 self.spacy_model = spacy_model 34 self.max_length = max_length 35 self.nlp.max_length = max_length 36 self.normalize = normalize 37 38 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") 39 def encode(self, topic: str) -> List: 40 """ 41 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. 42 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. 43 44 :param topic: The topic (a list of sentences) to encode. Should be a raw string. 45 :param disable_cache: keyword argument, pass as True to disable encoding caching. 46 :return: 47 Returns a list of encoded tensors is returned. 48 """ 49 sentences = [ 50 " ".join(sent.text.split()) 51 for sent in self.nlp(topic).sents 52 if sent.text.strip() 53 ] 54 55 embeddings = self.model.encode(sentences, convert_to_tensor=True, 56 show_progress_bar=False) 57 58 if len(embeddings.size()) == 1: 59 embeddings = torch.unsqueeze(embeddings, dim=0) 60 embeddings = torch.mean(embeddings, axis=0) 61 62 if self.normalize: 63 embeddings = F.normalize(embeddings, dim=-1) 64 65 embeddings = embeddings.tolist() 66 67 if isinstance(embeddings, list) and isinstance(embeddings[0], list): 68 return embeddings[0] 69 70 return embeddings 71 72 def __call__(self, topic, *args, **kwargs) -> List: 73 return self.encode(topic) 74 75 def __eq__(self, other): 76 return ( 77 self.model_path == other.model_path 78 and self.spacy_model == other.spacy_model 79 and self.normalize == other.normalize 80 and self.max_length == other.max_length 81 ) 82 83 def __hash__(self): 84 return int( 85 md5( 86 (self.model_path 87 + self.spacy_model 88 + str(self.normalize) 89 + str(self.max_length) 90 ).encode() 91 ).hexdigest(), 92 16, 93 )
class
Encoder:
14class Encoder: 15 """ 16 A wrapper for the Sentence Transformer Encoder used in Universal Sentence Embeddings (USE) for ranking or reranking. 17 18 :param model_path: The path to a sentence transformer or transformer model. 19 :param normalize: Normalize the output vectors to unit length for dot product retrieval rather than cosine. 20 :param spacy_model: the spacy or scispacy model to use for sentence boundary detection. 21 :param max_length: Maximum input length for the spacy nlp model. 22 """ 23 24 def __init__( 25 self, 26 model_path, 27 normalize=False, 28 spacy_model="en_core_sci_md", 29 max_length=2000000, 30 ): 31 self.model = sentence_transformers.SentenceTransformer(model_path) 32 self.model_path = model_path 33 self.nlp = spacy.load(spacy_model) 34 self.spacy_model = spacy_model 35 self.max_length = max_length 36 self.nlp.max_length = max_length 37 self.normalize = normalize 38 39 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") 40 def encode(self, topic: str) -> List: 41 """ 42 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. 43 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. 44 45 :param topic: The topic (a list of sentences) to encode. Should be a raw string. 46 :param disable_cache: keyword argument, pass as True to disable encoding caching. 47 :return: 48 Returns a list of encoded tensors is returned. 49 """ 50 sentences = [ 51 " ".join(sent.text.split()) 52 for sent in self.nlp(topic).sents 53 if sent.text.strip() 54 ] 55 56 embeddings = self.model.encode(sentences, convert_to_tensor=True, 57 show_progress_bar=False) 58 59 if len(embeddings.size()) == 1: 60 embeddings = torch.unsqueeze(embeddings, dim=0) 61 embeddings = torch.mean(embeddings, axis=0) 62 63 if self.normalize: 64 embeddings = F.normalize(embeddings, dim=-1) 65 66 embeddings = embeddings.tolist() 67 68 if isinstance(embeddings, list) and isinstance(embeddings[0], list): 69 return embeddings[0] 70 71 return embeddings 72 73 def __call__(self, topic, *args, **kwargs) -> List: 74 return self.encode(topic) 75 76 def __eq__(self, other): 77 return ( 78 self.model_path == other.model_path 79 and self.spacy_model == other.spacy_model 80 and self.normalize == other.normalize 81 and self.max_length == other.max_length 82 ) 83 84 def __hash__(self): 85 return int( 86 md5( 87 (self.model_path 88 + self.spacy_model 89 + str(self.normalize) 90 + str(self.max_length) 91 ).encode() 92 ).hexdigest(), 93 16, 94 )
A wrapper for the Sentence Transformer Encoder used in Universal Sentence Embeddings (USE) for ranking or reranking.
Parameters
- model_path: The path to a sentence transformer or transformer model.
- normalize: Normalize the output vectors to unit length for dot product retrieval rather than cosine.
- spacy_model: the spacy or scispacy model to use for sentence boundary detection.
- max_length: Maximum input length for the spacy nlp model.
Encoder( model_path, normalize=False, spacy_model='en_core_sci_md', max_length=2000000)
24 def __init__( 25 self, 26 model_path, 27 normalize=False, 28 spacy_model="en_core_sci_md", 29 max_length=2000000, 30 ): 31 self.model = sentence_transformers.SentenceTransformer(model_path) 32 self.model_path = model_path 33 self.nlp = spacy.load(spacy_model) 34 self.spacy_model = spacy_model 35 self.max_length = max_length 36 self.nlp.max_length = max_length 37 self.normalize = normalize
@cache.Cache(hash_self=True, cache_dir='./cache/embedding_cache/')
def
encode(self, topic: str) -> List:
39 @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/") 40 def encode(self, topic: str) -> List: 41 """ 42 Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. 43 By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this. 44 45 :param topic: The topic (a list of sentences) to encode. Should be a raw string. 46 :param disable_cache: keyword argument, pass as True to disable encoding caching. 47 :return: 48 Returns a list of encoded tensors is returned. 49 """ 50 sentences = [ 51 " ".join(sent.text.split()) 52 for sent in self.nlp(topic).sents 53 if sent.text.strip() 54 ] 55 56 embeddings = self.model.encode(sentences, convert_to_tensor=True, 57 show_progress_bar=False) 58 59 if len(embeddings.size()) == 1: 60 embeddings = torch.unsqueeze(embeddings, dim=0) 61 embeddings = torch.mean(embeddings, axis=0) 62 63 if self.normalize: 64 embeddings = F.normalize(embeddings, dim=-1) 65 66 embeddings = embeddings.tolist() 67 68 if isinstance(embeddings, list) and isinstance(embeddings[0], list): 69 return embeddings[0] 70 71 return embeddings
Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.
Parameters
- topic: The topic (a list of sentences) to encode. Should be a raw string.
- disable_cache: keyword argument, pass as True to disable encoding caching.
Returns
Returns a list of encoded tensors is returned.