Skip to content

Commit

Permalink
Clean up MS MARCO V1 passage + BEIR regressions (#1807)
Browse files Browse the repository at this point in the history
Implemented much simpler heuristic for flaky tests where scores vary depending on OS:
Anything within 0.005, just output "OKish". Everything else, code up case-by-case exceptions.
  • Loading branch information
lintool committed Mar 8, 2024
1 parent 310c828 commit 77f8a81
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 122 deletions.
25 changes: 13 additions & 12 deletions pyserini/2cr/beir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pkg_resources
import yaml

from ._base import run_eval_and_return_metric, ok_str, fail_str
from ._base import run_eval_and_return_metric, ok_str, okish_str, fail_str

dense_threads = 16
dense_batch_size = 512
Expand Down Expand Up @@ -152,22 +152,18 @@ def generate_report(args):
s4=f'{table[dataset]["bm25-multifield"]["R@100"]:8.4f}',
s5=f'{table[dataset]["splade-pp-ed"]["nDCG@10"]:8.4f}',
s6=f'{table[dataset]["splade-pp-ed"]["R@100"]:8.4f}',
# s7=f'{table[dataset]["contriever"]["nDCG@10"]:8.4f}',
# s8=f'{table[dataset]["contriever"]["R@100"]:8.4f}',
s7=f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.4f}',
s8=f'{table[dataset]["contriever-msmarco"]["R@100"]:8.4f}',
s9=f'{table[dataset]["bge-base-en-v1.5"]["nDCG@10"]:8.4f}',
s10=f'{table[dataset]["bge-base-en-v1.5"]["R@100"]:8.4f}',
cmd1=commands[dataset]["bm25-flat"],
cmd2=commands[dataset]["bm25-multifield"],
cmd3=commands[dataset]["splade-pp-ed"],
# cmd4=commands[dataset]["contriever"],
cmd4=commands[dataset]["contriever-msmarco"],
cmd5=commands[dataset]["bge-base-en-v1.5"],
eval_cmd1=eval_commands[dataset]["bm25-flat"].rstrip(),
eval_cmd2=eval_commands[dataset]["bm25-multifield"].rstrip(),
eval_cmd3=eval_commands[dataset]["splade-pp-ed"].rstrip(),
# eval_cmd4=eval_commands[dataset]["contriever"].rstrip(),
eval_cmd4=eval_commands[dataset]["contriever-msmarco"].rstrip(),
eval_cmd5=eval_commands[dataset]["bge-base-en-v1.5"].rstrip())

Expand Down Expand Up @@ -229,8 +225,13 @@ def run_conditions(args):

score = float(run_eval_and_return_metric(metric, f'beir-v1.0.0-{dataset}-test',
trec_eval_metric_definitions[metric], runfile))
result = ok_str if math.isclose(score, float(expected[metric])) \
else fail_str + f' expected {expected[metric]:.4f}'
if math.isclose(score, float(expected[metric])):
result = ok_str
# If results are within 0.0005, just call it "OKish".
elif abs(score - float(expected[metric])) <= 0.0005:
result = okish_str
else:
result = fail_str
print(f' {metric:7}: {score:.4f} {result}')

table[dataset][name][metric] = score
Expand Down Expand Up @@ -266,24 +267,24 @@ def run_conditions(args):
final_score = (top_level_sums[model][metric] + cqa_score) / 18
final_scores[model][metric] = final_score

