Skip to content

SentenceTransformer

Sentence Transformer retriever.

Parameters

  • key (str)

  • device (str) – defaults to cpu

Examples

>>> from neural_tree import retrievers
>>> from sentence_transformers import SentenceTransformer
>>> from pprint import pprint

>>> model = SentenceTransformer("all-mpnet-base-v2")

>>> retriever = retrievers.SentenceTransformer(key="id")

>>> retriever = retriever.add(
...     documents_embeddings={
...         0: model.encode("Paris is the capital of France."),
...         1: model.encode("Berlin is the capital of Germany."),
...         2: model.encode("Paris and Berlin are European cities."),
...         3: model.encode("Paris and Berlin are beautiful cities."),
...     }
... )

>>> queries_embeddings = {
...     0: model.encode("Paris"),
...     1: model.encode("Berlin"),
... }

>>> candidates = retriever(queries_embeddings=queries_embeddings, k=2)
>>> pprint(candidates)
[[{'id': 0, 'similarity': 0.644777984318611},
  {'id': 3, 'similarity': 0.52865785276988}],
 [{'id': 1, 'similarity': 0.6901492368348436},
  {'id': 3, 'similarity': 0.5457692206973245}]]

Methods

call

Retrieve documents.

Parameters

  • queries_embeddings (dict[int, numpy.ndarray])
  • k (int | None) – defaults to 100
  • kwargs
add

Add documents to the faiss index.

Parameters

  • documents_embeddings (dict[int, numpy.ndarray])