Skip to content

ColBERT

ColBERT ranker.

Parameters

Examples

>>> from neural_cherche import models, rank
>>> from pprint import pprint
>>> import torch

>>> _ = torch.manual_seed(42)

>>> encoder = models.ColBERT(
...     model_name_or_path="raphaelsty/neural-cherche-colbert",
...     device="mps",
... )

>>> documents = [
...     {"id": 0, "document": "Food"},
...     {"id": 1, "document": "Sports"},
...     {"id": 2, "document": "Cinema"},
... ]

>>> queries = ["Food", "Sports", "Cinema"]

>>> ranker = rank.ColBERT(
...    key="id",
...    on=["document"],
...    model=encoder,
... )

>>> queries_embeddings = ranker.encode_queries(
...     queries=queries,
...     batch_size=3,
... )

>>> documents_embeddings = ranker.encode_documents(
...     documents=documents,
...     batch_size=3,
... )

>>> scores = ranker(
...     documents=[documents for _ in queries],
...     queries_embeddings=queries_embeddings,
...     documents_embeddings=documents_embeddings,
...     batch_size=3,
...     tqdm_bar=True,
...     k=3,
... )

>>> pprint(scores)
[[{'document': 'Food', 'id': 0, 'similarity': 20.23601531982422},
  {'document': 'Cinema', 'id': 2, 'similarity': 7.255690574645996},
  {'document': 'Sports', 'id': 1, 'similarity': 6.666046142578125}],
 [{'document': 'Sports', 'id': 1, 'similarity': 21.373430252075195},
  {'document': 'Cinema', 'id': 2, 'similarity': 5.494492053985596},
  {'document': 'Food', 'id': 0, 'similarity': 4.814355850219727}],
 [{'document': 'Sports', 'id': 1, 'similarity': 9.25660228729248},
  {'document': 'Food', 'id': 0, 'similarity': 8.206350326538086},
  {'document': 'Cinema', 'id': 2, 'similarity': 5.496612548828125}]]

Methods

call

Rank documents givent queries.

Parameters

  • documents (list[list[dict]])
  • queries_embeddings (dict[str, torch.Tensor])
  • documents_embeddings (dict[str, torch.Tensor])
  • batch_size (int) – defaults to 32
  • tqdm_bar (bool) – defaults to True
  • k (int) – defaults to None
encode_documents

Encode documents.

Parameters

  • documents (list[str])
  • batch_size (int) – defaults to 32
  • tqdm_bar (bool) – defaults to True
  • query_mode (bool) – defaults to False
  • kwargs
encode_queries

Encode queries.

Parameters

  • queries (list[str])
  • batch_size (int) – defaults to 32
  • tqdm_bar (bool) – defaults to True
  • query_mode (bool) – defaults to True
  • kwargs