print(' ' * 30 + 'BM25-flat' + ' ' * 10 + 'BM25-mf' + ' ' * 13 + 'SPLADE' + ' ' * 11 + 'Contriever' + ' ' * 5 + 'Contriever-msmarco' + ' ' * 5 + 'BGE-base-en-v1.5')
print(' ' * 26 + 'nDCG@10 R@100 ' * 5)
print(' ' * 30 + 'BM25-flat' + ' ' * 10 + 'BM25-mf' + ' ' * 13 + 'SPLADE' + ' ' * 11 + 'Contriever' + ' ' * 5 + 'Contriever-msmarco' + ' ' * 2 + 'BGE-base-en-v1.5')
print(' ' * 26 + 'nDCG@10 R@100 ' * 6)
print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14)
for dataset in beir_keys:
print(f'{dataset:25}' +
f'{table[dataset]["bm25-flat"]["nDCG@10"]:8.4f}{table[dataset]["bm25-flat"]["R@100"]:8.4f} ' +
f'{table[dataset]["bm25-multifield"]["nDCG@10"]:8.4f}{table[dataset]["bm25-multifield"]["R@100"]:8.4f} ' +
f'{table[dataset]["splade-pp-ed"]["nDCG@10"]:8.4f}{table[dataset]["splade-pp-ed"]["R@100"]:8.4f} ' +
f'{table[dataset]["contriever"]["nDCG@10"]:8.4f}{table[dataset]["contriever"]["R@100"]:8.4f} ' +
f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.4f}{table[dataset]["contriever-msmarco"]["R@100"]:8.4f}' +
f'{table[dataset]["contriever-msmarco"]["nDCG@10"]:8.4f}{table[dataset]["contriever-msmarco"]["R@100"]:8.4f} ' +
f'{table[dataset]["bge-base-en-v1.5"]["nDCG@10"]:8.4f}{table[dataset]["bge-base-en-v1.5"]["R@100"]:8.4f}')
print(' ' * 27 + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14 + ' ' + '-' * 14)
print('avg' + ' ' * 22 + f'{final_scores["bm25-flat"]["nDCG@10"]:8.4f}{final_scores["bm25-flat"]["R@100"]:8.4f} ' +
f'{final_scores["bm25-multifield"]["nDCG@10"]:8.4f}{final_scores["bm25-multifield"]["R@100"]:8.4f} ' +
f'{final_scores["splade-pp-ed"]["nDCG@10"]:8.4f}{final_scores["splade-pp-ed"]["R@100"]:8.4f} ' +
f'{final_scores["contriever"]["nDCG@10"]:8.4f}{final_scores["contriever"]["R@100"]:8.4f} ' +
f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f}' +
f'{final_scores["bge-base-en-v1.5"]["nDCG@10"]:8.4f}{final_scores["bge-base-en-v1.5"]["R@100"]:8.4f}')
f'{final_scores["contriever-msmarco"]["nDCG@10"]:8.4f}{final_scores["contriever-msmarco"]["R@100"]:8.4f} ' +
f'{final_scores["bge-base-en-v1.5"]["nDCG@10"]:8.4f}{final_scores["bge-base-en-v1.5"]["R@100"]:8.4f}')

end = time.time()

Expand Down
4 changes: 2 additions & 2 deletions pyserini/2cr/beir.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,8 @@ conditions:
R@1000: 0.7824
- dataset: scifact
scores:
- nDCG@10: 0.7376
R@100: 0.9700
- nDCG@10: 0.7408
R@100: 0.9667
R@1000: 0.9967
- dataset: signal1m
scores:
Expand Down
16 changes: 8 additions & 8 deletions pyserini/2cr/msmarco-v1-passage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ conditions:
- topic_key: msmarco-passage-dev-subset
eval_key: msmarco-passage-dev-subset
scores:
- MRR@10: 0.3557
R@1K: 0.9814
- MRR@10: 0.3583
R@1K: 0.9811
- topic_key: dl19-passage
eval_key: dl19-passage
scores:
- MAP: 0.4436
nDCG@10: 0.7055
R@1K: 0.8472
- MAP: 0.4485
nDCG@10: 0.7016
R@1K: 0.8427
- topic_key: dl20
eval_key: dl20-passage
scores:
- MAP: 0.4651
nDCG@10: 0.6780
R@1K: 0.8503
- MAP: 0.4628
nDCG@10: 0.6768
R@1K: 0.8547
- name: cosdpr-distil-pytorch
display: "cosDPR-distil: PyTorch"
display-html: "cosDPR-distil: PyTorch"
Expand Down
51 changes: 26 additions & 25 deletions pyserini/2cr/msmarco.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
'unicoil-noexp-otf',
'unicoil-otf',
'slimr-pp'],

