debeir.training.hparm_tuning.trainer

  1import abc
  2from collections import defaultdict
  3from functools import partial
  4from typing import Dict, Sequence, Union
  5
  6import loguru
  7import optuna
  8import torch
  9import torch_optimizer
 10from debeir.training.hparm_tuning.config import HparamConfig
 11from debeir.training.hparm_tuning.types import Hparam
 12from debeir.training.utils import LoggingEvaluator, LoggingLoss
 13from sentence_transformers import SentenceTransformer, losses
 14from torch.utils.data import DataLoader
 15from wandb import wandb
 16
 17from datasets import Dataset, DatasetDict
 18
 19
 20class OptimizersWrapper:
 21    def __getattr__(self, name):
 22        if name in torch.optim.__dict__:
 23            return getattr(torch.optim, name)
 24        elif name in torch_optimizer.__dict__:
 25            return getattr(torch_optimizer, name)
 26        else:
 27            raise ModuleNotFoundError("Optimizer is not implemented, doesn't exist or is not supported.")
 28
 29
 30class Trainer:
 31    """
 32    Wrapper class for a trainer class.
 33
 34    """
 35
 36    def __init__(self, model, evaluator_fn, dataset_loading_fn):
 37        self.evaluator_fn = evaluator_fn
 38        self.model_cls = model  # Trainer object or method we will initialize
 39        self.dataset_loading_fn = dataset_loading_fn
 40
 41    @abc.abstractmethod
 42    def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset):
 43        raise NotImplementedError()
 44
 45
 46class SentenceTransformerHparamTrainer(Trainer):
 47    " See Optuna documentation for types! "
 48    model: SentenceTransformer
 49
 50    def __init__(self, dataset_loading_fn, evaluator_fn, hparams_config: HparamConfig):
 51        super().__init__(SentenceTransformer, evaluator_fn, dataset_loading_fn)
 52        self.loss_fn = None
 53        self.hparams = hparams_config.parse_config_to_py() if hparams_config else None
 54
 55    def get_optuna_hparams(self, trial: optuna.Trial, hparams: Sequence[Hparam] = None):
 56        """
 57        Get hyperparameters suggested by the optuna library
 58
 59        :param trial: The optuna trial object
 60        :param hparams: Optional, pass a dictionary of HparamType[Enum] objects
 61        :return:
 62        """
 63
 64        loguru.logger.info("Fitting the trainer.")
 65
 66        hparam_values = defaultdict(lambda: 0.0)
 67
 68        hparams = hparams if hparams else self.hparams
 69
 70        if hparams is None:
 71            raise RuntimeError("No hyperparameters were specified")
 72
 73        for key, hparam in hparams.items():
 74            if hasattr(hparam, 'suggest'):
 75                hparam_values[hparam.name] = hparam.suggest(trial)
 76                loguru.logger.info(f"Using {hparam_values[hparam.name]} for {hparam.name}.")
 77            else:
 78                hparam_values[key] = hparam
 79
 80        return hparam_values
 81
 82    def build_kwargs_and_model(self, hparams: Dict):
 83        kwargs = {}
 84
 85        for hparam, hparam_value in list(hparams.items()):
 86            loguru.logger.info(f"Building model with {hparam}: {hparam_value}")
 87
 88            if hparam == "lr":
 89                kwargs["optimizer_params"] = {
 90                    "lr": hparam_value
 91                }
 92            elif hparam == "model_name":
 93                self.model = self.model_cls(hparam_value)
 94            elif hparam == "optimizer":
 95                kwargs["optimizer_class"] = getattr(OptimizersWrapper(), hparam_value)
 96            elif hparam == "loss_fn":
 97                self.loss_fn = getattr(losses, hparam_value)
 98            else:
 99                kwargs[hparam] = hparam_value
