NT-Xent (Normalized Temperature-Scaled Cross-Entropy) Loss Explained and Implemented in PyTorch | by Dhruv Matani | Jun, 2023

Implementation of NT-BXent loss in PyTorch

All the code in this section can be found in this notebook.

Code Reuse: Similar to our implementation of the NT-Xent loss, we shall re-use the Binary Cross-entropy (BCE) loss method provided by PyTorch. The setup of our ground-truth labels will be similar to that of a multi-label classification problem where BCE loss is used.

Predictions Tensor: We’ll use the same (8, 2) predictions tensor as we used for the implementation of the NT-Xent loss.

x = torch.randn(8, 2)

Cosine Similarity: Since the input tensor x is same, the all-pairs cosine similarity tensor xcs will also be the same. Please see this page for a detailed explanation of what the line below does.

xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)

To ensure that the loss from the element at position (i, i) is 0, we’ll need to perform some gymnastics to have our xcs tensor contain a value 1 at every index (i, i) after Sigmoid is applied to it. Since we’ll be using BCE Loss, we will mark the self-similarity score of every feature vector with the value infinity in tensor xcs. That’s because applying the sigmoid function on the xcs tensor, will convert infinity to the value 1, and we will set up our ground-truth labels so that every position (i, i) in the ground-truth labels has the value 1.

Let’s create a masking tensor that has the value True along the principal diagonal (xcs has self-similarity scores along the principal diagonal), and False everywhere else.

eye = torch.eye(8).bool()

Let’s clone the tensor “xcs” into a tensor named “y” so that we can reference the “xcs” tensor later.

y = xcs.clone()

Now, we will set the values along the principal diagonal of the all-pairs cosine similarity matrix to infinity so that when we compute the sigmoid on each row, we get 1 in these positions.

y[eye] = float("inf")

The tensor “y” scaled by a temperature parameter will be one of the inputs (predictions) to the BCE loss API in PyTorch. Next, we need to compute the ground-truth labels (target) that we need to feed to the BCE loss API.

Ground Truth labels (Target tensor): We will expect the user to pass to us the pair of all (x, y) index pairs which contain positive examples. This is a departure for what we did for the NT-Xent loss, since the positive pairs were implicit, whereas here, the positive pairs are explicit.

In addition to the locations provided by the user, we will set all the diagonal elements as positive pairs as explained above. We will use the PyTorch tensor indexing API to pluck out all the elements at those locations and set them to 1, whereas the rest are initialized to 0.

target = torch.zeros(8, 8)
pos_indices = torch.tensor([
(0, 0), (0, 2), (0, 4),
(1, 4), (1, 6), (1, 1),
(2, 3),
(3, 7),
(4, 3),
(7, 6),
# Add indexes of the principal diagonal as positive indexes.
# This will be useful since we will use the BCELoss in PyTorch,
# which will expect a value for the elements on the principal
# diagonal as well.
pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)
# Set the values in the target vector to 1.
target[pos_indices[:,0], pos_indices[:,1]] = 1

Binary cross-entropy (BCE) Loss: Unlike the NT-Xent loss, we can’t simply call the torch.nn.functional.binary_cross_entropy_function, since we want to weigh the positive and negative loss based on how many positive and negative pairs the element at index i has in the current mini-batch.

The first step though is to compute the element-wise BCE loss.

temperature = 0.1
loss = F.binary_cross_entropy((y / temperature).sigmoid(), target, reduction="none")

We’ll create a binary mask of positive and negative pairs and then create 2 tensors, loss_pos and loss_neg that contain only those elements from the computed loss that correspond to the positive and negative pairs.

target_pos = target.bool()
target_neg = ~target_pos
# loss_pos and loss_neg below contain non-zero values only for those elements
# that are positive pairs and negative pairs respectively.
loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])

Next, we’ll sum up the positive and negative pair loss (separately) corresponding to each element i in our mini-batch.

# loss_pos and loss_neg now contain the sum of positive and negative pair losses
# as computed relative to the i'th input.
loss_pos = loss_pos.sum(dim=1)
loss_neg = loss_neg.sum(dim=1)

To perform weighting, we need to track the number of positive and negative pairs corresponding to each element i in our mini-batch. Tensors “num_pos” and “num_neg” will store these values.

# num_pos and num_neg below contain the number of positive and negative pairs
# computed relative to the i'th input. In an actual setting, this number should
# be the same for every input element, but we let it vary here for maximum
# flexibility.
num_pos = target.sum(dim=1)
num_neg = target.size(0) - num_pos

We have all the ingredients we need to compute our loss! The only thing that we need to do is weigh the positive and negative loss by the number of positive and negative pairs, and then average the loss across the mini-batch.

def nt_bxent_loss(x, pos_indices, temperature):
assert len(x.size()) == 2

# Add indexes of the principal diagonal elements to pos_indices
pos_indices = torch.cat([
torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2),
], dim=0)

# Ground truth labels
target = torch.zeros(x.size(0), x.size(0))
target[pos_indices[:,0], pos_indices[:,1]] = 1.0

# Cosine similarity
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
# Set logit of diagonal element to "inf" signifying complete
# correlation. sigmoid(inf) = 1.0 so this will work out nicely
# when computing the Binary cross-entropy Loss.
xcs[torch.eye(x.size(0)).bool()] = float("inf")

# Standard binary cross-entropy loss. We use binary_cross_entropy() here and not
# binary_cross_entropy_with_logits() because of
# https://github.com/pytorch/pytorch/issues/102894
# The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values
# to result in a NaN result.
loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none")

target_pos = target.bool()
target_neg = ~target_pos

loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])
loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])
loss_pos = loss_pos.sum(dim=1)
loss_neg = loss_neg.sum(dim=1)
num_pos = target.sum(dim=1)
num_neg = x.size(0) - num_pos

return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()

pos_indices = torch.tensor([
(0, 0), (0, 2), (0, 4),
(1, 4), (1, 6), (1, 1),
(2, 3),
(3, 7),
(4, 3),
(7, 6),
for t in (0.01, 0.1, 1.0, 10.0, 20.0):
print(f"Temperature: {t:5.2f}, Loss: {nt_bxent_loss(x, pos_indices, temperature=t)}")


Temperature: 0.01, Loss: 62.898780822753906

Temperature: 0.10, Loss: 4.851151943206787

Temperature: 1.00, Loss: 1.0727109909057617

Temperature: 10.00, Loss: 0.9827173948287964

Temperature: 20.00, Loss: 0.982099175453186

Source link

Leave a Comment