# MS MARCO v2 doc
'msmarco-v2-doc':
['bm25-doc-default',
'bm25-doc-segmented-default',
Expand Down Expand Up @@ -337,9 +339,11 @@ def list_conditions(args):
continue
print(condition)

def _get_display_num(num: int) -> str:

def _get_display_num(num: int) -> str:
return f'{num:.4f}' if num != 0 else '-'


def _remove_commands(table, name, s, v1):
v1_unavilable_dict = {
('dl19', 'MAP'): 'Command to generate run on TREC 2019 queries:.*?</div>',
Expand All @@ -359,6 +363,7 @@ def _remove_commands(table, name, s, v1):
s = re.sub(re.compile(v, re.MULTILINE | re.DOTALL), 'Not available.</div>', s)
return s


def generate_report(args):
yaml_file = pkg_resources.resource_filename(__name__, f'{args.collection}.yaml')

Expand Down Expand Up @@ -512,22 +517,15 @@ def generate_report(args):
with open(args.output, 'w') as out:
out.write(Template(html_template).substitute(title=full_name, rows=all_rows))

# Flaky test on Jimmy's Mac Studio

FlakyKey = namedtuple('FlakyKey', ['collection', 'name', 'topic_key', 'metric'])
flaky_dict = {
FlakyKey('msmarco-v1-passage', 'distilbert-kd-tasb-rocchio-prf-pytorch', 'msmarco-passage-dev-subset', 'MRR@10'): 0.0001,
FlakyKey('msmarco-v1-passage', 'tct_colbert-v2-hnp-avg-prf-pytorch', 'dl19-passage', 'nDCG@10'): 0.0001,
FlakyKey('msmarco-v1-passage', 'tct_colbert-v2-hnp-avg-prf-pytorch', 'dl20', 'MAP'): 0.0002,
FlakyKey('msmarco-v1-passage', 'tct_colbert-v2-hnp-avg-prf-pytorch', 'dl20', 'nDCG@10'): 0.0009,
FlakyKey('msmarco-v1-passage', 'tct_colbert-v2-hnp-bm25-pytorch', 'msmarco-passage-dev-subset', 'MRR@10'): 0.0001,
FlakyKey('msmarco-v1-passage', 'ance', 'msmarco-passage-dev-subset', 'MRR@10'): 0.0001,
FlakyKey('msmarco-v1-passage', 'ance-pytorch', 'msmarco-passage-dev-subset', 'MRR@10'): 0.0001,
FlakyKey('msmarco-v1-passage', 'ance-rocchio-prf-pytorch', 'msmarco-passage-dev-subset', 'R@1K'): 0.0002,
FlakyKey('msmarco-v1-passage', 'ance-rocchio-prf-pytorch', 'dl19-passage', 'MAP'): 0.0001,
FlakyKey('msmarco-v1-passage', 'ance-rocchio-prf-pytorch', 'dl19-passage', 'nDCG@10'): 0.0008,
FlakyKey('msmarco-v1-passage', 'ance-avg-prf-pytorch', 'msmarco-passage-dev-subset', 'MRR@10'): 0.0002
# Flaky test on Jimmy's Mac Studio
FlakyKey('msmarco-v1-passage', 'tct_colbert-v2-hnp-avg-prf-pytorch', 'dl20', 'nDCG@10'): 0.0009,
FlakyKey('msmarco-v1-passage', 'ance-rocchio-prf-pytorch', 'dl19-passage', 'nDCG@10'): 0.0008,
}


def run_conditions(args):
start = time.time()

Expand Down Expand Up @@ -589,10 +587,13 @@ def run_conditions(args):
runfile))
if math.isclose(score, float(expected[metric])):
result_str = ok_str
# Flaky test on Jimmy's Mac Studio
# If results are within 0.0005, just call it "OKish".
elif abs(score-float(expected[metric])) <= 0.0005:
result_str = okish_str + f' expected {expected[metric]:.4f}'
# If there are bigger differences, deal with on a case-by-case basis.
elif abs(score-float(expected[metric])) <= \
flaky_dict.get(FlakyKey(collection=args.collection, name=name, topic_key=topic_key, metric=metric), 0):
result_str = okish_str
result_str = okish_str + f' expected {expected[metric]:.4f}'
else:
result_str = fail_str + f' expected {expected[metric]:.4f}'
print(f' {metric:7}: {score:.4f} {result_str}')
Expand All @@ -613,24 +614,24 @@ def run_conditions(args):
names = [ args.condition ]
else:
# Otherwise, print out all rows
names = models[args.collection]
names = models[args.collection]

