debeir.training.utils

  1from typing import List, Union
  2
  3import loguru
  4import transformers
  5from debeir.datasets.types import InputExample, RelevanceExample
  6from sentence_transformers import SentenceTransformer, losses, models
  7from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator
  8from torch.optim.lr_scheduler import LambdaLR
  9from torch.utils.data import DataLoader
 10from wandb import wandb
 11
 12import datasets
 13
 14
 15# from sentence_transformers import InputExample
 16
 17
 18class LoggingScheduler:
 19    def __init__(self, scheduler: LambdaLR):
 20        self.scheduler = scheduler
 21
 22    def step(self, epoch=None):
 23        self.scheduler.step(epoch)
 24
 25        last_lr = self.scheduler.get_last_lr()
 26
 27        for i, lr in enumerate(last_lr):
 28            wandb.log({f"lr_{i}": lr})
 29
 30    def __getattr__(self, attr):
 31        return getattr(self.scheduler, attr)
 32
 33
 34def get_scheduler_with_wandb(optimizer, scheduler: str, warmup_steps: int, t_total: int):
 35    """
 36    Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
 37    """
 38    scheduler = scheduler.lower()
 39    loguru.logger.info(f"Creating scheduler: {scheduler}")
 40
 41    if scheduler == 'constantlr':
 42        sched = transformers.get_constant_schedule(optimizer)
 43    elif scheduler == 'warmupconstant':
 44        sched = transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
 45    elif scheduler == 'warmuplinear':
 46        sched = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
 47                                                             num_training_steps=t_total)
 48    elif scheduler == 'warmupcosine':
 49        sched = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
 50                                                             num_training_steps=t_total)
 51    elif scheduler == 'warmupcosinewithhardrestarts':
 52        sched = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,
 53                                                                                num_warmup_steps=warmup_steps,
 54                                                                                num_training_steps=t_total)
 55    else:
 56        raise ValueError("Unknown scheduler {}".format(scheduler))
 57
 58    return LoggingScheduler(sched)
 59
 60
 61class LoggingLoss:
 62    def __init__(self, loss_fn):
 63        self.loss_fn = loss_fn
 64
 65    def __call__(self, *args, **kwargs):
 66        loss = self.loss_fn(*args, **kwargs)
 67        wandb.log({'train_loss': loss})
 68        return loss
 69
 70    def __getattr__(self, attr):
 71        return getattr(self.loss_fn, attr)
 72
 73
 74class TokenizerOverload:
 75    def __init__(self, tokenizer, tokenizer_kwargs, debug=False):
 76        self.tokenizer = tokenizer
 77        self.tokenizer_kwargs = tokenizer_kwargs
 78        self.debug = debug
 79        self.max_length = -1
 80
 81    def __call__(self, *args, **kwargs):
 82        if self.debug:
 83            print(str(args), str(kwargs))
 84
 85        kwargs.update(self.tokenizer_kwargs)
 86        output = self.tokenizer(*args, **kwargs)
 87
 88        return output
 89
 90    def __getattr__(self, attr):
 91        if self.debug:
 92            print(str(attr))
 93
 94        return getattr(self.tokenizer, attr)
 95
 96
 97class LoggingEvaluator:
 98    def __init__(self, evaluator):
 99        self.evaluator = evaluator
