debeir.training.losses.contrastive

Author: Yonglong Tian (yonglong@mit.edu) Date: May 07, 2020

Code imported from: https://github.com/HobbitLong/SupContrast/blob/master/losses.py

  1"""
  2Author: Yonglong Tian (yonglong@mit.edu)
  3Date: May 07, 2020
  4
  5
  6Code imported from: https://github.com/HobbitLong/SupContrast/blob/master/losses.py
  7"""
  8
  9from enum import Enum
 10from typing import Dict, Iterable
 11
 12import torch
 13import torch.nn.functional as F
 14from torch import Tensor, nn
 15
 16
 17class SupConLoss(nn.Module):
 18    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
 19    It also supports the unsupervised contrastive loss in SimCLR"""
 20
 21    # def __init__(self, temperature=0.07, contrast_mode='all',
 22    #             base_temperature=0.07):
 23    def __init__(self, temperature=1.0, contrast_mode='all',
 24                 base_temperature=1.0):
 25        super(SupConLoss, self).__init__()
 26        self.temperature = temperature
 27        self.base_temperature = base_temperature
 28        self.contrast_mode = contrast_mode
 29
 30    def forward(self, features, labels=None, mask=None):
 31        """Compute loss for model. If both `labels` and `mask` are None,
 32        it degenerates to SimCLR unsupervised loss:
 33        https://arxiv.org/pdf/2002.05709.pdf
 34        Args:
 35            features: hidden vector of shape [bsz, n_views, ...].
 36            labels: ground truth of shape [bsz].
 37            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
 38                has the same class as sample i. Can be asymmetric.
 39        Returns:
 40            A loss scalar.
 41        """
 42        device = (torch.device('cuda')
 43                  if features.is_cuda
 44                  else torch.device('cpu'))
 45
 46        if len(features.shape) < 3:
 47            raise ValueError('`features` needs to be [bsz, n_views, ...],'
 48                             'at least 3 dimensions are required')
 49        if len(features.shape) > 3:
 50            features = features.view(features.shape[0], features.shape[1], -1)
 51
 52        batch_size = features.shape[0]
 53        if labels is not None and mask is not None:
 54            raise ValueError('Cannot define both `labels` and `mask`')
 55        elif labels is None and mask is None:
 56            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
 57        elif labels is not None:
 58            labels = labels.contiguous().view(-1, 1)
 59            if labels.shape[0] != batch_size:
 60                raise ValueError('Num of labels does not match num of features')
 61            mask = torch.eq(labels, labels.T).float().to(device)
 62        else:
 63            mask = mask.float().to(device)
 64
 65        contrast_count = features.shape[1]
 66        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
 67        if self.contrast_mode == 'one':
 68            anchor_feature = features[:, 0]
 69            anchor_count = 1
 70        elif self.contrast_mode == 'all':
 71            anchor_feature = contrast_feature
 72            anchor_count = contrast_count
 73        else:
 74            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
 75
 76        # compute logits
 77        anchor_dot_contrast = torch.div(
 78            torch.matmul(anchor_feature, contrast_feature.T),
 79            self.temperature)
 80        # for numerical stability
 81        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
 82        logits = anchor_dot_contrast - logits_max.detach()
 83
 84        # tile mask
 85        mask = mask.repeat(anchor_count, contrast_count)
 86        # mask-out self-contrast cases
 87        logits_mask = torch.scatter(
 88            torch.ones_like(mask),
 89            1,
 90            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
 91            0
 92        )
 93        mask = mask * logits_mask
 94
 95        # compute log_prob
 96        exp_logits = torch.exp(logits) * logits_mask
 97        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
 98
 99        # compute mean of log-likelihood over positive
100        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
101
102        # loss
103        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
104        loss = loss.view(anchor_count, batch_size).mean()
105
106        return loss
107
108
109class SiameseDistanceMetric(Enum):
110    """
111    The metric for the contrastive loss
112    """
113    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
114    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
115    COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)
116
117
118class ContrastiveSentLoss(nn.Module):
119    """
120    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
121    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
122    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
123    :param model: SentenceTransformer model
124    :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
125    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
126    :param size_average: Average by the size of the mini-batch.
127    Example::
128        from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
129        from torch.utils.data import DataLoader
130        model = SentenceTransformer('all-MiniLM-L6-v2')
131        train_examples = [
132            InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
133            InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
134        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
135        train_loss = losses.ContrastiveLoss(model=model)
136        model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
137    """
138
139    def __init__(self, model, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
140                 margin: float = 0.5, size_average: bool = True):
141        super(ContrastiveSentLoss, self).__init__()
142        self.distance_metric = distance_metric
143        self.margin = margin
144        self.model = model
145        self.size_average = size_average
146
147    def get_config_dict(self):
148        distance_metric_name = self.distance_metric.__name__
149        for name, value in vars(SiameseDistanceMetric).items():
150            if value == self.distance_metric:
151                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
152                break
153
154        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}
155
156    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
157        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
158        assert len(reps) == 2
159        rep_anchor, rep_other = reps
160        distances = self.distance_metric(rep_anchor, rep_other)
161        losses = 0.5 * (
162                    labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
163        return losses.mean() if self.size_average else losses.sum()
class SupConLoss(torch.nn.modules.module.Module):
 18class SupConLoss(nn.Module):
 19    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
 20    It also supports the unsupervised contrastive loss in SimCLR"""
 21
 22    # def __init__(self, temperature=0.07, contrast_mode='all',
 23    #             base_temperature=0.07):
 24    def __init__(self, temperature=1.0, contrast_mode='all',
 25                 base_temperature=1.0):
 26        super(SupConLoss, self).__init__()
 27        self.temperature = temperature
 28        self.base_temperature = base_temperature
 29        self.contrast_mode = contrast_mode
 30
 31    def forward(self, features, labels=None, mask=None):
 32        """Compute loss for model. If both `labels` and `mask` are None,
 33        it degenerates to SimCLR unsupervised loss:
 34        https://arxiv.org/pdf/2002.05709.pdf
 35        Args:
 36            features: hidden vector of shape [bsz, n_views, ...].
 37            labels: ground truth of shape [bsz].
 38            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
 39                has the same class as sample i. Can be asymmetric.
 40        Returns:
 41            A loss scalar.
 42        """
 43        device = (torch.device('cuda')
 44                  if features.is_cuda
 45                  else torch.device('cpu'))
 46
 47        if len(features.shape) < 3:
 48            raise ValueError('`features` needs to be [bsz, n_views, ...],'
 49                             'at least 3 dimensions are required')
 50        if len(features.shape) > 3:
 51            features = features.view(features.shape[0], features.shape[1], -1)
 52
 53        batch_size = features.shape[0]
 54        if labels is not None and mask is not None:
 55            raise ValueError('Cannot define both `labels` and `mask`')
 56        elif labels is None and mask is None:
 57            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
 58        elif labels is not None:
 59            labels = labels.contiguous().view(-1, 1)
 60            if labels.shape[0] != batch_size:
 61                raise ValueError('Num of labels does not match num of features')
 62            mask = torch.eq(labels, labels.T).float().to(device)
 63        else:
 64            mask = mask.float().to(device)
 65
 66        contrast_count = features.shape[1]
 67        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
 68        if self.contrast_mode == 'one':
 69            anchor_feature = features[:, 0]
 70            anchor_count = 1
 71        elif self.contrast_mode == 'all':
 72            anchor_feature = contrast_feature
 73            anchor_count = contrast_count
 74        else:
 75            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
 76
 77        # compute logits
 78        anchor_dot_contrast = torch.div(
 79            torch.matmul(anchor_feature, contrast_feature.T),
 80            self.temperature)
 81        # for numerical stability
 82        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
 83        logits = anchor_dot_contrast - logits_max.detach()
 84
 85        # tile mask
 86        mask = mask.repeat(anchor_count, contrast_count)
 87        # mask-out self-contrast cases
 88        logits_mask = torch.scatter(
 89            torch.ones_like(mask),
 90            1,
 91            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
 92            0
 93        )
 94        mask = mask * logits_mask
 95
 96        # compute log_prob
 97        exp_logits = torch.exp(logits) * logits_mask
 98        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
 99
100        # compute mean of log-likelihood over positive
101        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
102
103        # loss
104        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
105        loss = loss.view(anchor_count, batch_size).mean()
106
107        return loss

Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR

SupConLoss(temperature=1.0, contrast_mode='all', base_temperature=1.0)
24    def __init__(self, temperature=1.0, contrast_mode='all',
25                 base_temperature=1.0):
26        super(SupConLoss, self).__init__()
27        self.temperature = temperature
28        self.base_temperature = base_temperature
29        self.contrast_mode = contrast_mode

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, features, labels=None, mask=None):
 31    def forward(self, features, labels=None, mask=None):
 32        """Compute loss for model. If both `labels` and `mask` are None,
 33        it degenerates to SimCLR unsupervised loss:
 34        https://arxiv.org/pdf/2002.05709.pdf
 35        Args:
 36            features: hidden vector of shape [bsz, n_views, ...].
 37            labels: ground truth of shape [bsz].
 38            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
 39                has the same class as sample i. Can be asymmetric.
 40        Returns:
 41            A loss scalar.
 42        """
 43        device = (torch.device('cuda')
 44                  if features.is_cuda
 45                  else torch.device('cpu'))
 46
 47        if len(features.shape) < 3:
 48            raise ValueError('`features` needs to be [bsz, n_views, ...],'
 49                             'at least 3 dimensions are required')
 50        if len(features.shape) > 3:
 51            features = features.view(features.shape[0], features.shape[1], -1)
 52
 53        batch_size = features.shape[0]
 54        if labels is not None and mask is not None:
 55            raise ValueError('Cannot define both `labels` and `mask`')
 56        elif labels is None and mask is None:
 57            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
 58        elif labels is not None:
 59            labels = labels.contiguous().view(-1, 1)
 60            if labels.shape[0] != batch_size:
 61                raise ValueError('Num of labels does not match num of features')
 62            mask = torch.eq(labels, labels.T).float().to(device)
 63        else:
 64            mask = mask.float().to(device)
 65
 66        contrast_count = features.shape[1]
 67        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
 68        if self.contrast_mode == 'one':
 69            anchor_feature = features[:, 0]
 70            anchor_count = 1
 71        elif self.contrast_mode == 'all':
 72            anchor_feature = contrast_feature
 73            anchor_count = contrast_count
 74        else:
 75            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
 76
 77        # compute logits
 78        anchor_dot_contrast = torch.div(
 79            torch.matmul(anchor_feature, contrast_feature.T),
 80            self.temperature)
 81        # for numerical stability
 82        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
 83        logits = anchor_dot_contrast - logits_max.detach()
 84
 85        # tile mask
 86        mask = mask.repeat(anchor_count, contrast_count)
 87        # mask-out self-contrast cases
 88        logits_mask = torch.scatter(
 89            torch.ones_like(mask),
 90            1,
 91            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
 92            0
 93        )
 94        mask = mask * logits_mask
 95
 96        # compute log_prob
 97        exp_logits = torch.exp(logits) * logits_mask
 98        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
 99
