Skip to content

Commit

Permalink
fix retriever and eval tests, add dl23, upgrade requirements (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
sahel-sh committed Apr 26, 2024
1 parent fd40109 commit 031525d
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 35 deletions.
16 changes: 8 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
tqdm>=4.66.1
openai>=1.9.0
tiktoken>=0.5.2
transformers>=4.37.0
pyserini>=0.24.0
tqdm>=4.66.2
openai>=1.23.6
tiktoken>=0.6.0
transformers>=4.40.1
pyserini>=0.35.0
python-dotenv>=1.0.1
faiss-cpu>=1.7.2
ftfy>=6.1.3
faiss-cpu>=1.8.0
ftfy>=6.2.0
dacite>=1.8.1
fschat[model_worker]>=0.2.35
fschat[model_worker]>=0.2.36
2 changes: 1 addition & 1 deletion src/rank_llm/demo/experimental_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
rerankers = {"rg": g_reranker, "rv": v_reranker, "rz": z_reranker}

results = {}
for dataset in ["dl19", "dl20", "dl21", "dl22"]:
for dataset in ["dl19", "dl20", "dl21", "dl22", "dl23"]:
retrieved_results = Retriever.from_dataset_with_prebuilt_index(dataset, k=20)
topics = TOPICS[dataset]
ret_ndcg_10 = EvalFunction.from_results(retrieved_results, topics)
Expand Down
4 changes: 4 additions & 0 deletions src/rank_llm/retrieve/indices_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"dl20": "msmarco-v1-passage",
"dl21": "msmarco-v2-passage",
"dl22": "msmarco-v2-passage",
"dl23": "msmarco-v2-passage",
"covid": "beir-v1.0.0-trec-covid.flat",
"arguana": "beir-v1.0.0-arguana.flat",
"touche": "beir-v1.0.0-webis-touche2020.flat",
Expand Down Expand Up @@ -33,6 +34,7 @@
"dl20": "msmarco-v1-passage-splade-pp-ed-text",
"dl21": "",
"dl22": "",
"dl23": "",
"covid": "beir-v1.0.0-trec-covid.splade-pp-ed",
"arguana": "beir-v1.0.0-arguana.splade-pp-ed",
"touche": "beir-v1.0.0-webis-touche2020.splade-pp-ed",
Expand All @@ -52,6 +54,7 @@
"dl20": "msmarco-v1-passage.distilbert-dot-tas_b-b256",
"dl21": "",
"dl22": "",
"dl23": "",
"covid": "",
"arguana": "",
"touche": "",
Expand All @@ -71,6 +74,7 @@
"dl20": "msmarco-v1-passage.openai-ada2",
"dl21": "",
"dl22": "",
"dl23": "",
"covid": "",
"arguana": "",
"touche": "",
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/retrieve/pyserini_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _init_prebuilt_topics(self, topics_path: str, index_path: str):
def _init_topics_from_dict(self, dataset: str):
if dataset not in TOPICS:
raise ValueError("dataset %s not in TOPICS" % dataset)
if dataset in ["dl20", "dl21", "dl22"]:
if dataset in ["dl20", "dl21", "dl22", "dl23"]:
topics_key = dataset
else:
topics_key = TOPICS[dataset]
Expand Down
1 change: 1 addition & 0 deletions src/rank_llm/retrieve/topics_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"dl20": "dl20-passage",
"dl21": "dl21-passage",
"dl22": "dl22-passage",
"dl23": "dl23-passage",
"covid": "beir-v1.0.0-trec-covid-test",
"arguana": "beir-v1.0.0-arguana-test",
"touche": "beir-v1.0.0-webis-touche2020-test",
Expand Down
8 changes: 5 additions & 3 deletions test/evaluation/test_trec_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ def setUp(self):
data={
"query": {"text": "Query1", "qid": "q1"},
"candidates": [
{"qid": "q1", "docid": "D1", "score": 0.9},
{"qid": "q1", "docid": "D2", "score": 0.8},
{"doc": {"text": "Doc1"}, "docid": "D1", "score": 0.9},
{"doc": {"text": "Doc2"}, "docid": "D2", "score": 0.8},
],
},
),
from_dict(
data_class=Result,
data={
"query": {"text": "Query2", "qid": "q2"},
"candidates": [{"qid": "q2", "docid": "D3", "score": 0.85}],
"candidates": [
{"doc": {"text": "Doc3"}, "docid": "D3", "score": 0.85}
],
},
),
]
Expand Down
49 changes: 27 additions & 22 deletions test/retrieve/test_PyseriniRetriever.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from unittest.mock import MagicMock, patch

from rank_llm.data import Result
from dacite import from_dict

from rank_llm.data import Request
from rank_llm.retrieve.indices_dict import INDICES
from rank_llm.retrieve.pyserini_retriever import PyseriniRetriever, RetrievalMethod

Expand All @@ -26,10 +28,11 @@

# Mocking Hits object
class MockHit:
def __init__(self, docid, rank, score):
def __init__(self, docid, rank, score, qid):
self.docid = docid
self.rank = rank
self.score = score
self.qid = qid


class TestPyseriniRetriever(unittest.TestCase):
Expand Down Expand Up @@ -75,7 +78,10 @@ def test_retrieve_query(self, mock_json_loads, mock_index_reader):

# Mocking hits
mock_hits = MagicMock(spec=list[MockHit])
mock_hits.__iter__.return_value = [MockHit("d1", 1, 0.5), MockHit("d2", 2, 0.4)]
mock_hits.__iter__.return_value = [
MockHit("d1", 1, 0.5, "q1"),
MockHit("d2", 2, 0.4, "q1"),
]
# Setting up PyseriniRetriever instance
retriever = PyseriniRetriever("dl19", RetrievalMethod.BM25)

Expand All @@ -84,30 +90,29 @@ def test_retrieve_query(self, mock_json_loads, mock_index_reader):

# Creating lists to store expected and actual results
expected_results = [
Result(
query="Sample Query",
hits=[
{
"content": "Title: Sample Title Content: Sample Text",
"qid": None,
"docid": "d1",
"rank": 1,
"score": 0.5,
},
{
"content": "Title: Sample Title Content: Sample Text",
"qid": None,
"docid": "d2",
"rank": 2,
"score": 0.4,
},
],
from_dict(
data_class=Request,
data={
"query": {"text": "Sample Query", "qid": "q1"},
"candidates": [
{
"doc": {"title": "Sample Title", "text": "Sample Text"},
"docid": "d1",
"score": 0.5,
},
{
"doc": {"title": "Sample Title", "text": "Sample Text"},
"docid": "d2",
"score": 0.4,
},
],
},
)
]
actual_results = []

# Calling the _retrieve_query method
retriever._retrieve_query("Sample Query", actual_results, 2)
retriever._retrieve_query("Sample Query", actual_results, 2, "q1")

# Asserting that Hits object is called with the correct query and k
retriever._searcher.search.assert_called_once_with("Sample Query", k=2)
Expand Down

0 comments on commit 031525d

Please sign in to comment.