Skip to content

Commit

Permalink
update scoring function, cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelsty committed Jul 31, 2024
1 parent 8775b7a commit 432f137
Show file tree
Hide file tree
Showing 28 changed files with 709 additions and 670 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Python Tests

on:
pull_request:
branches:
- '**'

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ".[dev]"
- name: Run tests
run: |
pytest giga_cherche --cov=giga_cherche --cov-report=html
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ For example, to run the BEIR evaluations using giga-cherche indexes:
# Modeling
The modeling of giga-cherche is based on sentence-transformers which allow to build a ColBERT model from any encoder available by appending a projection layer applied to the output of the encoders to reduce the embeddings dimension.
```
from giga_cherche.models import ColBERT
from giga_cherche import models
model_name = "bert-base-uncased"
model = ColBERT(model_name_or_path=model_name)
model = models.ColBERT(model_name_or_path=model_name)
```
The following parameters can be passed to the constructor to set different properties of the model:
- ```embedding_size```, the output size of the projection layer and so the dimension of the embeddings
Expand All @@ -40,7 +40,7 @@ from sentence_transformers import (
SentenceTransformerTrainingArguments,
)

from giga_cherche import losses, models, data_collator, evaluation
from giga_cherche import losses, models, datasets, evaluation

model_name = "bert-base-uncased"
batch_size = 32
Expand Down Expand Up @@ -77,7 +77,7 @@ trainer = SentenceTransformerTrainer(
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
data_collator=data_collator.ColBERT(model.tokenize),
data_collator=utils.ColBERTCollator(model.tokenize),
)

