debeir.training.utils
1from typing import List, Union 2 3import loguru 4import transformers 5from debeir.datasets.types import InputExample, RelevanceExample 6from sentence_transformers import SentenceTransformer, losses, models 7from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SentenceEvaluator 8from torch.optim.lr_scheduler import LambdaLR 9from torch.utils.data import DataLoader 10from wandb import wandb 11 12import datasets 13 14 15# from sentence_transformers import InputExample 16 17 18class LoggingScheduler: 19 def __init__(self, scheduler: LambdaLR): 20 self.scheduler = scheduler 21 22 def step(self, epoch=None): 23 self.scheduler.step(epoch) 24 25 last_lr = self.scheduler.get_last_lr() 26 27 for i, lr in enumerate(last_lr): 28 wandb.log({f"lr_{i}": lr}) 29 30 def __getattr__(self, attr): 31 return getattr(self.scheduler, attr) 32 33 34def get_scheduler_with_wandb(optimizer, scheduler: str, warmup_steps: int, t_total: int): 35 """ 36 Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts 37 """ 38 scheduler = scheduler.lower() 39 loguru.logger.info(f"Creating scheduler: {scheduler}") 40 41 if scheduler == 'constantlr': 42 sched = transformers.get_constant_schedule(optimizer) 43 elif scheduler == 'warmupconstant': 44 sched = transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps) 45 elif scheduler == 'warmuplinear': 46 sched = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 47 num_training_steps=t_total) 48 elif scheduler == 'warmupcosine': 49 sched = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 50 num_training_steps=t_total) 51 elif scheduler == 'warmupcosinewithhardrestarts': 52 sched = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, 53 num_warmup_steps=warmup_steps, 54 num_training_steps=t_total) 55 else: 56 raise ValueError("Unknown scheduler {}".format(scheduler)) 57 58 return LoggingScheduler(sched) 59 60 61class LoggingLoss: 62 def __init__(self, loss_fn): 63 self.loss_fn = loss_fn 64 65 def __call__(self, *args, **kwargs): 66 loss = self.loss_fn(*args, **kwargs) 67 wandb.log({'train_loss': loss}) 68 return loss 69 70 def __getattr__(self, attr): 71 return getattr(self.loss_fn, attr) 72 73 74class TokenizerOverload: 75 def __init__(self, tokenizer, tokenizer_kwargs, debug=False): 76 self.tokenizer = tokenizer 77 self.tokenizer_kwargs = tokenizer_kwargs 78 self.debug = debug 79 self.max_length = -1 80 81 def __call__(self, *args, **kwargs): 82 if self.debug: 83 print(str(args), str(kwargs)) 84 85 kwargs.update(self.tokenizer_kwargs) 86 output = self.tokenizer(*args, **kwargs) 87 88 return output 89 90 def __getattr__(self, attr): 91 if self.debug: 92 print(str(attr)) 93 94 return getattr(self.tokenizer, attr) 95 96 97class LoggingEvaluator: 98 def __init__(self, evaluator): 99 self.evaluator = evaluator 100 101 def __call__(self, *args, **kwargs): 102 scores = self.evaluator(*args, **kwargs) 103 wandb.log({'val_acc': scores}) 104 105 return scores 106 107 def __getattr__(self, attr): 108 return getattr(self.evaluator, attr) 109 110 111class SentDataset: 112 def __init__(self, dataset: datasets.Dataset, text_cols: List[str], 113 label_col: str = None, label=None): 114 self.dataset = dataset 115 self.text_cols = text_cols 116 self.label_col = label_col 117 self.label = label 118 119 def __getitem__(self, idx): 120 item = self.dataset[idx] 121 122 texts = [] 123 124 for text_col in self.text_cols: 125 texts.append(item[text_col]) 126 127 example = InputExample(texts=texts) 128 129 if self.label_col: 130 example.label = item[self.label_col] 131 else: 132 if self.label: 133 example.label = self.label 134 135 return example 136 137 def __len__(self): 138 return len(self.dataset) 139 140 141class SentDatasetList: 142 def __init__(self, datasets: List[SentDataset]): 143 self.datasets = datasets 144 self.lengths = [len(dataset) for dataset in self.datasets] 145 self.total_length = sum(self.lengths) 146 147 def __getitem__(self, idx): 148 i, c = 0, 0 149 150 for i, length in enumerate(self.lengths): 151 if idx - c == 0: 152 idx = 0 153 break 154 if idx - c < length: 155 idx = idx - c 156 break 157 158 c = c + self.lengths[i] 159 160 return self.datasets[i][idx] 161 162 def __len__(self): 163 return self.total_length 164 165 166def _train_sentence_transformer(model_fp_or_name: str, output_dir: str, 167 train_dataset: List[Union[RelevanceExample, InputExample]], 168 eval_dataset: List[Union[RelevanceExample, InputExample]], 169 train_batch_size=32, num_epochs=3, 170 warmup_steps=None, evaluate_every_n_step: int = 1000, special_tokens=None, 171 pooling_mode=None, loss_func=None, evaluator: SentenceEvaluator = None): 172 """ 173 Train a sentence transformer model 174 175 Returns the model for evaluation 176 """ 177 178 encoder = models.Transformer(model_fp_or_name) 179 180 if special_tokens: 181 encoder.tokenizer.add_tokens(special_tokens, special_tokens=True) 182 encoder.auto_model.resize_token_embeddings(len(encoder.tokenizer)) 183 184 pooling_model = models.Pooling(encoder.get_word_embedding_dimension(), 185 pooling_mode=pooling_mode) 186 187 model = SentenceTransformer(modules=[encoder, pooling_model]) 188 189 train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) 190 191 if loss_func is None: 192 loss_func = losses.CosineSimilarityLoss(model=model) 193 194 if evaluator is None: 195 evaluator = EmbeddingSimilarityEvaluator.from_input_examples(eval_dataset) 196 197 model.fit(train_objectives=[(train_dataloader, loss_func)], 198 evaluator=evaluator, 199 epochs=num_epochs, 200 evaluation_steps=evaluate_every_n_step, 201 warmup_steps=warmup_steps if warmup_steps else (num_epochs * len(train_dataset)) // 20, 202 output_path=output_dir) 203 204 return model 205 206 207def tokenize_function(tokenizer, examples, padding_strategy, truncate): 208 """ 209 Tokenizer function 210 211 :param tokenizer: Tokenizer 212 :param examples: Input examples to tokenize 213 :param padding_strategy: Padding strategy 214 :param truncate: Truncate sentences 215 :return: 216 Returns a list of tokenized examples 217 """ 218 return tokenizer(examples["text"], 219 padding=padding_strategy, 220 truncation=truncate) 221 222 223def get_max_seq_length(tokenizer, dataset, x_labels, dataset_key="train"): 224 dataset = dataset.map(lambda example: tokenizer([example[x_label] for x_label in x_labels])) 225 226 max_length = -1 227 for example in dataset[dataset_key]['attention_mask']: 228 length = max(sum(x) for x in example) 229 if length > max_length: 230 max_length = length 231 232 return max_length
class
LoggingScheduler:
19class LoggingScheduler: 20 def __init__(self, scheduler: LambdaLR): 21 self.scheduler = scheduler 22 23 def step(self, epoch=None): 24 self.scheduler.step(epoch) 25 26 last_lr = self.scheduler.get_last_lr() 27 28 for i, lr in enumerate(last_lr): 29 wandb.log({f"lr_{i}": lr}) 30 31 def __getattr__(self, attr): 32 return getattr(self.scheduler, attr)
def
get_scheduler_with_wandb(optimizer, scheduler: str, warmup_steps: int, t_total: int):
35def get_scheduler_with_wandb(optimizer, scheduler: str, warmup_steps: int, t_total: int): 36 """ 37 Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts 38 """ 39 scheduler = scheduler.lower() 40 loguru.logger.info(f"Creating scheduler: {scheduler}") 41 42 if scheduler == 'constantlr': 43 sched = transformers.get_constant_schedule(optimizer) 44 elif scheduler == 'warmupconstant': 45 sched = transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps) 46 elif scheduler == 'warmuplinear': 47 sched = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 48 num_training_steps=t_total) 49 elif scheduler == 'warmupcosine': 50 sched = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 51 num_training_steps=t_total) 52 elif scheduler == 'warmupcosinewithhardrestarts': 53 sched = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, 54 num_warmup_steps=warmup_steps, 55 num_training_steps=t_total) 56 else: 57 raise ValueError("Unknown scheduler {}".format(scheduler)) 58 59 return LoggingScheduler(sched)
Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
class
LoggingLoss:
class
TokenizerOverload:
75class TokenizerOverload: 76 def __init__(self, tokenizer, tokenizer_kwargs, debug=False): 77 self.tokenizer = tokenizer 78 self.tokenizer_kwargs = tokenizer_kwargs 79 self.debug = debug 80 self.max_length = -1 81 82 def __call__(self, *args, **kwargs): 83 if self.debug: 84 print(str(args), str(kwargs)) 85 86 kwargs.update(self.tokenizer_kwargs) 87 output = self.tokenizer(*args, **kwargs) 88 89 return output 90 91 def __getattr__(self, attr): 92 if self.debug: 93 print(str(attr)) 94 95 return getattr(self.tokenizer, attr)
class
LoggingEvaluator:
98class LoggingEvaluator: 99 def __init__(self, evaluator): 100 self.evaluator = evaluator 101 102 def __call__(self, *args, **kwargs): 103 scores = self.evaluator(*args, **kwargs) 104 wandb.log({'val_acc': scores}) 105 106 return scores 107 108 def __getattr__(self, attr): 109 return getattr(self.evaluator, attr)
class
SentDataset:
112class SentDataset: 113 def __init__(self, dataset: datasets.Dataset, text_cols: List[str], 114 label_col: str = None, label=None): 115 self.dataset = dataset 116 self.text_cols = text_cols 117 self.label_col = label_col 118 self.label = label 119 120 def __getitem__(self, idx): 121 item = self.dataset[idx] 122 123 texts = [] 124 125 for text_col in self.text_cols: 126 texts.append(item[text_col]) 127 128 example = InputExample(texts=texts) 129 130 if self.label_col: 131 example.label = item[self.label_col] 132 else: 133 if self.label: 134 example.label = self.label 135 136 return example 137 138 def __len__(self): 139 return len(self.dataset)
class
SentDatasetList:
142class SentDatasetList: 143 def __init__(self, datasets: List[SentDataset]): 144 self.datasets = datasets 145 self.lengths = [len(dataset) for dataset in self.datasets] 146 self.total_length = sum(self.lengths) 147 148 def __getitem__(self, idx): 149 i, c = 0, 0 150 151 for i, length in enumerate(self.lengths): 152 if idx - c == 0: 153 idx = 0 154 break 155 if idx - c < length: 156 idx = idx - c 157 break 158 159 c = c + self.lengths[i] 160 161 return self.datasets[i][idx] 162 163 def __len__(self): 164 return self.total_length
SentDatasetList(datasets: List[debeir.training.utils.SentDataset])
def
tokenize_function(tokenizer, examples, padding_strategy, truncate):
208def tokenize_function(tokenizer, examples, padding_strategy, truncate): 209 """ 210 Tokenizer function 211 212 :param tokenizer: Tokenizer 213 :param examples: Input examples to tokenize 214 :param padding_strategy: Padding strategy 215 :param truncate: Truncate sentences 216 :return: 217 Returns a list of tokenized examples 218 """ 219 return tokenizer(examples["text"], 220 padding=padding_strategy, 221 truncation=truncate)
Tokenizer function
Parameters
- tokenizer: Tokenizer
- examples: Input examples to tokenize
- padding_strategy: Padding strategy
- truncate: Truncate sentences
Returns
Returns a list of tokenized examples
def
get_max_seq_length(tokenizer, dataset, x_labels, dataset_key='train'):
224def get_max_seq_length(tokenizer, dataset, x_labels, dataset_key="train"): 225 dataset = dataset.map(lambda example: tokenizer([example[x_label] for x_label in x_labels])) 226 227 max_length = -1 228 for example in dataset[dataset_key]['attention_mask']: 229 length = max(sum(x) for x in example) 230 if length > max_length: 231 max_length = length 232 233 return max_length