sparank.modules.NTXentLoss
- class sparank.modules.NTXentLoss(temperature=0.5)[source]
Bases:
ModuleNormalised Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning on paired views.
- Parameters:
temperature (float, default 0.5) – Softmax temperature (lower means a sharper probability distribution).
Methods
__init__([temperature])forward(z_a, z_b)Compute the contrastive loss between two augmented views.
- forward(z_a, z_b)[source]
Compute the contrastive loss between two augmented views.
- Parameters:
z_a (torch.Tensor) – Latent embeddings of the first view, shape
(B, D)where B is the batch size and D is the feature dimension.z_b (torch.Tensor) – Latent embeddings of the second view, shape
(B, D).
- Returns:
A scalar tensor containing the computed NT-Xent loss.
- Return type: