Source code for sparank.modules.losses

"""Loss functions for SpaRank training."""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class NTXentLoss(nn.Module): """Normalised Temperature-scaled Cross-Entropy (NT-Xent) loss for contrastive learning on paired views. Parameters ---------- temperature : float, default 0.5 Softmax temperature (lower means a sharper probability distribution). """
[docs] def __init__(self, temperature: float = 0.5): super().__init__() self.temperature = temperature
[docs] def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: """Compute the contrastive loss between two augmented views. Parameters ---------- z_a : torch.Tensor Latent embeddings of the first view, shape ``(B, D)`` where B is the batch size and D is the feature dimension. z_b : torch.Tensor Latent embeddings of the second view, shape ``(B, D)``. Returns ------- torch.Tensor A scalar tensor containing the computed NT-Xent loss. """ B = z_a.shape[0] # Normalize along feature dimension z = F.normalize(torch.cat([z_a, z_b], dim=0), dim=1) # Compute similarity matrix sim = torch.mm(z, z.T) / self.temperature # Mask out self-similarity sim.masked_fill_(torch.eye(2 * B, dtype=torch.bool, device=z.device), -9e15) # Create positive pairing targets pos = torch.cat([ torch.arange(B, 2 * B, device=z.device), torch.arange(0, B, device=z.device), ]) return F.cross_entropy(sim, pos)
[docs] class DeconvCrossEntropy(nn.Module): """Cross-entropy loss for cell-type proportion deconvolution. Treats the target proportions as a soft probability distribution and uses standard cross-entropy against the predicted logits. """
[docs] def __init__(self): super().__init__() self.criterion = nn.CrossEntropyLoss(reduction="mean")
[docs] def forward( self, logits: torch.Tensor, target_proportions: torch.Tensor ) -> torch.Tensor: """Compute the cross-entropy loss for deconvolution predictions. Parameters ---------- logits : torch.Tensor Unnormalized predicted logits from the model, shape ``(B, num_classes)``. target_proportions : torch.Tensor Ground truth proportions (soft labels) summing to 1 per sample, shape ``(B, num_classes)``. Returns ------- torch.Tensor A scalar tensor containing the computed cross-entropy loss. """ return self.criterion(logits, target_proportions)