100
101    def __call__(self, *args, **kwargs):
102        scores = self.evaluator(*args, **kwargs)
103        wandb.log({'val_acc': scores})
104
105        return scores
106
107    def __getattr__(self, attr):
108        return getattr(self.evaluator, attr)
109
110
111class SentDataset:
112    def __init__(self, dataset: datasets.Dataset, text_cols: List[str],
113                 label_col: str = None, label=None):
114        self.dataset = dataset
115        self.text_cols = text_cols
116        self.label_col = label_col
117        self.label = label
118
119    def __getitem__(self, idx):
120        item = self.dataset[idx]
121
122        texts = []
123
124        for text_col in self.text_cols:
125            texts.append(item[text_col])
126
127        example = InputExample(texts=texts)
128
129        if self.label_col:
130            example.label = item[self.label_col]
131        else:
132            if self.label:
133                example.label = self.label
134
135        return example
136
137    def __len__(self):
138        return len(self.dataset)
139
140
141class SentDatasetList:
142    def __init__(self, datasets: List[SentDataset]):
143        self.datasets = datasets
144        self.lengths = [len(dataset) for dataset in self.datasets]
145        self.total_length = sum(self.lengths)
146
147    def __getitem__(self, idx):
148        i, c = 0, 0
149
150        for i, length in enumerate(self.lengths):
151            if idx - c == 0:
152                idx = 0
153                break
154            if idx - c < length:
155                idx = idx - c
156                break
157
158            c = c + self.lengths[i]
159
160        return self.datasets[i][idx]
161
162    def __len__(self):
163        return self.total_length
164
165
166def _train_sentence_transformer(model_fp_or_name: str, output_dir: str,
167                                train_dataset: List[Union[RelevanceExample, InputExample]],
168                                eval_dataset: List[Union[RelevanceExample, InputExample]],
169                                train_batch_size=32, num_epochs=3,
170                                warmup_steps=None, evaluate_every_n_step: int = 1000, special_tokens=None,
171                                pooling_mode=None, loss_func=None, evaluator: SentenceEvaluator = None):
172    """
173        Train a sentence transformer model
174
175        Returns the model for evaluation
176    """
177
178    encoder = models.Transformer(model_fp_or_name)
179
180    if special_tokens:
181        encoder.tokenizer.add_tokens(special_tokens, special_tokens=True)
182        encoder.auto_model.resize_token_embeddings(len(encoder.tokenizer))
183
184    pooling_model = models.Pooling(encoder.get_word_embedding_dimension(),
185                                   pooling_mode=pooling_mode)
186
187    model = SentenceTransformer(modules=[encoder, pooling_model])
188
189    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
190
191    if loss_func is None:
192        loss_func = losses.CosineSimilarityLoss(model=model)
193
194    if evaluator is None:
195        evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_dataset)
196
197    model.fit(train_objectives=[(train_dataloader, loss_func)],
198              evaluator=evaluator,
199              epochs=num_epochs,
200              evaluation_steps=evaluate_every_n_step,
201              warmup_steps=warmup_steps if warmup_steps else (num_epochs * len(train_dataset)) // 20,
202              output_path=output_dir)
203
204    return model
205
206
207def tokenize_function(tokenizer, examples, padding_strategy, truncate):
208    """
209    Tokenizer function
210
211    :param tokenizer: Tokenizer
212    :param examples: Input examples to tokenize
213    :param padding_strategy: Padding strategy
214    :param truncate: Truncate sentences
215    :return:
216        Returns a list of tokenized examples
217    """
218    return tokenizer(examples["text"],
219                     padding=padding_strategy,
220                     truncation=truncate)
221
222
223def get_max_seq_length(tokenizer, dataset, x_labels, dataset_key="train"):
224    dataset = dataset.map(lambda example: tokenizer([example[x_label] for x_label in x_labels]))
225
226    max_length = -1
227    for example in dataset[dataset_key]['attention_mask']:
228        length = max(sum(x) for x in example)
229        if length > max_length:
230            max_length = length
231
232    return max_length
class LoggingScheduler:
19class LoggingScheduler:
20    def __init__(self, scheduler: LambdaLR):
21        self.scheduler = scheduler
22
23    def step(self, epoch=None):
24        self.scheduler.step(epoch)
25
26        last_lr = self.scheduler.get_last_lr()
27
28        for i, lr in enumerate(last_lr):
29            wandb.log({f"lr_{i}": lr})
30
31    def __getattr__(self, attr):
32        return getattr(self.scheduler, attr)
LoggingScheduler(scheduler: torch.optim.lr_scheduler.LambdaLR)
20    def __init__(self, scheduler: LambdaLR):
21        self.scheduler = scheduler
def step(self, epoch=None):
23    def step(self, epoch=None):
24        self.scheduler.step(epoch)
25
26        last_lr = self.scheduler.get_last_lr()
27
28        for i, lr in enumerate(last_lr):
29            wandb.log({f"lr_{i}": lr})
def get_scheduler_with_wandb(optimizer, scheduler: str, warmup_steps: int, t_total: int):
35def get_scheduler_with_wandb(optimizer, scheduler: str, warmup_steps: int, t_total: int):
36    """
37    Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
38    """
39    scheduler = scheduler.lower()
40    loguru.logger.info(f"Creating scheduler: {scheduler}")
41
42    if scheduler == 'constantlr':
43        sched = transformers.get_constant_schedule(optimizer)
44    elif scheduler == 'warmupconstant':
45        sched = transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
46    elif scheduler == 'warmuplinear':
47        sched = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
48                                                             num_training_steps=t_total)
49    elif scheduler == 'warmupcosine':
50        sched = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
51                                                             num_training_steps=t_total)
52    elif scheduler == 'warmupcosinewithhardrestarts':
53        sched = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,
54                                                                                num_warmup_steps=warmup_steps,
55                                                                                num_training_steps=t_total)
56    else:
57        raise ValueError("Unknown scheduler {}".format(scheduler))
58
59    return LoggingScheduler(sched)

Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts

