"""Loss functions for SpaRank training."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class NTXentLoss(nn.Module):
"""Normalised Temperature-scaled Cross-Entropy (NT-Xent) loss
for contrastive learning on paired views.
Parameters
----------
temperature : float, default 0.5
Softmax temperature (lower means a sharper probability distribution).
"""
[docs]
def __init__(self, temperature: float = 0.5):
super().__init__()
self.temperature = temperature
[docs]
def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
"""Compute the contrastive loss between two augmented views.
Parameters
----------
z_a : torch.Tensor
Latent embeddings of the first view, shape ``(B, D)`` where B is
the batch size and D is the feature dimension.
z_b : torch.Tensor
Latent embeddings of the second view, shape ``(B, D)``.
Returns
-------
torch.Tensor
A scalar tensor containing the computed NT-Xent loss.
"""
B = z_a.shape[0]
# Normalize along feature dimension
z = F.normalize(torch.cat([z_a, z_b], dim=0), dim=1)
# Compute similarity matrix
sim = torch.mm(z, z.T) / self.temperature
# Mask out self-similarity
sim.masked_fill_(torch.eye(2 * B, dtype=torch.bool, device=z.device), -9e15)
# Create positive pairing targets
pos = torch.cat([
torch.arange(B, 2 * B, device=z.device),
torch.arange(0, B, device=z.device),
])
return F.cross_entropy(sim, pos)
[docs]
class DeconvCrossEntropy(nn.Module):
"""Cross-entropy loss for cell-type proportion deconvolution.
Treats the target proportions as a soft probability distribution
and uses standard cross-entropy against the predicted logits.
"""
[docs]
def __init__(self):
super().__init__()
self.criterion = nn.CrossEntropyLoss(reduction="mean")
[docs]
def forward(
self, logits: torch.Tensor, target_proportions: torch.Tensor
) -> torch.Tensor:
"""Compute the cross-entropy loss for deconvolution predictions.
Parameters
----------
logits : torch.Tensor
Unnormalized predicted logits from the model, shape ``(B, num_classes)``.
target_proportions : torch.Tensor
Ground truth proportions (soft labels) summing to 1 per sample,
shape ``(B, num_classes)``.
Returns
-------
torch.Tensor
A scalar tensor containing the computed cross-entropy loss.
"""
return self.criterion(logits, target_proportions)