Source code for sparank.modules.layers

"""Shared building blocks: fusion layers for N modalities."""

from __future__ import annotations

from typing import Any, Dict, List, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class FusionLayer(nn.Module): r"""Fuse N modality embeddings into a single vector. All modalities share embedding dimension *dim*. Supported modes --------------- ``'concat'`` Concatenation → output dim = dim × N. ``'gate'`` Per-dimension softmax gating across modalities. For each modality *m* a learnable projection :math:`W_g^m \in \mathbb{R}^{d \times d}` produces gate logits. The gate vector is normalised across modalities per dimension: .. math:: \mathbf{g}^m = \frac{\exp(W_g^m \mathbf{h}^m)} {\sum_{j=1}^{M} \exp(W_g^j \mathbf{h}^j)} \in \mathbb{R}^d ensuring :math:`\sum_{m=1}^{M} g_i^m = 1` for every dimension *i*. The fused embedding is: .. math:: \mathbf{h}_f = \sum_{m=1}^{M} \mathbf{g}^m \odot \mathbf{h}^m ``'sigmoid_gate'`` Learns a per-dimension sigmoid blending weight for exactly 2 modalities. Produces a weighting scalar :math:`z \in (0, 1)^d`, fusing them as: .. math:: \mathbf{h}_f = z \odot \mathbf{h}^1 + (1 - z) \odot \mathbf{h}^2 ``'attention'`` Stack CLS tokens → self-attention → mean-pool. Output dim = *dim*. When ``n_modalities == 1`` all modes collapse to identity (the single input is returned unchanged, with output dim = dim). Parameters ---------- dim : int Per-modality feature dimension. n_modalities : int Number of input vectors. fusion_type : str, default "concat" ``'concat'``, ``'gate'``, ``'sigmoid_gate'``, or ``'attention'``. attn_heads : int, default 4 Attention heads (only applies when fusion_type is ``'attention'``). """
[docs] def __init__( self, dim: int, n_modalities: int, fusion_type: str = "concat", attn_heads: int = 4, ): super().__init__() self.fusion_type = fusion_type self.n_modalities = n_modalities self.dim = dim # Single-modality → identity, always output dim = dim if n_modalities == 1: self.out_dim = dim return if fusion_type == "concat": self.out_dim = dim * n_modalities elif fusion_type == "gate": self.out_dim = dim # One projection W_g^m per modality self.gate_projs = nn.ModuleList([ nn.Linear(dim, dim) for _ in range(n_modalities) ]) elif fusion_type == "sigmoid_gate": assert n_modalities == 2, "sigmoid_gate only supports exactly 2 modalities" self.out_dim = dim self.gate_net = nn.Sequential( nn.Linear(dim * 2, dim), nn.Sigmoid(), ) elif fusion_type == "attention": self.out_dim = dim self.attn = nn.MultiheadAttention( embed_dim=dim, num_heads=attn_heads, batch_first=True, ) self.norm = nn.LayerNorm(dim) else: raise ValueError(f"Unknown fusion_type: {fusion_type!r}")
[docs] def forward( self, features: List[torch.Tensor], return_gates: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]: """Fuse a list of (B, dim) tensors into a single (B, out_dim) tensor. Parameters ---------- features : List[torch.Tensor] List of input modality tensors, each of shape ``(B, dim)``. Must contain exactly ``self.n_modalities`` tensors. return_gates : bool, default False If True, additionally returns a dict describing the fusion weighting used for each sample. Returns ------- Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]] If ``return_gates=False``, returns only the fused tensor. If ``return_gates=True``, returns ``(fused_tensor, gate_info)`` where ``gate_info`` is a dict containing: - ``type`` (str): The fusion mechanism applied. - ``weights`` (torch.Tensor or None): The gating weights applied. Shape depends on the type (e.g., ``(M, B, d)`` for gates, ``(B, M, M)`` for attention matrix). """ assert len(features) == self.n_modalities, \ f"Expected {self.n_modalities} features, got {len(features)}" # ── Single modality: identity ───────────────────────── if self.n_modalities == 1: out = features[0] if return_gates: return out, {"type": "identity", "weights": None} return out # ── Concat ──────────────────────────────────────────── if self.fusion_type == "concat": out = torch.cat(features, dim=1) if return_gates: return out, {"type": "concat", "weights": None} return out # ── Gate (per-dimension softmax across modalities) ──── if self.fusion_type == "gate": gate_logits = torch.stack( [proj(h) for proj, h in zip(self.gate_projs, features)], dim=0 ) # (M, B, d) gates = F.softmax(gate_logits, dim=0) # (M, B, d) h_stack = torch.stack(features, dim=0) # (M, B, d) out = (gates * h_stack).sum(dim=0) # (B, d) if return_gates: return out, {"type": "gate", "weights": gates} return out # ── Sigmoid Gate (per-dimension sigmoid across modalities) ──── if self.fusion_type == "sigmoid_gate": h_cat = torch.cat(features, dim=1) z = self.gate_net(h_cat) out = z * features[0] + (1 - z) * features[1] if return_gates: # synthesize (M, B, d) schema: modality 0 weight = z, modality 1 = 1-z weights = torch.stack([z, 1 - z], dim=0) # (2, B, d) return out, {"type": "sigmoid_gate", "weights": weights} return out # ── Attention ───────────────────────────────────────── seq = torch.stack(features, dim=1) # (B, M, d) # Ask MultiheadAttention to return averaged attention weights attn_out, attn_w = self.attn( seq, seq, seq, need_weights=return_gates, average_attn_weights=True, ) # attn_w: (B, M, M) out = self.norm(attn_out + seq).mean(dim=1) # (B, d) if return_gates: return out, {"type": "attention", "weights": attn_w} return out