sparank.training.Trainer

class sparank.training.Trainer(model, config, device='cuda')[source]

Bases: object

Train a SpotRankTransformer with CLS + CL + MRP.

Parameters:
  • model (nn.Module) – The unified model (1 or N towers) to be trained.

  • config (ExpConfig) – Experiment configuration dataclass containing hyperparameters.

  • device (Union[str, torch.device], default "cuda") – The device to run the training on (e.g., “cpu”, “cuda”, “cuda:0”).

__init__(model, config, device='cuda')[source]
Parameters:

Methods

__init__(model, config[, device])

fit(train_loader[, epochs, save_dir])

Run the full training loop.

fit(train_loader, epochs=None, save_dir=None)[source]

Run the full training loop.

Executes the optimization process over the specified dataset for a given number of epochs, saving model checkpoints and loss metadata.

Parameters:
  • train_loader (DataLoader) – PyTorch DataLoader providing the training batches.

  • epochs (int, optional) – Number of training epochs. Defaults to self.cfg.epochs.

  • save_dir (str, optional) – Directory to save model checkpoints and metadata. Defaults to self.cfg.save_dir.

Returns:

A dictionary containing the loss history per epoch, with keys: "total", "cls", "cl", and "mrp".

Return type:

Dict[str, list]