100        # compute mean of log-likelihood over positive
101        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
102
103        # loss
104        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
105        loss = loss.view(anchor_count, batch_size).mean()
106
107        return loss

Compute loss for model. If both labels and mask are None, it degenerates to SimCLR unsupervised loss: https://arxiv.org/pdf/2002.05709.pdf Args: features: hidden vector of shape [bsz, n_views, ...]. labels: ground truth of shape [bsz]. mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j has the same class as sample i. Can be asymmetric. Returns: A loss scalar.

Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class SiameseDistanceMetric(enum.Enum):
110class SiameseDistanceMetric(Enum):
111    """
112    The metric for the contrastive loss
113    """
114    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
115    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
116    COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)

The metric for the contrastive loss

def EUCLIDEAN(x, y):
114    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
def MANHATTAN(x, y):
115    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
def COSINE_DISTANCE(x, y):
116    COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)
Inherited Members
enum.Enum
name
value
class ContrastiveSentLoss(torch.nn.modules.module.Module):
119class ContrastiveSentLoss(nn.Module):
120    """
121    Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
122    two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
123    Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
124    :param model: SentenceTransformer model
125    :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
126    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
127    :param size_average: Average by the size of the mini-batch.
128    Example::
129        from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
130        from torch.utils.data import DataLoader
131        model = SentenceTransformer('all-MiniLM-L6-v2')
132        train_examples = [
133            InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
134            InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
135        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
136        train_loss = losses.ContrastiveLoss(model=model)
137        model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
138    """
139
140    def __init__(self, model, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
141                 margin: float = 0.5, size_average: bool = True):
142        super(ContrastiveSentLoss, self).__init__()
143        self.distance_metric = distance_metric
144        self.margin = margin
145        self.model = model
146        self.size_average = size_average
147
148    def get_config_dict(self):
149        distance_metric_name = self.distance_metric.__name__
150        for name, value in vars(SiameseDistanceMetric).items():
151            if value == self.distance_metric:
152                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
153                break
154
155        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}
156
157    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
158        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
159        assert len(reps) == 2
160        rep_anchor, rep_other = reps
161        distances = self.distance_metric(rep_anchor, rep_other)
162        losses = 0.5 * (
163                    labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
164        return losses.mean() if self.size_average else losses.sum()

Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

Parameters
  • model: SentenceTransformer model
  • distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
  • margin: Negative samples (label == 0) should have a distance of at least the margin value.
  • size_average: Average by the size of the mini-batch. Example:: from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample from torch.utils.data import DataLoader model = SentenceTransformer('all-MiniLM-L6-v2') train_examples = [ InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) train_loss = losses.ContrastiveLoss(model=model) model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
ContrastiveSentLoss( model, distance_metric=<function SiameseDistanceMetric.<lambda>>, margin: float = 0.5, size_average: bool = True)
140    def __init__(self, model, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
141                 margin: float = 0.5, size_average: bool = True):
142        super(ContrastiveSentLoss, self).__init__()
143        self.distance_metric = distance_metric
144        self.margin = margin
145        self.model = model
146        self.size_average = size_average

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def get_config_dict(self):
148    def get_config_dict(self):
149        distance_metric_name = self.distance_metric.__name__
150        for name, value in vars(SiameseDistanceMetric).items():
151            if value == self.distance_metric:
152                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
153                break
154
155        return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}
def forward( self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
157    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
158        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
159        assert len(reps) == 2
160        rep_anchor, rep_other = reps
161        distances = self.distance_metric(rep_anchor, rep_other)
162        losses = 0.5 * (
163                    labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
164        return losses.mean() if self.size_average else losses.sum()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr