Skip to content

Commit

Permalink
extend-unit-testing
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelsty committed Aug 19, 2024
1 parent 80b3451 commit a28658b
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ jobs:
- name: Run tests
run: |
pytest giga_cherche --cov=giga_cherche --cov-report=html
pytest giga_cherche --cov=doctrings --cov-report=html
pytest tests --cov=tests --cov-report=html
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,10 @@ output/

# datasets
evaluation_datasets/
datasets/
./datasets/

*.csv
*.sqlite
*.voy

/test-model/
3 changes: 0 additions & 3 deletions distillation_evaluation_results.csv

This file was deleted.

8 changes: 3 additions & 5 deletions giga_cherche/evaluation/colbert_triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def __init__(
"accuracy",
]

self.primary_metric = "accuracy"

def __call__(
self,
model: ColBERT,
Expand Down Expand Up @@ -241,11 +243,7 @@ def __call__(
for metric in self.metrics:
logger.info(f"{metric.capitalize()}: \t{metrics[metric]:.2f}")

metrics = self.prefix_name_to_metrics(
metrics,
self.name,
)
self.store_metrics_in_model_card_data(model, metrics)
self.store_metrics_in_model_card_data(model=model, metrics=metrics)

if output_path is not None and self.write_csv:
csv_writer(
Expand Down
13 changes: 13 additions & 0 deletions giga_cherche/models/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,19 @@ class ColBERT(SentenceTransformer):
>>> assert len(embeddings) == 2
>>> model.save_pretrained("test-model")
>>> model = models.ColBERT("test-model")
>>> embeddings = model.encode([
... "Hello, how are you?",
... "How is the weather today?"
... ])
>>> assert len(embeddings) == 2
>>> assert embeddings[0].shape == (9, 128)
>>> assert embeddings[1].shape == (9, 128)
"""

def __init__(
Expand Down
86 changes: 86 additions & 0 deletions tests/tests_constractive_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests the training loop."""

import os
import shutil

import pandas as pd
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.training_args import BatchSamplers

from giga_cherche import evaluation, losses, models, utils


def test_contrastive_training() -> None:
"""Test constrastive training."""
if os.path.exists(path="tests/contrastive"):
shutil.rmtree("tests/contrastive")

model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")

dataset = load_dataset("lightonai/lighton-ms-marco-mini", "triplet", split="train")

splits = dataset.train_test_split(test_size=0.5)

train_dataset, eval_dataset = splits["train"], splits["test"]

train_loss = losses.Contrastive(model=model)

dev_evaluation = evaluation.ColBERTTripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
)

args = SentenceTransformerTrainingArguments(
output_dir="tests/contrastive",
overwrite_output_dir=True,
num_train_epochs=1,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
fp16=False,
bf16=False,
batch_sampler=BatchSamplers.NO_DUPLICATES,
eval_strategy="steps",
eval_steps=1,
save_strategy="epoch",
save_steps=1,
save_total_limit=1,
learning_rate=3e-6,
do_eval=True,
)

trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluation,
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)

trainer.train()

assert os.path.isdir("tests/contrastive")

metrics = dev_evaluation(
model=model,
output_path="tests/contrastive/",
)

assert isinstance(metrics, dict)

assert os.path.isfile(path="tests/contrastive/triplet_evaluation_results.csv")

results = pd.read_csv(
filepath_or_buffer="tests/contrastive/triplet_evaluation_results.csv"
)

assert "accuracy" in list(results.columns)

if os.path.exists(path="tests/contrastive"):
shutil.rmtree("tests/contrastive")

0 comments on commit a28658b

Please sign in to comment.