From 031525d71803ccb4738c3551326a24f7a26b392c Mon Sep 17 00:00:00 2001 From: Sahel Sharifymoghaddam Date: Thu, 25 Apr 2024 23:50:24 -0400 Subject: [PATCH] fix retriever and eval tests, add dl23, upgrade requirements (#113) --- requirements.txt | 16 +++---- src/rank_llm/demo/experimental_results.py | 2 +- src/rank_llm/retrieve/indices_dict.py | 4 ++ src/rank_llm/retrieve/pyserini_retriever.py | 2 +- src/rank_llm/retrieve/topics_dict.py | 1 + test/evaluation/test_trec_eval.py | 8 ++-- test/retrieve/test_PyseriniRetriever.py | 49 ++++++++++++--------- 7 files changed, 47 insertions(+), 35 deletions(-) diff --git a/requirements.txt b/requirements.txt index 47b8766..f400a59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -tqdm>=4.66.1 -openai>=1.9.0 -tiktoken>=0.5.2 -transformers>=4.37.0 -pyserini>=0.24.0 +tqdm>=4.66.2 +openai>=1.23.6 +tiktoken>=0.6.0 +transformers>=4.40.1 +pyserini>=0.35.0 python-dotenv>=1.0.1 -faiss-cpu>=1.7.2 -ftfy>=6.1.3 +faiss-cpu>=1.8.0 +ftfy>=6.2.0 dacite>=1.8.1 -fschat[model_worker]>=0.2.35 +fschat[model_worker]>=0.2.36 diff --git a/src/rank_llm/demo/experimental_results.py b/src/rank_llm/demo/experimental_results.py index 49140fd..fa1000e 100644 --- a/src/rank_llm/demo/experimental_results.py +++ b/src/rank_llm/demo/experimental_results.py @@ -24,7 +24,7 @@ rerankers = {"rg": g_reranker, "rv": v_reranker, "rz": z_reranker} results = {} -for dataset in ["dl19", "dl20", "dl21", "dl22"]: +for dataset in ["dl19", "dl20", "dl21", "dl22", "dl23"]: retrieved_results = Retriever.from_dataset_with_prebuilt_index(dataset, k=20) topics = TOPICS[dataset] ret_ndcg_10 = EvalFunction.from_results(retrieved_results, topics) diff --git a/src/rank_llm/retrieve/indices_dict.py b/src/rank_llm/retrieve/indices_dict.py index afe15a4..f020e32 100644 --- a/src/rank_llm/retrieve/indices_dict.py +++ b/src/rank_llm/retrieve/indices_dict.py @@ -4,6 +4,7 @@ "dl20": "msmarco-v1-passage", "dl21": "msmarco-v2-passage", "dl22": "msmarco-v2-passage", + "dl23": "msmarco-v2-passage", "covid": "beir-v1.0.0-trec-covid.flat", "arguana": "beir-v1.0.0-arguana.flat", "touche": "beir-v1.0.0-webis-touche2020.flat", @@ -33,6 +34,7 @@ "dl20": "msmarco-v1-passage-splade-pp-ed-text", "dl21": "", "dl22": "", + "dl23": "", "covid": "beir-v1.0.0-trec-covid.splade-pp-ed", "arguana": "beir-v1.0.0-arguana.splade-pp-ed", "touche": "beir-v1.0.0-webis-touche2020.splade-pp-ed", @@ -52,6 +54,7 @@ "dl20": "msmarco-v1-passage.distilbert-dot-tas_b-b256", "dl21": "", "dl22": "", + "dl23": "", "covid": "", "arguana": "", "touche": "", @@ -71,6 +74,7 @@ "dl20": "msmarco-v1-passage.openai-ada2", "dl21": "", "dl22": "", + "dl23": "", "covid": "", "arguana": "", "touche": "", diff --git a/src/rank_llm/retrieve/pyserini_retriever.py b/src/rank_llm/retrieve/pyserini_retriever.py index ce337cc..e15d239 100644 --- a/src/rank_llm/retrieve/pyserini_retriever.py +++ b/src/rank_llm/retrieve/pyserini_retriever.py @@ -202,7 +202,7 @@ def _init_prebuilt_topics(self, topics_path: str, index_path: str): def _init_topics_from_dict(self, dataset: str): if dataset not in TOPICS: raise ValueError("dataset %s not in TOPICS" % dataset) - if dataset in ["dl20", "dl21", "dl22"]: + if dataset in ["dl20", "dl21", "dl22", "dl23"]: topics_key = dataset else: topics_key = TOPICS[dataset] diff --git a/src/rank_llm/retrieve/topics_dict.py b/src/rank_llm/retrieve/topics_dict.py index 4750036..ce2c646 100644 --- a/src/rank_llm/retrieve/topics_dict.py +++ b/src/rank_llm/retrieve/topics_dict.py @@ -3,6 +3,7 @@ "dl20": "dl20-passage", "dl21": "dl21-passage", "dl22": "dl22-passage", + "dl23": "dl23-passage", "covid": "beir-v1.0.0-trec-covid-test", "arguana": "beir-v1.0.0-arguana-test", "touche": "beir-v1.0.0-webis-touche2020-test", diff --git a/test/evaluation/test_trec_eval.py b/test/evaluation/test_trec_eval.py index b9d1a91..15bf720 100644 --- a/test/evaluation/test_trec_eval.py +++ b/test/evaluation/test_trec_eval.py @@ -15,8 +15,8 @@ def setUp(self): data={ "query": {"text": "Query1", "qid": "q1"}, "candidates": [ - {"qid": "q1", "docid": "D1", "score": 0.9}, - {"qid": "q1", "docid": "D2", "score": 0.8}, + {"doc": {"text": "Doc1"}, "docid": "D1", "score": 0.9}, + {"doc": {"text": "Doc2"}, "docid": "D2", "score": 0.8}, ], }, ), @@ -24,7 +24,9 @@ def setUp(self): data_class=Result, data={ "query": {"text": "Query2", "qid": "q2"}, - "candidates": [{"qid": "q2", "docid": "D3", "score": 0.85}], + "candidates": [ + {"doc": {"text": "Doc3"}, "docid": "D3", "score": 0.85} + ], }, ), ] diff --git a/test/retrieve/test_PyseriniRetriever.py b/test/retrieve/test_PyseriniRetriever.py index 412654f..743f6b9 100644 --- a/test/retrieve/test_PyseriniRetriever.py +++ b/test/retrieve/test_PyseriniRetriever.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import MagicMock, patch -from rank_llm.data import Result +from dacite import from_dict + +from rank_llm.data import Request from rank_llm.retrieve.indices_dict import INDICES from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod @@ -26,10 +28,11 @@ # Mocking Hits object class MockHit: - def __init__(self, docid, rank, score): + def __init__(self, docid, rank, score, qid): self.docid = docid self.rank = rank self.score = score + self.qid = qid class TestPyseriniRetriever(unittest.TestCase): @@ -75,7 +78,10 @@ def test_retrieve_query(self, mock_json_loads, mock_index_reader): # Mocking hits mock_hits = MagicMock(spec=list[MockHit]) - mock_hits.__iter__.return_value = [MockHit("d1", 1, 0.5), MockHit("d2", 2, 0.4)] + mock_hits.__iter__.return_value = [ + MockHit("d1", 1, 0.5, "q1"), + MockHit("d2", 2, 0.4, "q1"), + ] # Setting up PyseriniRetriever instance retriever = PyseriniRetriever("dl19", RetrievalMethod.BM25) @@ -84,30 +90,29 @@ def test_retrieve_query(self, mock_json_loads, mock_index_reader): # Creating lists to store expected and actual results expected_results = [ - Result( - query="Sample Query", - hits=[ - { - "content": "Title: Sample Title Content: Sample Text", - "qid": None, - "docid": "d1", - "rank": 1, - "score": 0.5, - }, - { - "content": "Title: Sample Title Content: Sample Text", - "qid": None, - "docid": "d2", - "rank": 2, - "score": 0.4, - }, - ], + from_dict( + data_class=Request, + data={ + "query": {"text": "Sample Query", "qid": "q1"}, + "candidates": [ + { + "doc": {"title": "Sample Title", "text": "Sample Text"}, + "docid": "d1", + "score": 0.5, + }, + { + "doc": {"title": "Sample Title", "text": "Sample Text"}, + "docid": "d2", + "score": 0.4, + }, + ], + }, ) ] actual_results = [] # Calling the _retrieve_query method - retriever._retrieve_query("Sample Query", actual_results, 2) + retriever._retrieve_query("Sample Query", actual_results, 2, "q1") # Asserting that Hits object is called with the correct query and k retriever._searcher.search.assert_called_once_with("Sample Query", k=2)