sparank.modules.NTXentLoss

class sparank.modules.NTXentLoss(temperature=0.5)[source]

Bases: Module

Normalised Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning on paired views.

Parameters:

temperature (float, default 0.5) – Softmax temperature (lower means a sharper probability distribution).

__init__(temperature=0.5)[source]
Parameters:

temperature (float)

Methods

__init__([temperature])

forward(z_a, z_b)

Compute the contrastive loss between two augmented views.

forward(z_a, z_b)[source]

Compute the contrastive loss between two augmented views.

Parameters:
  • z_a (torch.Tensor) – Latent embeddings of the first view, shape (B, D) where B is the batch size and D is the feature dimension.

  • z_b (torch.Tensor) – Latent embeddings of the second view, shape (B, D).

Returns:

A scalar tensor containing the computed NT-Xent loss.

Return type:

torch.Tensor