Skip to content

Commit

Permalink
extend-unit-testing
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelsty committed Aug 20, 2024
1 parent 80b3451 commit 5bd83f3
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 69 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ jobs:
- name: Run tests
run: |
pytest giga_cherche --cov=giga_cherche --cov-report=html
pytest giga_cherche --cov=doctrings --cov-report=html
pytest tests --cov=tests --cov-report=html
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,13 @@ output/

# datasets
evaluation_datasets/
datasets/
./datasets/

*.csv
*.sqlite
*.voy
*.parquet
*.tsv

/test-model/

3 changes: 0 additions & 3 deletions distillation_evaluation_results.csv

This file was deleted.

8 changes: 3 additions & 5 deletions giga_cherche/evaluation/colbert_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def __init__(
"accuracy",
]

self.primary_metric = "accuracy"

def __call__(
self,
model: ColBERT,
Expand Down Expand Up @@ -241,11 +243,7 @@ def __call__(
for metric in self.metrics:
logger.info(f"{metric.capitalize()}: \t{metrics[metric]:.2f}")

metrics = self.prefix_name_to_metrics(
metrics,
self.name,
)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model=model, metrics=metrics)

if output_path is not None and self.write_csv:
csv_writer(
Expand Down
16 changes: 13 additions & 3 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,19 @@ class ColBERT(SentenceTransformer):
>>> assert len(embeddings) == 2
>>> model.save_pretrained("test-model")
>>> model = models.ColBERT("test-model")
>>> embeddings = model.encode([
... "Hello, how are you?",
... "How is the weather today?"
... ])
>>> assert len(embeddings) == 2
>>> assert embeddings[0].shape == (9, 128)
>>> assert embeddings[1].shape == (9, 128)
"""

def __init__(
Expand Down Expand Up @@ -789,9 +802,6 @@ def encode_multi_process(
>>> model.stop_multi_process_pool(pool)
>>> assert len(embeddings) == 3
>>> assert embeddings[0].shape == (9, 128)
>>> assert embeddings[1].shape == (10, 128)
>>> assert embeddings[2].shape == (9, 128)
"""

Expand Down
177 changes: 121 additions & 56 deletions giga_cherche/utils/processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import ast
import logging

import datasets

logger = logging.getLogger(__name__)


class KDProcessing:
"""Dataset processing class for knowledge distillation training.
Expand All @@ -17,105 +20,167 @@ class KDProcessing:
Examples
--------
from datasets import load_dataset
from giga_cherche import utils
train = load_dataset(
path="./msmarco_fr",
name="train",
cache_dir="./msmarco_fr"
)
queries = load_dataset(
path="./msmarco_fr",
name="queries",
cache_dir="./msmarco_fr"
)
documents = load_dataset(
path="./msmarco_fr", name="documents", cache_dir="./msmarco_fr"
)
train = train.map(
utils.DatasetProcessing(
queries=queries, documents=documents
).transform,
)
>>> from datasets import load_dataset
>>> from giga_cherche import utils
>>> train = load_dataset(
... path="lightonai/lighton-ms-marco-mini",
... name="train",
... split="train",
... )
>>> queries = load_dataset(
... path="lightonai/lighton-ms-marco-mini",
... name="queries",
... split="train",
... )
>>> documents = load_dataset(
... path="lightonai/lighton-ms-marco-mini",
... name="documents",
... split="train",
... )
>>> train.set_transform(
... utils.KDProcessing(
... queries=queries, documents=documents
... ).transform,
... )
>>> for sample in train:
... assert "documents" in sample and isinstance(sample["documents"], list)
... assert "query" in sample and isinstance(sample["query"], str)
... assert "scores" in sample and isinstance(sample["scores"], list)
"""

def __init__(
self, queries: datasets.Dataset, documents: datasets.Dataset, n_ways: int = 32
) -> None:
self.queries = queries
self.documents = documents
self.queries = queries["train"] if "train" in queries else queries
self.documents = documents["train"] if "train" in documents else documents
self.n_ways = n_ways

self.queries_index = {
query_id: i
for i, query_id in enumerate(iterable=self.queries["train"]["query_id"])
query_id: i for i, query_id in enumerate(iterable=self.queries["query_id"])
}

self.documents_index = {
document_id: i
for i, document_id in enumerate(
iterable=self.documents["train"]["document_id"]
)
for i, document_id in enumerate(iterable=self.documents["document_id"])
}

def transform(self, examples: dict) -> dict:
"""Update the input dataset with the queries and documents."""
examples["scores"] = [
ast.literal_eval(node_or_string=score)[: self.n_ways]
for score in examples["scores"]
]
if isinstance(examples["scores"], str):
examples["scores"] = [
ast.literal_eval(node_or_string=score) for score in examples["scores"]
]

examples["scores"] = [score[: self.n_ways] for score in examples["scores"]]

if isinstance(examples["document_ids"], str):
examples["document_ids"] = [
ast.literal_eval(node_or_string=document_ids)
for document_ids in examples["document_ids"]
]

examples["document_ids"] = [
ast.literal_eval(node_or_string=document_ids)[: self.n_ways]
for document_ids in examples["document_ids"]
document_ids[: self.n_ways] for document_ids in examples["document_ids"]
]

examples["query"] = [
[self.queries["train"][self.queries_index[query_id]]["text"]]
self.queries[self.queries_index[query_id]]["text"]
for query_id in examples["query_id"]
]

examples["documents"] = []
for doc_ids in examples["document_ids"]:
for document_ids in examples["document_ids"]:
documents = []
for document_id in doc_ids:
for document_id in document_ids:
try:
documents.append(
self.documents["train"][self.documents_index[document_id]][
"text"
]
self.documents[self.documents_index[document_id]]["text"]
)
except KeyError:
documents.append("")
print(f"KeyError: {document_id}")
logger.warning(f"Unable to find document: {document_id}")

examples["documents"].append(documents)

return examples

def map(self, example: dict) -> dict:
"""Add queries and documents text to the examples."""
scores = ast.literal_eval(node_or_string=example["scores"])[: self.n_ways]
documents_ids = ast.literal_eval(node_or_string=example["document_ids"])[
: self.n_ways
]
"""Process a single example.
Parameters
----------
example
Example to process.
Examples
--------
>>> from datasets import load_dataset
>>> from giga_cherche import utils
>>> train = load_dataset(
... path="lightonai/lighton-ms-marco-mini",
... name="train",
... split="train",
... )
>>> queries = load_dataset(
... path="lightonai/lighton-ms-marco-mini",
... name="queries",
... split="train",
... )
>>> documents = load_dataset(
... path="lightonai/lighton-ms-marco-mini",
... name="documents",
... split="train",
... )
>>> train = train.map(
... utils.KDProcessing(
... queries=queries, documents=documents
... ).map,
... )
>>> for sample in train:
... assert "documents" in sample and isinstance(sample["documents"], list)
... assert "query" in sample and isinstance(sample["query"], str)
... assert "scores" in sample and isinstance(sample["scores"], list)
"""
if isinstance(example["scores"], str):
example["scores"] = ast.literal_eval(node_or_string=example["scores"])

example["scores"] = example["scores"][: self.n_ways]

if isinstance(example["document_ids"], str):
example["document_ids"] = ast.literal_eval(
node_or_string=example["document_ids"]
)

example["document_ids"] = example["document_ids"][: self.n_ways]

processed_example = {
"scores": scores,
"query": self.queries["train"][self.queries_index[example["query_id"]]][
"text"
],
"scores": example["scores"],
"query": self.queries[self.queries_index[example["query_id"]]]["text"],
}

documents = []
for document_id in documents_ids:
for document_id in example["document_ids"]:
try:
documents.append(
self.documents["train"][self.documents_index[document_id]]["text"]
self.documents[self.documents_index[document_id]]["text"]
)
except KeyError:
documents.append("")
print(f"KeyError: {document_id}")
logger.warning(f"Unable to find document: {document_id}")

processed_example["documents"] = documents

return processed_example
Loading

0 comments on commit 5bd83f3

Please sign in to comment.