Skip to content

train_splade

Compute the ranking loss and the flops loss for a single step.

Parameters

  • model

    Splade model.

  • optimizer

    Optimizer.

  • anchor (list[str])

    Anchor.

  • positive (list[str])

    Positive.

  • negative (list[str])

    Negative.

  • flops_loss_weight (float) – defaults to 0.0001

    Flops loss weight. Defaults to 1e-4.

  • sparse_loss_weight (float) – defaults to 1.0

  • in_batch_negatives (bool) – defaults to False

    Whether to use in batch negatives or not. Defaults to True.

  • threshold_flops (float) – defaults to 30

  • kwargs

Examples

>>> from neural_cherche import models, utils, train
>>> import torch

>>> _ = torch.manual_seed(42)

>>> model = models.Splade(
...     model_name_or_path="raphaelsty/neural-cherche-sparse-embed",
...     device="mps",
... )

>>> optimizer = torch.optim.AdamW(
...     model.parameters(),
...     lr=1e-6,
... )

>>> X = [
...     ("Sports", "Music", "Cinema"),
...     ("Sports", "Music", "Cinema"),
...     ("Sports", "Music", "Cinema"),
... ]

>>> flops_scheduler = losses.FlopsScheduler()

>>> for anchor, positive, negative in utils.iter(
...         X,
...         epochs=3,
...         batch_size=3,
...         shuffle=False
...     ):
...     loss = train.train_splade(
...         model=model,
...         optimizer=optimizer,
...         anchor=anchor,
...         positive=positive,
...         negative=negative,
...         flops_loss_weight=flops_scheduler.get(),
...         in_batch_negatives=False,
...     )

>>> loss
{'sparse': tensor(0., device='mps:0', grad_fn=<ClampBackward1>), 'flops': tensor(10., device='mps:0', grad_fn=<ClampBackward1>)}