100
101        return kwargs
102
103    def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset):
104        hparams = self.get_optuna_hparams(in_trial)
105        kwargs = self.build_kwargs_and_model(hparams)
106
107        evaluator = self.evaluator_fn.from_input_examples(val_dataset)
108        loss = self.loss_fn(model=self.model)
109        train_dataloader = DataLoader(train_dataset, shuffle=True,
110                                      batch_size=int(kwargs.pop("batch_size")), drop_last=True)
111
112        self.model.fit(
113            train_objectives=[(train_dataloader, loss)],
114            **kwargs,
115            evaluator=evaluator,
116            use_amp=True,
117            callback=partial(trial_callback, in_trial)
118        )
119
120        return self.model.evaluate(evaluator)
121
122
123def trial_callback(trial, score, epoch, *args, **kwargs):
124    trial.report(score, epoch)
125    # Handle pruning based on the intermediate value
126    if trial.should_prune():
127        raise optuna.exceptions.TrialPruned()
128
129
130class SentenceTransformerTrainer(SentenceTransformerHparamTrainer):
131    def __init__(self, dataset: Union[DatasetDict, Dict[str, Dataset]], hparams_config: HparamConfig,
132                 evaluator_fn=None, evaluator=None, use_wandb=False):
133        super().__init__(None, evaluator_fn, hparams_config)
134        self.evaluator = evaluator
135        self.use_wandb = use_wandb
136        self.dataset = dataset
137
138    def fit(self, **extra_kwargs):
139        kwargs = self.build_kwargs_and_model(self.hparams)
140
141        if not self.evaluator:
142            self.evaluator = LoggingEvaluator(self.evaluator_fn.from_input_examples(self.dataset['val']), wandb)
143
144        loss = self.loss_fn(model=self.model)
145
146        if self.use_wandb:
147            wandb.watch(self.model)
148            loss = LoggingLoss(loss, wandb)
149
150        train_dataloader = DataLoader(self.dataset['train'], shuffle=True,
151                                      batch_size=int(kwargs.pop("batch_size")),
152                                      drop_last=True)
153
154        self.model.fit(
155            train_objectives=[(train_dataloader, loss)],
156            **kwargs,
157            evaluator=self.evaluator,
158            use_amp=True,
159            **extra_kwargs
160        )
161
162        return self.model.evaluate(self.evaluator)
class OptimizersWrapper:
21class OptimizersWrapper:
22    def __getattr__(self, name):
23        if name in torch.optim.__dict__:
24            return getattr(torch.optim, name)
25        elif name in torch_optimizer.__dict__:
26            return getattr(torch_optimizer, name)
27        else:
28            raise ModuleNotFoundError("Optimizer is not implemented, doesn't exist or is not supported.")
OptimizersWrapper()
class Trainer:
31class Trainer:
32    """
33    Wrapper class for a trainer class.
34
35    """
36
37    def __init__(self, model, evaluator_fn, dataset_loading_fn):
38        self.evaluator_fn = evaluator_fn
39        self.model_cls = model  # Trainer object or method we will initialize
40        self.dataset_loading_fn = dataset_loading_fn
41
42    @abc.abstractmethod
43    def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset):
44        raise NotImplementedError()

Wrapper class for a trainer class.

Trainer(model, evaluator_fn, dataset_loading_fn)
37    def __init__(self, model, evaluator_fn, dataset_loading_fn):
38        self.evaluator_fn = evaluator_fn
39        self.model_cls = model  # Trainer object or method we will initialize
40        self.dataset_loading_fn = dataset_loading_fn
@abc.abstractmethod
def fit( self, in_trial: optuna.trial._trial.Trial, train_dataset, val_dataset):
42    @abc.abstractmethod
43    def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset):
44        raise NotImplementedError()
class SentenceTransformerHparamTrainer(Trainer):
 47class SentenceTransformerHparamTrainer(Trainer):
 48    " See Optuna documentation for types! "
 49    model: SentenceTransformer
 50
 51    def __init__(self, dataset_loading_fn, evaluator_fn, hparams_config: HparamConfig):
 52        super().__init__(SentenceTransformer, evaluator_fn, dataset_loading_fn)
 53        self.loss_fn = None
 54        self.hparams = hparams_config.parse_config_to_py() if hparams_config else None
 55
 56    def get_optuna_hparams(self, trial: optuna.Trial, hparams: Sequence[Hparam] = None):
 57        """
 58        Get hyperparameters suggested by the optuna library
 59
 60        :param trial: The optuna trial object
 61        :param hparams: Optional, pass a dictionary of HparamType[Enum] objects
 62        :return:
 63        """
 64
 65        loguru.logger.info("Fitting the trainer.")
 66
 67        hparam_values = defaultdict(lambda: 0.0)
 68
 69        hparams = hparams if hparams else self.hparams
 70
 71        if hparams is None:
 72            raise RuntimeError("No hyperparameters were specified")
 73
 74        for key, hparam in hparams.items():
 75            if hasattr(hparam, 'suggest'):
 76                hparam_values[hparam.name] = hparam.suggest(trial)
 77                loguru.logger.info(f"Using {hparam_values[hparam.name]} for {hparam.name}.")
 78            else:
 79                hparam_values[key] = hparam
 80
 81        return hparam_values
 82
 83    def build_kwargs_and_model(self, hparams: Dict):
 84        kwargs = {}
 85
 86        for hparam, hparam_value in list(hparams.items()):
 87            loguru.logger.info(f"Building model with {hparam}: {hparam_value}")
 88
 89            if hparam == "lr":
 90                kwargs["optimizer_params"] = {
 91                    "lr": hparam_value
 92                }
 93            elif hparam == "model_name":
 94                self.model = self.model_cls(hparam_value)
 95            elif hparam == "optimizer":
 96                kwargs["optimizer_class"] = getattr(OptimizersWrapper(), hparam_value)
 97            elif hparam == "loss_fn":
 98                self.loss_fn = getattr(losses, hparam_value)
 99            else:
