Source code for sparank.training.trainer

"""Unified Trainer for SpaRank."""

from __future__ import annotations

import json
import os
from os.path import join
from typing import Dict, Optional, Union, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from sparank.config import ExpConfig
from sparank.modules.losses import NTXentLoss, DeconvCrossEntropy


[docs] class Trainer: """Train a :class:`SpotRankTransformer` with CLS + CL + MRP. Parameters ---------- model : nn.Module The unified model (1 or N towers) to be trained. config : ExpConfig Experiment configuration dataclass containing hyperparameters. device : Union[str, torch.device], default "cuda" The device to run the training on (e.g., "cpu", "cuda", "cuda:0"). """
[docs] def __init__( self, model: nn.Module, config: ExpConfig, device: Union[str, torch.device] = "cuda", ): self.model = model.to(device) self.cfg = config self.device = torch.device(device) self.optimizer = optim.AdamW(model.parameters(), lr=config.lr) self.criterion_cls = DeconvCrossEntropy() self.criterion_cl = NTXentLoss(temperature=config.cl_temperature) self.criterion_mrp = nn.CrossEntropyLoss(reduction="mean") self.loss_history: Dict[str, list] = { "total": [], "cls": [], "cl": [], "mrp": [], }
# ------------------------------------------------------------------ def _train_step(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Execute a single forward pass and loss computation. Parameters ---------- batch : Dict[str, torch.Tensor] A dictionary containing the batched input tensors (e.g., "x", "y", "view_a", "view_b", "masked_x", "mask_pos", "target_x", "context"). Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor] A tuple containing three scalar loss tensors: - Classification loss (l_cls) - Contrastive learning loss (l_cl) - Masked region prediction loss (l_mrp) """ batch = {k: v.to(self.device) for k, v in batch.items()} y = batch["y"] context_ids = batch.get("context") if self.cfg.use_context else None # ── Classification (+ optional CL) ───────────────────────── if self.cfg.use_cl: # CL path: two views, both go through encoder + fusion proj_a, proj_b, logits = self.model.forward_cl( batch["view_a"], batch["view_b"], context_ids ) l_cls = self.criterion_cls(logits, y) l_cl = self.criterion_cl(proj_a, proj_b) else: # CLS-only path: single forward, no proj_head, no 2x overhead logits, _ = self.model(batch["x"], context_ids) l_cls = self.criterion_cls(logits, y) l_cl = torch.zeros((), device=self.device) # ── Optional MRP ────────────────────────────────────────── if self.cfg.use_mrp: mrp_results = self.model.forward_mrp(batch["masked_x"], batch["mask_pos"]) mrp_losses = [] offset = 0 for spec in self.model.specs: name = spec["name"] top_k = spec["top_k"] pred, seg_mask = mrp_results[name] # Extract ground truth tokens specifically for the masked positions true_ids = batch["target_x"][:, offset : offset + top_k][seg_mask] if true_ids.numel() > 0: mrp_losses.append(self.criterion_mrp(pred, true_ids)) offset += top_k l_mrp = ( sum(mrp_losses) / len(mrp_losses) if mrp_losses else torch.zeros((), device=self.device) ) else: l_mrp = torch.zeros((), device=self.device) return l_cls, l_cl, l_mrp # ------------------------------------------------------------------
[docs] def fit( self, train_loader: DataLoader, epochs: Optional[int] = None, save_dir: Optional[str] = None, ) -> Dict[str, list]: """Run the full training loop. Executes the optimization process over the specified dataset for a given number of epochs, saving model checkpoints and loss metadata. Parameters ---------- train_loader : DataLoader PyTorch DataLoader providing the training batches. epochs : int, optional Number of training epochs. Defaults to ``self.cfg.epochs``. save_dir : str, optional Directory to save model checkpoints and metadata. Defaults to ``self.cfg.save_dir``. Returns ------- Dict[str, list] A dictionary containing the loss history per epoch, with keys: ``"total"``, ``"cls"``, ``"cl"``, and ``"mrp"``. """ epochs = epochs or self.cfg.epochs save_dir = save_dir or self.cfg.save_dir os.makedirs(join(save_dir, "metadata"), exist_ok=True) cls_w = self.cfg.cls_weight cl_w = self.cfg.cl_weight mrp_w = self.cfg.mrp_weight for epoch in range(1, epochs + 1): self.model.train() tot = tot_cls = tot_cl = tot_mrp = 0.0 loop = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}") for batch in loop: self.optimizer.zero_grad(set_to_none=True) l_cls, l_cl, l_mrp = self._train_step(batch) loss = cls_w * l_cls + cl_w * l_cl + mrp_w * l_mrp loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() tot += loss.item() tot_cls += l_cls.item() tot_cl += l_cl.item() tot_mrp += l_mrp.item() loop.set_postfix( loss=f"{loss.item():.4f}", cls=f"{l_cls.item():.4f}", cl=f"{l_cl.item():.4f}", mrp=f"{l_mrp.item():.4f}", ) n = len(train_loader) for key, val in zip( ("total", "cls", "cl", "mrp"), (tot, tot_cls, tot_cl, tot_mrp), ): self.loss_history[key].append(val / n) print( f"Epoch {epoch} | total={tot/n:.4f} cls={tot_cls/n:.4f} " f"cl={tot_cl/n:.4f} mrp={tot_mrp/n:.4f}" ) torch.save( self.model.state_dict(), join(save_dir, f"model_epoch{epoch}.pth"), ) with open(join(save_dir, "metadata", "loss_hist.json"), "w") as f: json.dump(self.loss_history, f, indent=4) print("Training finished.") return self.loss_history