debeir.rankers.transformer_sent_encoder

 1from hashlib import md5
 2from typing import List
 3
 4import sentence_transformers
 5import spacy
 6import torch
 7import torch.nn.functional as F
 8from analysis_tools_ir.utils import cache
 9
10EMBEDDING_DIM_SIZE = 768
11
12
13class Encoder:
14    """
15    A wrapper for the Sentence Transformer Encoder used in Universal Sentence Embeddings (USE) for ranking or reranking.
16
17    :param model_path: The path to a sentence transformer or transformer model.
18    :param normalize: Normalize the output vectors to unit length for dot product retrieval rather than cosine.
19    :param spacy_model: the spacy or scispacy model to use for sentence boundary detection.
20    :param max_length: Maximum input length for the spacy nlp model.
21    """
22
23    def __init__(
24            self,
25            model_path,
26            normalize=False,
27            spacy_model="en_core_sci_md",
28            max_length=2000000,
29    ):
30        self.model = sentence_transformers.SentenceTransformer(model_path)
31        self.model_path = model_path
32        self.nlp = spacy.load(spacy_model)
33        self.spacy_model = spacy_model
34        self.max_length = max_length
35        self.nlp.max_length = max_length
36        self.normalize = normalize
37
38    @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/")
39    def encode(self, topic: str) -> List:
40        """
41        Computes sentence embeddings for a given topic, uses spacy for sentence segmentation.
42        By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.
43
44        :param topic: The topic (a list of sentences) to encode. Should be a raw string.
45        :param disable_cache: keyword argument, pass as True to disable encoding caching.
46        :return:
47            Returns a list of encoded tensors is returned.
48        """
49        sentences = [
50            " ".join(sent.text.split())
51            for sent in self.nlp(topic).sents
52            if sent.text.strip()
53        ]
54
55        embeddings = self.model.encode(sentences, convert_to_tensor=True,
56                                       show_progress_bar=False)
57
58        if len(embeddings.size()) == 1:
59            embeddings = torch.unsqueeze(embeddings, dim=0)
60            embeddings = torch.mean(embeddings, axis=0)
61
62        if self.normalize:
63            embeddings = F.normalize(embeddings, dim=-1)
64
65        embeddings = embeddings.tolist()
66
67        if isinstance(embeddings, list) and isinstance(embeddings[0], list):
68            return embeddings[0]
69
70        return embeddings
71
72    def __call__(self, topic, *args, **kwargs) -> List:
73        return self.encode(topic)
74
75    def __eq__(self, other):
76        return (
77                self.model_path == other.model_path
78                and self.spacy_model == other.spacy_model
79                and self.normalize == other.normalize
80                and self.max_length == other.max_length
81        )
82
83    def __hash__(self):
84        return int(
85            md5(
86                (self.model_path
87                 + self.spacy_model
88                 + str(self.normalize)
89                 + str(self.max_length)
90                 ).encode()
91            ).hexdigest(),
92            16,
93        )
class Encoder:
14class Encoder:
15    """
16    A wrapper for the Sentence Transformer Encoder used in Universal Sentence Embeddings (USE) for ranking or reranking.
17
18    :param model_path: The path to a sentence transformer or transformer model.
19    :param normalize: Normalize the output vectors to unit length for dot product retrieval rather than cosine.
20    :param spacy_model: the spacy or scispacy model to use for sentence boundary detection.
21    :param max_length: Maximum input length for the spacy nlp model.
22    """
23
24    def __init__(
25            self,
26            model_path,
27            normalize=False,
28            spacy_model="en_core_sci_md",
29            max_length=2000000,
30    ):
31        self.model = sentence_transformers.SentenceTransformer(model_path)
32        self.model_path = model_path
33        self.nlp = spacy.load(spacy_model)
34        self.spacy_model = spacy_model
35        self.max_length = max_length
36        self.nlp.max_length = max_length
37        self.normalize = normalize
38
39    @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/")
40    def encode(self, topic: str) -> List:
41        """
42        Computes sentence embeddings for a given topic, uses spacy for sentence segmentation.
43        By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.
44
45        :param topic: The topic (a list of sentences) to encode. Should be a raw string.
46        :param disable_cache: keyword argument, pass as True to disable encoding caching.
47        :return:
48            Returns a list of encoded tensors is returned.
49        """
50        sentences = [
51            " ".join(sent.text.split())
52            for sent in self.nlp(topic).sents
53            if sent.text.strip()
54        ]
55
56        embeddings = self.model.encode(sentences, convert_to_tensor=True,
57                                       show_progress_bar=False)
58
59        if len(embeddings.size()) == 1:
60            embeddings = torch.unsqueeze(embeddings, dim=0)
61            embeddings = torch.mean(embeddings, axis=0)
62
63        if self.normalize:
64            embeddings = F.normalize(embeddings, dim=-1)
65
66        embeddings = embeddings.tolist()
67
68        if isinstance(embeddings, list) and isinstance(embeddings[0], list):
69            return embeddings[0]
70
71        return embeddings
72
73    def __call__(self, topic, *args, **kwargs) -> List:
74        return self.encode(topic)
75
76    def __eq__(self, other):
77        return (
78                self.model_path == other.model_path
79                and self.spacy_model == other.spacy_model
80                and self.normalize == other.normalize
81                and self.max_length == other.max_length
82        )
83
84    def __hash__(self):
85        return int(
86            md5(
87                (self.model_path
88                 + self.spacy_model
89                 + str(self.normalize)
90                 + str(self.max_length)
91                 ).encode()
92            ).hexdigest(),
93            16,
94        )