100                kwargs[hparam] = hparam_value
101
102        return kwargs
103
104    def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset):
105        hparams = self.get_optuna_hparams(in_trial)
106        kwargs = self.build_kwargs_and_model(hparams)
107
108        evaluator = self.evaluator_fn.from_input_examples(val_dataset)
109        loss = self.loss_fn(model=self.model)
110        train_dataloader = DataLoader(train_dataset, shuffle=True,
111                                      batch_size=int(kwargs.pop("batch_size")), drop_last=True)
112
113        self.model.fit(
114            train_objectives=[(train_dataloader, loss)],
115            **kwargs,
116            evaluator=evaluator,
117            use_amp=True,
118            callback=partial(trial_callback, in_trial)
119        )
120
121        return self.model.evaluate(evaluator)

See Optuna documentation for types!

SentenceTransformerHparamTrainer( dataset_loading_fn, evaluator_fn, hparams_config: debeir.training.hparm_tuning.config.HparamConfig)
51    def __init__(self, dataset_loading_fn, evaluator_fn, hparams_config: HparamConfig):
52        super().__init__(SentenceTransformer, evaluator_fn, dataset_loading_fn)
53        self.loss_fn = None
54        self.hparams = hparams_config.parse_config_to_py() if hparams_config else None
def get_optuna_hparams( self, trial: optuna.trial._trial.Trial, hparams: Sequence[debeir.training.hparm_tuning.types.Hparam] = None):
56    def get_optuna_hparams(self, trial: optuna.Trial, hparams: Sequence[Hparam] = None):
57        """
58        Get hyperparameters suggested by the optuna library
59
60        :param trial: The optuna trial object
61        :param hparams: Optional, pass a dictionary of HparamType[Enum] objects
62        :return:
63        """
64
65        loguru.logger.info("Fitting the trainer.")
66
67        hparam_values = defaultdict(lambda: 0.0)
68
69        hparams = hparams if hparams else self.hparams
70
71        if hparams is None:
72            raise RuntimeError("No hyperparameters were specified")
73
74        for key, hparam in hparams.items():
75            if hasattr(hparam, 'suggest'):
76                hparam_values[hparam.name] = hparam.suggest(trial)
77                loguru.logger.info(f"Using {hparam_values[hparam.name]} for {hparam.name}.")
78            else:
79                hparam_values[key] = hparam
80
81        return hparam_values

Get hyperparameters suggested by the optuna library

Parameters
  • trial: The optuna trial object
  • hparams: Optional, pass a dictionary of HparamType[Enum] objects
Returns
def build_kwargs_and_model(self, hparams: Dict):
 83    def build_kwargs_and_model(self, hparams: Dict):
 84        kwargs = {}
 85
 86        for hparam, hparam_value in list(hparams.items()):
 87            loguru.logger.info(f"Building model with {hparam}: {hparam_value}")
 88
 89            if hparam == "lr":
 90                kwargs["optimizer_params"] = {
 91                    "lr": hparam_value
 92                }
 93            elif hparam == "model_name":
 94                self.model = self.model_cls(hparam_value)
 95            elif hparam == "optimizer":
 96                kwargs["optimizer_class"] = getattr(OptimizersWrapper(), hparam_value)
 97            elif hparam == "loss_fn":
 98                self.loss_fn = getattr(losses, hparam_value)
 99            else:
