From a28658b29989a79f5ff116e6b5bd46b8d514342c Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Wed, 14 Aug 2024 14:25:26 +0200 Subject: [PATCH] extend-unit-testing --- .github/workflows/python-tests.yml | 3 +- .gitignore | 8 +- distillation_evaluation_results.csv | 3 - giga_cherche/evaluation/colbert_triplet.py | 8 +- giga_cherche/models/colbert.py | 13 ++++ tests/tests_constractive_training.py | 86 ++++++++++++++++++++++ 6 files changed, 111 insertions(+), 10 deletions(-) delete mode 100644 distillation_evaluation_results.csv create mode 100644 tests/tests_constractive_training.py diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 2fa9237..12934ea 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -25,4 +25,5 @@ jobs: - name: Run tests run: | - pytest giga_cherche --cov=giga_cherche --cov-report=html \ No newline at end of file + pytest giga_cherche --cov=doctrings --cov-report=html + pytest tests --cov=tests --cov-report=html \ No newline at end of file diff --git a/.gitignore b/.gitignore index f6f8166..e97a932 100644 --- a/.gitignore +++ b/.gitignore @@ -148,4 +148,10 @@ output/ # datasets evaluation_datasets/ -datasets/ \ No newline at end of file +./datasets/ + +*.csv +*.sqlite +*.voy + +/test-model/ \ No newline at end of file diff --git a/distillation_evaluation_results.csv b/distillation_evaluation_results.csv deleted file mode 100644 index 00db192..0000000 --- a/distillation_evaluation_results.csv +++ /dev/null @@ -1,3 +0,0 @@ -epoch,steps,kl_divergence --1,-1,0.07069863379001617 --1,-1,0.091236412525177 diff --git a/giga_cherche/evaluation/colbert_triplet.py b/giga_cherche/evaluation/colbert_triplet.py index 2191c0c..41bcdc3 100644 --- a/giga_cherche/evaluation/colbert_triplet.py +++ b/giga_cherche/evaluation/colbert_triplet.py @@ -180,6 +180,8 @@ def __init__( "accuracy", ] + self.primary_metric = "accuracy" + def __call__( self, model: ColBERT, @@ -241,11 +243,7 @@ def __call__( for metric in self.metrics: logger.info(f"{metric.capitalize()}: \t{metrics[metric]:.2f}") - metrics = self.prefix_name_to_metrics( - metrics, - self.name, - ) - self.store_metrics_in_model_card_data(model, metrics) + self.store_metrics_in_model_card_data(model=model, metrics=metrics) if output_path is not None and self.write_csv: csv_writer( diff --git a/giga_cherche/models/colbert.py b/giga_cherche/models/colbert.py index 3f23a2a..b8a587e 100644 --- a/giga_cherche/models/colbert.py +++ b/giga_cherche/models/colbert.py @@ -179,6 +179,19 @@ class ColBERT(SentenceTransformer): >>> assert len(embeddings) == 2 + >>> model.save_pretrained("test-model") + + >>> model = models.ColBERT("test-model") + + >>> embeddings = model.encode([ + ... "Hello, how are you?", + ... "How is the weather today?" + ... ]) + + >>> assert len(embeddings) == 2 + >>> assert embeddings[0].shape == (9, 128) + >>> assert embeddings[1].shape == (9, 128) + """ def __init__( diff --git a/tests/tests_constractive_training.py b/tests/tests_constractive_training.py new file mode 100644 index 0000000..195e920 --- /dev/null +++ b/tests/tests_constractive_training.py @@ -0,0 +1,86 @@ +"""Tests the training loop.""" + +import os +import shutil + +import pandas as pd +from datasets import load_dataset +from sentence_transformers import ( + SentenceTransformerTrainer, + SentenceTransformerTrainingArguments, +) +from sentence_transformers.training_args import BatchSamplers + +from giga_cherche import evaluation, losses, models, utils + + +def test_contrastive_training() -> None: + """Test constrastive training.""" + if os.path.exists(path="tests/contrastive"): + shutil.rmtree("tests/contrastive") + + model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2") + + dataset = load_dataset("lightonai/lighton-ms-marco-mini", "triplet", split="train") + + splits = dataset.train_test_split(test_size=0.5) + + train_dataset, eval_dataset = splits["train"], splits["test"] + + train_loss = losses.Contrastive(model=model) + + dev_evaluation = evaluation.ColBERTTripletEvaluator( + anchors=eval_dataset["query"], + positives=eval_dataset["positive"], + negatives=eval_dataset["negative"], + ) + + args = SentenceTransformerTrainingArguments( + output_dir="tests/contrastive", + overwrite_output_dir=True, + num_train_epochs=1, + per_device_train_batch_size=1, + per_device_eval_batch_size=1, + fp16=False, + bf16=False, + batch_sampler=BatchSamplers.NO_DUPLICATES, + eval_strategy="steps", + eval_steps=1, + save_strategy="epoch", + save_steps=1, + save_total_limit=1, + learning_rate=3e-6, + do_eval=True, + ) + + trainer = SentenceTransformerTrainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + loss=train_loss, + evaluator=dev_evaluation, + data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize), + ) + + trainer.train() + + assert os.path.isdir("tests/contrastive") + + metrics = dev_evaluation( + model=model, + output_path="tests/contrastive/", + ) + + assert isinstance(metrics, dict) + + assert os.path.isfile(path="tests/contrastive/triplet_evaluation_results.csv") + + results = pd.read_csv( + filepath_or_buffer="tests/contrastive/triplet_evaluation_results.csv" + ) + + assert "accuracy" in list(results.columns) + + if os.path.exists(path="tests/contrastive"): + shutil.rmtree("tests/contrastive")