A wrapper for the Sentence Transformer Encoder used in Universal Sentence Embeddings (USE) for ranking or reranking.

Parameters
  • model_path: The path to a sentence transformer or transformer model.
  • normalize: Normalize the output vectors to unit length for dot product retrieval rather than cosine.
  • spacy_model: the spacy or scispacy model to use for sentence boundary detection.
  • max_length: Maximum input length for the spacy nlp model.
Encoder( model_path, normalize=False, spacy_model='en_core_sci_md', max_length=2000000)
24    def __init__(
25            self,
26            model_path,
27            normalize=False,
28            spacy_model="en_core_sci_md",
29            max_length=2000000,
30    ):
31        self.model = sentence_transformers.SentenceTransformer(model_path)
32        self.model_path = model_path
33        self.nlp = spacy.load(spacy_model)
34        self.spacy_model = spacy_model
35        self.max_length = max_length
36        self.nlp.max_length = max_length
37        self.normalize = normalize
@cache.Cache(hash_self=True, cache_dir='./cache/embedding_cache/')
def encode(self, topic: str) -> List:
39    @cache.Cache(hash_self=True, cache_dir="./cache/embedding_cache/")
40    def encode(self, topic: str) -> List:
41        """
42        Computes sentence embeddings for a given topic, uses spacy for sentence segmentation.
43        By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.
44
45        :param topic: The topic (a list of sentences) to encode. Should be a raw string.
46        :param disable_cache: keyword argument, pass as True to disable encoding caching.
47        :return:
48            Returns a list of encoded tensors is returned.
49        """
50        sentences = [
51            " ".join(sent.text.split())
52            for sent in self.nlp(topic).sents
53            if sent.text.strip()
54        ]
55
56        embeddings = self.model.encode(sentences, convert_to_tensor=True,
57                                       show_progress_bar=False)
58
59        if len(embeddings.size()) == 1:
60            embeddings = torch.unsqueeze(embeddings, dim=0)
61            embeddings = torch.mean(embeddings, axis=0)
62
63        if self.normalize:
64            embeddings = F.normalize(embeddings, dim=-1)
65
66        embeddings = embeddings.tolist()
67
68        if isinstance(embeddings, list) and isinstance(embeddings[0], list):
69            return embeddings[0]
70
71        return embeddings

Computes sentence embeddings for a given topic, uses spacy for sentence segmentation. By default, uses a cache to store previously computed vectors. Pass "disable_cache" as a kwarg to disable this.

Parameters
  • topic: The topic (a list of sentences) to encode. Should be a raw string.
  • disable_cache: keyword argument, pass as True to disable encoding caching.
Returns
Returns a list of encoded tensors is returned.