"""Unified Trainer for SpaRank."""
from __future__ import annotations
import json
import os
from os.path import join
from typing import Dict, Optional, Union, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from sparank.config import ExpConfig
from sparank.modules.losses import NTXentLoss, DeconvCrossEntropy
[docs]
class Trainer:
"""Train a :class:`SpotRankTransformer` with CLS + CL + MRP.
Parameters
----------
model : nn.Module
The unified model (1 or N towers) to be trained.
config : ExpConfig
Experiment configuration dataclass containing hyperparameters.
device : Union[str, torch.device], default "cuda"
The device to run the training on (e.g., "cpu", "cuda", "cuda:0").
"""
[docs]
def __init__(
self,
model: nn.Module,
config: ExpConfig,
device: Union[str, torch.device] = "cuda",
):
self.model = model.to(device)
self.cfg = config
self.device = torch.device(device)
self.optimizer = optim.AdamW(model.parameters(), lr=config.lr)
self.criterion_cls = DeconvCrossEntropy()
self.criterion_cl = NTXentLoss(temperature=config.cl_temperature)
self.criterion_mrp = nn.CrossEntropyLoss(reduction="mean")
self.loss_history: Dict[str, list] = {
"total": [], "cls": [], "cl": [], "mrp": [],
}
# ------------------------------------------------------------------
def _train_step(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Execute a single forward pass and loss computation.
Parameters
----------
batch : Dict[str, torch.Tensor]
A dictionary containing the batched input tensors (e.g., "x", "y",
"view_a", "view_b", "masked_x", "mask_pos", "target_x", "context").
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple containing three scalar loss tensors:
- Classification loss (l_cls)
- Contrastive learning loss (l_cl)
- Masked region prediction loss (l_mrp)
"""
batch = {k: v.to(self.device) for k, v in batch.items()}
y = batch["y"]
context_ids = batch.get("context") if self.cfg.use_context else None
# ── Classification (+ optional CL) ─────────────────────────
if self.cfg.use_cl:
# CL path: two views, both go through encoder + fusion
proj_a, proj_b, logits = self.model.forward_cl(
batch["view_a"], batch["view_b"], context_ids
)
l_cls = self.criterion_cls(logits, y)
l_cl = self.criterion_cl(proj_a, proj_b)
else:
# CLS-only path: single forward, no proj_head, no 2x overhead
logits, _ = self.model(batch["x"], context_ids)
l_cls = self.criterion_cls(logits, y)
l_cl = torch.zeros((), device=self.device)
# ── Optional MRP ──────────────────────────────────────────
if self.cfg.use_mrp:
mrp_results = self.model.forward_mrp(batch["masked_x"], batch["mask_pos"])
mrp_losses = []
offset = 0
for spec in self.model.specs:
name = spec["name"]
top_k = spec["top_k"]
pred, seg_mask = mrp_results[name]
# Extract ground truth tokens specifically for the masked positions
true_ids = batch["target_x"][:, offset : offset + top_k][seg_mask]
if true_ids.numel() > 0:
mrp_losses.append(self.criterion_mrp(pred, true_ids))
offset += top_k
l_mrp = (
sum(mrp_losses) / len(mrp_losses)
if mrp_losses
else torch.zeros((), device=self.device)
)
else:
l_mrp = torch.zeros((), device=self.device)
return l_cls, l_cl, l_mrp
# ------------------------------------------------------------------
[docs]
def fit(
self,
train_loader: DataLoader,
epochs: Optional[int] = None,
save_dir: Optional[str] = None,
) -> Dict[str, list]:
"""Run the full training loop.
Executes the optimization process over the specified dataset for a
given number of epochs, saving model checkpoints and loss metadata.
Parameters
----------
train_loader : DataLoader
PyTorch DataLoader providing the training batches.
epochs : int, optional
Number of training epochs. Defaults to ``self.cfg.epochs``.
save_dir : str, optional
Directory to save model checkpoints and metadata.
Defaults to ``self.cfg.save_dir``.
Returns
-------
Dict[str, list]
A dictionary containing the loss history per epoch, with keys:
``"total"``, ``"cls"``, ``"cl"``, and ``"mrp"``.
"""
epochs = epochs or self.cfg.epochs
save_dir = save_dir or self.cfg.save_dir
os.makedirs(join(save_dir, "metadata"), exist_ok=True)
cls_w = self.cfg.cls_weight
cl_w = self.cfg.cl_weight
mrp_w = self.cfg.mrp_weight
for epoch in range(1, epochs + 1):
self.model.train()
tot = tot_cls = tot_cl = tot_mrp = 0.0
loop = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
for batch in loop:
self.optimizer.zero_grad(set_to_none=True)
l_cls, l_cl, l_mrp = self._train_step(batch)
loss = cls_w * l_cls + cl_w * l_cl + mrp_w * l_mrp
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
tot += loss.item()
tot_cls += l_cls.item()
tot_cl += l_cl.item()
tot_mrp += l_mrp.item()
loop.set_postfix(
loss=f"{loss.item():.4f}", cls=f"{l_cls.item():.4f}",
cl=f"{l_cl.item():.4f}", mrp=f"{l_mrp.item():.4f}",
)
n = len(train_loader)
for key, val in zip(
("total", "cls", "cl", "mrp"), (tot, tot_cls, tot_cl, tot_mrp),
):
self.loss_history[key].append(val / n)
print(
f"Epoch {epoch} | total={tot/n:.4f} cls={tot_cls/n:.4f} "
f"cl={tot_cl/n:.4f} mrp={tot_mrp/n:.4f}"
)
torch.save(
self.model.state_dict(), join(save_dir, f"model_epoch{epoch}.pth"),
)
with open(join(save_dir, "metadata", "loss_hist.json"), "w") as f:
json.dump(self.loss_history, f, indent=4)
print("Training finished.")
return self.loss_history