class LoggingLoss:
62class LoggingLoss:
63    def __init__(self, loss_fn):
64        self.loss_fn = loss_fn
65
66    def __call__(self, *args, **kwargs):
67        loss = self.loss_fn(*args, **kwargs)
68        wandb.log({'train_loss': loss})
69        return loss
70
71    def __getattr__(self, attr):
72        return getattr(self.loss_fn, attr)
LoggingLoss(loss_fn)
63    def __init__(self, loss_fn):
64        self.loss_fn = loss_fn
class TokenizerOverload:
75class TokenizerOverload:
76    def __init__(self, tokenizer, tokenizer_kwargs, debug=False):
77        self.tokenizer = tokenizer
78        self.tokenizer_kwargs = tokenizer_kwargs
79        self.debug = debug
80        self.max_length = -1
81
82    def __call__(self, *args, **kwargs):
83        if self.debug:
84            print(str(args), str(kwargs))
85
86        kwargs.update(self.tokenizer_kwargs)
87        output = self.tokenizer(*args, **kwargs)
88
89        return output
90
91    def __getattr__(self, attr):
92        if self.debug:
93            print(str(attr))
94
95        return getattr(self.tokenizer, attr)
TokenizerOverload(tokenizer, tokenizer_kwargs, debug=False)
76    def __init__(self, tokenizer, tokenizer_kwargs, debug=False):
77        self.tokenizer = tokenizer
78        self.tokenizer_kwargs = tokenizer_kwargs
79        self.debug = debug
80        self.max_length = -1
class LoggingEvaluator:
 98class LoggingEvaluator:
 99    def __init__(self, evaluator):
100        self.evaluator = evaluator
101
102    def __call__(self, *args, **kwargs):
103        scores = self.evaluator(*args, **kwargs)
104        wandb.log({'val_acc': scores})
105
106        return scores
107
108    def __getattr__(self, attr):
109        return getattr(self.evaluator, attr)
LoggingEvaluator(evaluator)
 99    def __init__(self, evaluator):
100        self.evaluator = evaluator
class SentDataset:
112class SentDataset:
113    def __init__(self, dataset: datasets.Dataset, text_cols: List[str],
114                 label_col: str = None, label=None):
115        self.dataset = dataset
116        self.text_cols = text_cols
117        self.label_col = label_col
118        self.label = label
119
120    def __getitem__(self, idx):
121        item = self.dataset[idx]
122
123        texts = []
124
125        for text_col in self.text_cols:
126            texts.append(item[text_col])
127
128        example = InputExample(texts=texts)
129
130        if self.label_col:
131            example.label = item[self.label_col]
132        else:
133            if self.label:
134                example.label = self.label
135
136        return example
137
138    def __len__(self):
139        return len(self.dataset)
SentDataset( dataset: datasets.arrow_dataset.Dataset, text_cols: List[str], label_col: str = None, label=None)
113    def __init__(self, dataset: datasets.Dataset, text_cols: List[str],
114                 label_col: str = None, label=None):
115        self.dataset = dataset
116        self.text_cols = text_cols
117        self.label_col = label_col
118        self.label = label
class SentDatasetList:
142class SentDatasetList:
143    def __init__(self, datasets: List[SentDataset]):
144        self.datasets = datasets
145        self.lengths = [len(dataset) for dataset in self.datasets]
146        self.total_length = sum(self.lengths)
147
148    def __getitem__(self, idx):
149        i, c = 0, 0
150
151        for i, length in enumerate(self.lengths):
152            if idx - c == 0:
153                idx = 0
154                break
155            if idx - c < length:
156                idx = idx - c
157                break
158
159            c = c + self.lengths[i]
160
161        return self.datasets[i][idx]
162
163    def __len__(self):
164        return self.total_length
SentDatasetList(datasets: List[debeir.training.utils.SentDataset])
143    def __init__(self, datasets: List[SentDataset]):
144        self.datasets = datasets
145        self.lengths = [len(dataset) for dataset in self.datasets]
146        self.total_length = sum(self.lengths)
def tokenize_function(tokenizer, examples, padding_strategy, truncate):
208def tokenize_function(tokenizer, examples, padding_strategy, truncate):
209    """
210    Tokenizer function
211
212    :param tokenizer: Tokenizer
213    :param examples: Input examples to tokenize
214    :param padding_strategy: Padding strategy
215    :param truncate: Truncate sentences
216    :return:
217        Returns a list of tokenized examples
218    """
219    return tokenizer(examples["text"],
220                     padding=padding_strategy,
221                     truncation=truncate)

Tokenizer function

Parameters
  • tokenizer: Tokenizer
  • examples: Input examples to tokenize
  • padding_strategy: Padding strategy
  • truncate: Truncate sentences
Returns
Returns a list of tokenized examples
def get_max_seq_length(tokenizer, dataset, x_labels, dataset_key='train'):
224def get_max_seq_length(tokenizer, dataset, x_labels, dataset_key="train"):
225    dataset = dataset.map(lambda example: tokenizer([example[x_label] for x_label in x_labels]))
226
227    max_length = -1
228    for example in dataset[dataset_key]['attention_mask']:
229        length = max(sum(x) for x in example)
230        if length > max_length:
231            max_length = length
232
233    return max_length