"""Shared building blocks: fusion layers for N modalities."""
from __future__ import annotations
from typing import Any, Dict, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class FusionLayer(nn.Module):
r"""Fuse N modality embeddings into a single vector.
All modalities share embedding dimension *dim*.
Supported modes
---------------
``'concat'``
Concatenation → output dim = dim × N.
``'gate'``
Per-dimension softmax gating across modalities.
For each modality *m* a learnable projection
:math:`W_g^m \in \mathbb{R}^{d \times d}` produces gate logits.
The gate vector is normalised across modalities per dimension:
.. math::
\mathbf{g}^m = \frac{\exp(W_g^m \mathbf{h}^m)}
{\sum_{j=1}^{M} \exp(W_g^j \mathbf{h}^j)}
\in \mathbb{R}^d
ensuring :math:`\sum_{m=1}^{M} g_i^m = 1` for every dimension
*i*. The fused embedding is:
.. math::
\mathbf{h}_f = \sum_{m=1}^{M} \mathbf{g}^m \odot \mathbf{h}^m
``'sigmoid_gate'``
Learns a per-dimension sigmoid blending weight for exactly 2 modalities.
Produces a weighting scalar :math:`z \in (0, 1)^d`, fusing them as:
.. math::
\mathbf{h}_f = z \odot \mathbf{h}^1 + (1 - z) \odot \mathbf{h}^2
``'attention'``
Stack CLS tokens → self-attention → mean-pool.
Output dim = *dim*.
When ``n_modalities == 1`` all modes collapse to identity
(the single input is returned unchanged, with output dim = dim).
Parameters
----------
dim : int
Per-modality feature dimension.
n_modalities : int
Number of input vectors.
fusion_type : str, default "concat"
``'concat'``, ``'gate'``, ``'sigmoid_gate'``, or ``'attention'``.
attn_heads : int, default 4
Attention heads (only applies when fusion_type is ``'attention'``).
"""
[docs]
def __init__(
self,
dim: int,
n_modalities: int,
fusion_type: str = "concat",
attn_heads: int = 4,
):
super().__init__()
self.fusion_type = fusion_type
self.n_modalities = n_modalities
self.dim = dim
# Single-modality → identity, always output dim = dim
if n_modalities == 1:
self.out_dim = dim
return
if fusion_type == "concat":
self.out_dim = dim * n_modalities
elif fusion_type == "gate":
self.out_dim = dim
# One projection W_g^m per modality
self.gate_projs = nn.ModuleList([
nn.Linear(dim, dim) for _ in range(n_modalities)
])
elif fusion_type == "sigmoid_gate":
assert n_modalities == 2, "sigmoid_gate only supports exactly 2 modalities"
self.out_dim = dim
self.gate_net = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.Sigmoid(),
)
elif fusion_type == "attention":
self.out_dim = dim
self.attn = nn.MultiheadAttention(
embed_dim=dim, num_heads=attn_heads, batch_first=True,
)
self.norm = nn.LayerNorm(dim)
else:
raise ValueError(f"Unknown fusion_type: {fusion_type!r}")
[docs]
def forward(
self,
features: List[torch.Tensor],
return_gates: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]:
"""Fuse a list of (B, dim) tensors into a single (B, out_dim) tensor.
Parameters
----------
features : List[torch.Tensor]
List of input modality tensors, each of shape ``(B, dim)``.
Must contain exactly ``self.n_modalities`` tensors.
return_gates : bool, default False
If True, additionally returns a dict describing the fusion
weighting used for each sample.
Returns
-------
Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]
If ``return_gates=False``, returns only the fused tensor.
If ``return_gates=True``, returns ``(fused_tensor, gate_info)`` where
``gate_info`` is a dict containing:
- ``type`` (str): The fusion mechanism applied.
- ``weights`` (torch.Tensor or None): The gating weights applied.
Shape depends on the type (e.g., ``(M, B, d)`` for gates,
``(B, M, M)`` for attention matrix).
"""
assert len(features) == self.n_modalities, \
f"Expected {self.n_modalities} features, got {len(features)}"
# ── Single modality: identity ─────────────────────────
if self.n_modalities == 1:
out = features[0]
if return_gates:
return out, {"type": "identity", "weights": None}
return out
# ── Concat ────────────────────────────────────────────
if self.fusion_type == "concat":
out = torch.cat(features, dim=1)
if return_gates:
return out, {"type": "concat", "weights": None}
return out
# ── Gate (per-dimension softmax across modalities) ────
if self.fusion_type == "gate":
gate_logits = torch.stack(
[proj(h) for proj, h in zip(self.gate_projs, features)], dim=0
) # (M, B, d)
gates = F.softmax(gate_logits, dim=0) # (M, B, d)
h_stack = torch.stack(features, dim=0) # (M, B, d)
out = (gates * h_stack).sum(dim=0) # (B, d)
if return_gates:
return out, {"type": "gate", "weights": gates}
return out
# ── Sigmoid Gate (per-dimension sigmoid across modalities) ────
if self.fusion_type == "sigmoid_gate":
h_cat = torch.cat(features, dim=1)
z = self.gate_net(h_cat)
out = z * features[0] + (1 - z) * features[1]
if return_gates:
# synthesize (M, B, d) schema: modality 0 weight = z, modality 1 = 1-z
weights = torch.stack([z, 1 - z], dim=0) # (2, B, d)
return out, {"type": "sigmoid_gate", "weights": weights}
return out
# ── Attention ─────────────────────────────────────────
seq = torch.stack(features, dim=1) # (B, M, d)
# Ask MultiheadAttention to return averaged attention weights
attn_out, attn_w = self.attn(
seq, seq, seq,
need_weights=return_gates,
average_attn_weights=True,
) # attn_w: (B, M, M)
out = self.norm(attn_out + seq).mean(dim=1) # (B, d)
if return_gates:
return out, {"type": "attention", "weights": attn_w}
return out