sparank.modules.DeconvCrossEntropy

class sparank.modules.DeconvCrossEntropy(*args, **kwargs)[source]

Bases: Module

Cross-entropy loss for cell-type proportion deconvolution.

Treats the target proportions as a soft probability distribution and uses standard cross-entropy against the predicted logits.

__init__()[source]

Methods

__init__()

forward(logits, target_proportions)

Compute the cross-entropy loss for deconvolution predictions.

forward(logits, target_proportions)[source]

Compute the cross-entropy loss for deconvolution predictions.

Parameters:
  • logits (torch.Tensor) – Unnormalized predicted logits from the model, shape (B, num_classes).

  • target_proportions (torch.Tensor) – Ground truth proportions (soft labels) summing to 1 per sample, shape (B, num_classes).

Returns:

A scalar tensor containing the computed cross-entropy loss.

Return type:

torch.Tensor