sparank.modules.FusionLayer
- class sparank.modules.FusionLayer(dim, n_modalities, fusion_type='concat', attn_heads=4)[source]
Bases:
ModuleFuse N modality embeddings into a single vector.
All modalities share embedding dimension dim.
Supported modes
'concat'Concatenation → output dim = dim × N.
'gate'Per-dimension softmax gating across modalities.
For each modality m a learnable projection \(W_g^m \in \mathbb{R}^{d \times d}\) produces gate logits. The gate vector is normalised across modalities per dimension:
\[\mathbf{g}^m = \frac{\exp(W_g^m \mathbf{h}^m)} {\sum_{j=1}^{M} \exp(W_g^j \mathbf{h}^j)} \in \mathbb{R}^d\]ensuring \(\sum_{m=1}^{M} g_i^m = 1\) for every dimension i. The fused embedding is:
\[\mathbf{h}_f = \sum_{m=1}^{M} \mathbf{g}^m \odot \mathbf{h}^m\]'sigmoid_gate'Learns a per-dimension sigmoid blending weight for exactly 2 modalities. Produces a weighting scalar \(z \in (0, 1)^d\), fusing them as:
\[\mathbf{h}_f = z \odot \mathbf{h}^1 + (1 - z) \odot \mathbf{h}^2\]'attention'Stack CLS tokens → self-attention → mean-pool. Output dim = dim.
When
n_modalities == 1all modes collapse to identity (the single input is returned unchanged, with output dim = dim).- type dim:
- param dim:
Per-modality feature dimension.
- type dim:
int
- type n_modalities:
- param n_modalities:
Number of input vectors.
- type n_modalities:
int
- type fusion_type:
- param fusion_type:
'concat','gate','sigmoid_gate', or'attention'.- type fusion_type:
str, default “concat”
- type attn_heads:
- param attn_heads:
Attention heads (only applies when fusion_type is
'attention').- type attn_heads:
int, default 4
Methods
__init__(dim, n_modalities[, fusion_type, ...])forward(features[, return_gates])Fuse a list of (B, dim) tensors into a single (B, out_dim) tensor.
- forward(features, return_gates=False)[source]
Fuse a list of (B, dim) tensors into a single (B, out_dim) tensor.
- Parameters:
features (List[torch.Tensor]) – List of input modality tensors, each of shape
(B, dim). Must contain exactlyself.n_modalitiestensors.return_gates (bool, default False) – If True, additionally returns a dict describing the fusion weighting used for each sample.
- Returns:
If
return_gates=False, returns only the fused tensor. Ifreturn_gates=True, returns(fused_tensor, gate_info)wheregate_infois a dict containing:type(str): The fusion mechanism applied.weights(torch.Tensor or None): The gating weights applied. Shape depends on the type (e.g.,(M, B, d)for gates,(B, M, M)for attention matrix).
- Return type:
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]