for name in names:
if not name:
print('')
continue
print(f'{table_keys[name]:65}' +
f'{table[name]["dl19"]["MAP"]:8.4f}{table[name]["dl19"]["nDCG@10"]:8.4f}{table[name]["dl19"]["R@1K"]:8.4f} ' +
f'{table[name]["dl20"]["MAP"]:8.4f}{table[name]["dl20"]["nDCG@10"]:8.4f}{table[name]["dl20"]["R@1K"]:8.4f} ' +
f'{table[name]["dev"]["MRR@10"]:8.4f}{table[name]["dev"]["R@1K"]:8.4f}')
f'{table[name]["dl19"]["MAP"]:8.4f}{table[name]["dl19"]["nDCG@10"]:8.4f}{table[name]["dl19"]["R@1K"]:8.4f} ' +
f'{table[name]["dl20"]["MAP"]:8.4f}{table[name]["dl20"]["nDCG@10"]:8.4f}{table[name]["dl20"]["R@1K"]:8.4f} ' +
f'{table[name]["dev"]["MRR@10"]:8.4f}{table[name]["dev"]["R@1K"]:8.4f}')
else:
print(' ' * 69 + 'TREC 2021' + ' ' * 16 + 'TREC 2022' + ' ' * 16 + 'TREC 2023' + ' ' * 12 + 'MS MARCO dev' + ' ' * 5 + 'MS MARCO dev2')
print(' ' * 62 + 'MAP nDCG@10 R@1K MAP nDCG@10 R@1K MAP nDCG@10 R@1K MRR@100 R@1K MRR@100 R@1K')
print(' ' * 62 + '-' * 22 + ' ' + '-' * 22 + ' ' + '-' * 22 + ' ' + '-' * 14 + ' ' + '-' * 14)

if args.condition:
# If we've used --condition to specify a specific condition, print out only that row.
names = [ args.condition ]
names = [args.condition]
else:
# Otherwise, print out all rows
names = models[args.collection]
Expand All @@ -640,11 +641,11 @@ def run_conditions(args):
print('')
continue
print(f'{table_keys[name]:60}' +
f'{table[name]["dl21"]["MAP@100"]:8.4f}{table[name]["dl21"]["nDCG@10"]:8.4f}{table[name]["dl21"]["R@1K"]:8.4f} ' +
f'{table[name]["dl22"]["MAP@100"]:8.4f}{table[name]["dl22"]["nDCG@10"]:8.4f}{table[name]["dl22"]["R@1K"]:8.4f} ' +
f'{table[name]["dl23"]["MAP@100"]:8.4f}{table[name]["dl23"]["nDCG@10"]:8.4f}{table[name]["dl23"]["R@1K"]:8.4f} ' +
f'{table[name]["dev"]["MRR@100"]:8.4f}{table[name]["dev"]["R@1K"]:8.4f} ' +
f'{table[name]["dev2"]["MRR@100"]:8.4f}{table[name]["dev2"]["R@1K"]:8.4f}')
f'{table[name]["dl21"]["MAP@100"]:8.4f}{table[name]["dl21"]["nDCG@10"]:8.4f}{table[name]["dl21"]["R@1K"]:8.4f} ' +
f'{table[name]["dl22"]["MAP@100"]:8.4f}{table[name]["dl22"]["nDCG@10"]:8.4f}{table[name]["dl22"]["R@1K"]:8.4f} ' +
f'{table[name]["dl23"]["MAP@100"]:8.4f}{table[name]["dl23"]["nDCG@10"]:8.4f}{table[name]["dl23"]["R@1K"]:8.4f} ' +
f'{table[name]["dev"]["MRR@100"]:8.4f}{table[name]["dev"]["R@1K"]:8.4f} ' +
f'{table[name]["dev2"]["MRR@100"]:8.4f}{table[name]["dev2"]["R@1K"]:8.4f}')

end = time.time()
start_str = datetime.utcfromtimestamp(start).strftime('%Y-%m-%d %H:%M:%S')
Expand Down
Loading

0 comments on commit 77f8a81

Please sign in to comment.