sparank.modules.FusionLayer

class sparank.modules.FusionLayer(dim, n_modalities, fusion_type='concat', attn_heads=4)[source]

Bases: Module

Fuse N modality embeddings into a single vector.

All modalities share embedding dimension dim.

Supported modes

'concat'

Concatenation → output dim = dim × N.

'gate'

Per-dimension softmax gating across modalities.

For each modality m a learnable projection \(W_g^m \in \mathbb{R}^{d \times d}\) produces gate logits. The gate vector is normalised across modalities per dimension:

\[\mathbf{g}^m = \frac{\exp(W_g^m \mathbf{h}^m)} {\sum_{j=1}^{M} \exp(W_g^j \mathbf{h}^j)} \in \mathbb{R}^d\]

ensuring \(\sum_{m=1}^{M} g_i^m = 1\) for every dimension i. The fused embedding is:

\[\mathbf{h}_f = \sum_{m=1}^{M} \mathbf{g}^m \odot \mathbf{h}^m\]
'sigmoid_gate'

Learns a per-dimension sigmoid blending weight for exactly 2 modalities. Produces a weighting scalar \(z \in (0, 1)^d\), fusing them as:

\[\mathbf{h}_f = z \odot \mathbf{h}^1 + (1 - z) \odot \mathbf{h}^2\]
'attention'

Stack CLS tokens → self-attention → mean-pool. Output dim = dim.

When n_modalities == 1 all modes collapse to identity (the single input is returned unchanged, with output dim = dim).

type dim:

int

param dim:

Per-modality feature dimension.

type dim:

int

type n_modalities:

int

param n_modalities:

Number of input vectors.

type n_modalities:

int

type fusion_type:

str

param fusion_type:

'concat', 'gate', 'sigmoid_gate', or 'attention'.

type fusion_type:

str, default “concat”

type attn_heads:

int

param attn_heads:

Attention heads (only applies when fusion_type is 'attention').

type attn_heads:

int, default 4

__init__(dim, n_modalities, fusion_type='concat', attn_heads=4)[source]
Parameters:
  • dim (int)

  • n_modalities (int)

  • fusion_type (str)

  • attn_heads (int)

Methods

__init__(dim, n_modalities[, fusion_type, ...])

forward(features[, return_gates])

Fuse a list of (B, dim) tensors into a single (B, out_dim) tensor.

forward(features, return_gates=False)[source]

Fuse a list of (B, dim) tensors into a single (B, out_dim) tensor.

Parameters:
  • features (List[torch.Tensor]) – List of input modality tensors, each of shape (B, dim). Must contain exactly self.n_modalities tensors.

  • return_gates (bool, default False) – If True, additionally returns a dict describing the fusion weighting used for each sample.

Returns:

If return_gates=False, returns only the fused tensor. If return_gates=True, returns (fused_tensor, gate_info) where gate_info is a dict containing:

  • type (str): The fusion mechanism applied.

  • weights (torch.Tensor or None): The gating weights applied. Shape depends on the type (e.g., (M, B, d) for gates, (B, M, M) for attention matrix).

Return type:

Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]

Parameters:
  • dim (int)

  • n_modalities (int)

  • fusion_type (str)

  • attn_heads (int)