debeir.training.train_reranker

 1from typing import List
 2
 3from debeir.datasets.types import RelevanceExample
 4from debeir.training.utils import _train_sentence_transformer
 5from sentence_transformers.evaluation import SentenceEvaluator
 6
 7
 8def train_cross_encoder_reranker(model_fp_or_name: str, output_dir: str, train_dataset: List[RelevanceExample],
 9                                 dev_dataset: List[RelevanceExample], train_batch_size=32, num_epochs=3,
10                                 warmup_steps=None,
11                                 evaluate_every_n_step: int = 1000,
12                                 special_tokens=None, pooling_mode=None, loss_func=None,
13                                 evaluator: SentenceEvaluator = None,
14                                 *args, **kwargs):
15    """
16    Trains a reranker with relevance signals
17
18    :param model_fp_or_name: The model name or path to the model
19    :param output_dir: Output directory to save model, logs etc.
20    :param train_dataset: Training Examples
21    :param dev_dataset: Dev examples
22    :param train_batch_size: Training batch size
23    :param num_epochs: Number of epochs
24    :param warmup_steps: Warmup steps for the scheduler
25    :param evaluate_every_n_step: Evaluate the model every n steps
26    :param special_tokens: Special tokens to add, defaults to [DOC], [QRY] tokens (bi-encoder)
27    :param pooling_mode: Pooling mode for a sentence transformer model
28    :param loss_func: Loss function(s) to use
29    :param evaluator: Evaluator to use
30    """
31
32    if special_tokens is None:
33        special_tokens = ["[DOC]", "[QRY]"]
34
35    return _train_sentence_transformer(model_fp_or_name, output_dir, train_dataset,
36                                       dev_dataset, train_batch_size,
37                                       num_epochs, warmup_steps, evaluate_every_n_step,
38                                       special_tokens, pooling_mode, loss_func,
39                                       evaluator)
def train_cross_encoder_reranker( model_fp_or_name: str, output_dir: str, train_dataset: List[debeir.datasets.types.RelevanceExample], dev_dataset: List[debeir.datasets.types.RelevanceExample], train_batch_size=32, num_epochs=3, warmup_steps=None, evaluate_every_n_step: int = 1000, special_tokens=None, pooling_mode=None, loss_func=None, evaluator: sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator = None, *args, **kwargs):
 9def train_cross_encoder_reranker(model_fp_or_name: str, output_dir: str, train_dataset: List[RelevanceExample],
10                                 dev_dataset: List[RelevanceExample], train_batch_size=32, num_epochs=3,
11                                 warmup_steps=None,
12                                 evaluate_every_n_step: int = 1000,
13                                 special_tokens=None, pooling_mode=None, loss_func=None,
14                                 evaluator: SentenceEvaluator = None,
15                                 *args, **kwargs):
16    """
17    Trains a reranker with relevance signals
18
19    :param model_fp_or_name: The model name or path to the model
20    :param output_dir: Output directory to save model, logs etc.
21    :param train_dataset: Training Examples
22    :param dev_dataset: Dev examples
23    :param train_batch_size: Training batch size
24    :param num_epochs: Number of epochs
25    :param warmup_steps: Warmup steps for the scheduler
26    :param evaluate_every_n_step: Evaluate the model every n steps
27    :param special_tokens: Special tokens to add, defaults to [DOC], [QRY] tokens (bi-encoder)
28    :param pooling_mode: Pooling mode for a sentence transformer model
29    :param loss_func: Loss function(s) to use
30    :param evaluator: Evaluator to use
31    """
32
33    if special_tokens is None:
34        special_tokens = ["[DOC]", "[QRY]"]
35
36    return _train_sentence_transformer(model_fp_or_name, output_dir, train_dataset,
37                                       dev_dataset, train_batch_size,
38                                       num_epochs, warmup_steps, evaluate_every_n_step,
39                                       special_tokens, pooling_mode, loss_func,
40                                       evaluator)

Trains a reranker with relevance signals

Parameters
  • model_fp_or_name: The model name or path to the model
  • output_dir: Output directory to save model, logs etc.
  • train_dataset: Training Examples
  • dev_dataset: Dev examples
  • train_batch_size: Training batch size
  • num_epochs: Number of epochs
  • warmup_steps: Warmup steps for the scheduler
  • evaluate_every_n_step: Evaluate the model every n steps
  • special_tokens: Special tokens to add, defaults to [DOC], [QRY] tokens (bi-encoder)
  • pooling_mode: Pooling mode for a sentence transformer model
  • loss_func: Loss function(s) to use
  • evaluator: Evaluator to use