100                kwargs[hparam] = hparam_value
101
102        return kwargs
def fit( self, in_trial: optuna.trial._trial.Trial, train_dataset, val_dataset):
104    def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset):
105        hparams = self.get_optuna_hparams(in_trial)
106        kwargs = self.build_kwargs_and_model(hparams)
107
108        evaluator = self.evaluator_fn.from_input_examples(val_dataset)
109        loss = self.loss_fn(model=self.model)
110        train_dataloader = DataLoader(train_dataset, shuffle=True,
111                                      batch_size=int(kwargs.pop("batch_size")), drop_last=True)
112
113        self.model.fit(
114            train_objectives=[(train_dataloader, loss)],
115            **kwargs,
116            evaluator=evaluator,
117            use_amp=True,
118            callback=partial(trial_callback, in_trial)
119        )
120
121        return self.model.evaluate(evaluator)
def trial_callback(trial, score, epoch, *args, **kwargs):
124def trial_callback(trial, score, epoch, *args, **kwargs):
125    trial.report(score, epoch)
126    # Handle pruning based on the intermediate value
127    if trial.should_prune():
128        raise optuna.exceptions.TrialPruned()
class SentenceTransformerTrainer(SentenceTransformerHparamTrainer):
131class SentenceTransformerTrainer(SentenceTransformerHparamTrainer):
132    def __init__(self, dataset: Union[DatasetDict, Dict[str, Dataset]], hparams_config: HparamConfig,
133                 evaluator_fn=None, evaluator=None, use_wandb=False):
134        super().__init__(None, evaluator_fn, hparams_config)
135        self.evaluator = evaluator
136        self.use_wandb = use_wandb
137        self.dataset = dataset
138
139    def fit(self, **extra_kwargs):
140        kwargs = self.build_kwargs_and_model(self.hparams)
141
142        if not self.evaluator:
143            self.evaluator = LoggingEvaluator(self.evaluator_fn.from_input_examples(self.dataset['val']), wandb)
144
145        loss = self.loss_fn(model=self.model)
146
147        if self.use_wandb:
148            wandb.watch(self.model)
149            loss = LoggingLoss(loss, wandb)
150
151        train_dataloader = DataLoader(self.dataset['train'], shuffle=True,
152                                      batch_size=int(kwargs.pop("batch_size")),
153                                      drop_last=True)
154
155        self.model.fit(
156            train_objectives=[(train_dataloader, loss)],
157            **kwargs,
158            evaluator=self.evaluator,
159            use_amp=True,
160            **extra_kwargs
161        )
162
163        return self.model.evaluate(self.evaluator)

See Optuna documentation for types!

SentenceTransformerTrainer( dataset: Union[datasets.dataset_dict.DatasetDict, Dict[str, datasets.arrow_dataset.Dataset]], hparams_config: debeir.training.hparm_tuning.config.HparamConfig, evaluator_fn=None, evaluator=None, use_wandb=False)
132    def __init__(self, dataset: Union[DatasetDict, Dict[str, Dataset]], hparams_config: HparamConfig,
133                 evaluator_fn=None, evaluator=None, use_wandb=False):
134        super().__init__(None, evaluator_fn, hparams_config)
135        self.evaluator = evaluator
136        self.use_wandb = use_wandb
137        self.dataset = dataset
def fit(self, **extra_kwargs):
139    def fit(self, **extra_kwargs):
140        kwargs = self.build_kwargs_and_model(self.hparams)
141
142        if not self.evaluator:
143            self.evaluator = LoggingEvaluator(self.evaluator_fn.from_input_examples(self.dataset['val']), wandb)
144
145        loss = self.loss_fn(model=self.model)
146
147        if self.use_wandb:
148            wandb.watch(self.model)
149            loss = LoggingLoss(loss, wandb)
150
151        train_dataloader = DataLoader(self.dataset['train'], shuffle=True,
152                                      batch_size=int(kwargs.pop("batch_size")),
153                                      drop_last=True)
154
155        self.model.fit(
156            train_objectives=[(train_dataloader, loss)],
157            **kwargs,
158            evaluator=self.evaluator,
159            use_amp=True,
160            **extra_kwargs
161        )
162
163        return self.model.evaluate(self.evaluator)