trainer.train()
Expand All @@ -88,7 +88,7 @@ trainer.train()
```
import ast
def add_queries_and_documents(example: dict) -> dict:
def add_queries_and_documents(Examples dict) -> dict:
"""Add queries and documents text to the examples."""
scores = ast.literal_eval(node_or_string=example["scores"])
processed_example = {"scores": scores, "query": queries[example["query_id"]]}
Expand Down Expand Up @@ -135,7 +135,7 @@ You can then compute the ColBERT max-sim scores like this:

```python
from giga_cherche import scores
similarity_scores = scores.colbert_score(query_embeddings, document_embeddings)
similarity_scores = scores.colbert_scores(query_embeddings, document_embeddings)
```

## Indexing
Expand Down
11 changes: 6 additions & 5 deletions giga_cherche/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
__all__ = [
"models",
"losses",
"scores",
"evaluation",
"indexes",
"reranker",
"data_collator",
"losses",
"models",
"rerank",
"retrieve",
"scores",
"utils",
]
3 changes: 0 additions & 3 deletions giga_cherche/data_collator/__init__.py

This file was deleted.

76 changes: 16 additions & 60 deletions giga_cherche/evaluation/beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
import random
from collections import defaultdict

__all__ = ["evaluate", "load_beir", "get_beir_triples"]


def add_duplicates(queries: list[str], scores: list[list[dict]]) -> list:
"""Add back duplicates scores to the set of candidates."""
"""Add back duplicates scores to the set of candidates.
Parameters
----------
queries
List of queries.
scores
Scores of the retrieval model.
"""
query_counts = defaultdict(int)
for query in queries:
query_counts[query] += 1
Expand All @@ -31,7 +38,9 @@ def load_beir(dataset_name: str, split: str = "test") -> tuple[list, list, dict]
Parameters
----------
dataset_name
Dataset name: scifact.
Name of the beir dataset.
split
Split to load.
"""
from beir import util
Expand Down Expand Up @@ -85,14 +94,14 @@ def get_beir_triples(
Examples
--------
>>> from neural_cherche import utils
>>> from giga_cherche import evaluation
>>> documents, queries, qrels = utils.load_beir(
>>> documents, queries, qrels = evaluation.load_beir(
... "scifact",
... split="test",
... )
>>> triples = utils.get_beir_triples(
>>> triples = evaluation.get_beir_triples(
... key="id",
... on=["title", "text"],
... documents=documents,
Expand Down Expand Up @@ -146,59 +155,6 @@ def evaluate(
metrics
Metrics to compute.
Examples
--------
>>> from neural_cherche import models, retrieve, utils
>>> import torch
>>> _ = torch.manual_seed(42)
>>> model = models.Splade(
... model_name_or_path="raphaelsty/neural-cherche-sparse-embed",
... device="cpu",
... )
>>> documents, queries, qrels = utils.load_beir(
... "scifact",
... split="test",
... )
>>> documents = documents[:10]
>>> retriever = retrieve.Splade(
... key="id",
... on=["title", "text"],
... model=model
... )
>>> documents_embeddings = retriever.encode_documents(
... documents=documents,
... batch_size=1,
... )
>>> documents_embeddings = retriever.add(
... documents_embeddings=documents_embeddings,
... )
>>> queries_embeddings = retriever.encode_queries(
... queries=queries,
... batch_size=1,
... )
>>> scores = retriever(
... queries_embeddings=queries_embeddings,
... k=30,
... batch_size=1,
... )
>>> utils.evaluate(
... scores=scores,
... qrels=qrels,
... queries=queries,
... metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"]
... )
{'map': 0.0033333333333333335, 'ndcg@10': 0.0033333333333333335, 'ndcg@100': 0.0033333333333333335, 'recall@10': 0.0033333333333333335, 'recall@100': 0.0033333333333333335}
"""
from ranx import Qrels, Run, evaluate

Expand Down
19 changes: 11 additions & 8 deletions giga_cherche/evaluation/colbert_triplet_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,17 @@
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.similarity_functions import SimilarityFunction

from giga_cherche.scores.colbert_score import colbert_score
from ..scores import colbert_scores

logger = logging.getLogger(__name__)

__all__ = ["ColBERTTripletEvaluator"]


class ColBERTTripletEvaluator(SentenceEvaluator):
"""
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
Checks if colbert distance(sentence, positive_example) < distance(sentence, negative_example).
Example:
Examples
::
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -198,15 +196,20 @@ def __call__(
# Colbert distance
# pos_colbert_distances = colbert_pairwise_score(embeddings_anchors, embeddings_positives)
# neg_colbert_distances = colbert_pairwise_score(embeddings_anchors, embeddings_negatives)
pos_colbert_distances_full = colbert_score(
embeddings_anchors, embeddings_positives
pos_colbert_distances_full = colbert_scores(
queries_embeddings=embeddings_anchors,
documents_embeddings=embeddings_positives,
)
neg_colbert_distances_full = colbert_score(
embeddings_anchors, embeddings_negatives

neg_colbert_distances_full = colbert_scores(
queries_embeddings=embeddings_anchors,
documents_embeddings=embeddings_negatives,
)

distances_full = torch.cat(
[pos_colbert_distances_full, neg_colbert_distances_full], dim=1
)

# print(distances_full.shape)
labels = np.arange(0, len(embeddings_anchors))
indices = np.argsort(-distances_full.cpu().numpy(), axis=1)
Expand Down
2 changes: 0 additions & 2 deletions giga_cherche/indexes/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod

__all__ = ["Base"]


class Base(ABC):
"""Base class for all indexes. Indexes are used to store and retrieve embeddings."""
Expand Down
9 changes: 5 additions & 4 deletions giga_cherche/indexes/weaviate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import time

import weaviate
import weaviate.classes as wvc
try:
import weaviate
import weaviate.classes as wvc
except ImportError:
pass

from .base import Base

__all__ = ["Weaviate"]


# TODO: define Index metaclass
# max_doc_length is used to set a limit in the fetch embeddings method as the speed is dependant on the number of embeddings fetched
Expand Down
5 changes: 3 additions & 2 deletions giga_cherche/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .colbert import ColBERTLossv1, ColBERTLossv2
from .contrastive import Contrastive
from .distillation import Distillation

__all__ = ["ColBERTLossv1", "ColBERTLossv2"]
__all__ = ["Contrastive", "Distillation"]
Loading

0 comments on commit 432f137

Please sign in to comment.