sparank.training.Trainer
- class sparank.training.Trainer(model, config, device='cuda')[source]
Bases:
objectTrain a
SpotRankTransformerwith CLS + CL + MRP.- Parameters:
model (nn.Module) – The unified model (1 or N towers) to be trained.
config (ExpConfig) – Experiment configuration dataclass containing hyperparameters.
device (Union[str, torch.device], default "cuda") – The device to run the training on (e.g., “cpu”, “cuda”, “cuda:0”).
- __init__(model, config, device='cuda')[source]
- Parameters:
model (torch.nn.Module)
config (ExpConfig)
device (str | torch.device)
Methods
__init__(model, config[, device])fit(train_loader[, epochs, save_dir])Run the full training loop.
- fit(train_loader, epochs=None, save_dir=None)[source]
Run the full training loop.
Executes the optimization process over the specified dataset for a given number of epochs, saving model checkpoints and loss metadata.
- Parameters:
- Returns:
A dictionary containing the loss history per epoch, with keys:
"total","cls","cl", and"mrp".- Return type: