-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate_patches.py
89 lines (72 loc) · 2.62 KB
/
evaluate_patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from concurrent.futures import ThreadPoolExecutor, as_completed
from elleelleaime.core.utils.benchmarks import get_benchmark
from elleelleaime.core.benchmarks.bug import Bug
from elleelleaime.core.utils.jsonl import stream_jsonl, write_jsonl
from elleelleaime.evaluate.strategies.registry import PatchEvaluationStrategyRegistry
from pathlib import Path
import numpy as np
import fire
import sys
import tqdm
import logging
import json
import os
def evaluate_candidate(bug: Bug, sample: dict, strategy: str, **kwargs) -> dict:
"""
Evaluates the candidate patch for the given sample.
"""
evaluation_strategy = PatchEvaluationStrategyRegistry(**kwargs).get_evaluation(
strategy
)
evaluation = evaluation_strategy.evaluate(bug, sample)
sample["evaluation"] = evaluation
return sample
def entry_point(
benchmark: str,
samples_path: str,
strategy: str,
n_workers: int = 4,
**kwargs,
):
"""
Evaluates the candidate patches given the samples,
and writes the results to f"evaluation_{benchmark}_{prompt_strategy}_{model_name}.jsonl"
"""
# Get the benchmark, check if it exists, and initialize it
samples_file_name = os.path.basename(samples_path)
dir_path = os.path.dirname(samples_path)
prompt_strategy = samples_file_name.split("_")[2].split(".")[0]
model_name = samples_file_name.split("_")[3].split(".")[0]
# Read the samples
logging.info("Reading samples...")
samples = list(stream_jsonl(samples_path))
benchmark_obj = get_benchmark(benchmark)
if benchmark_obj is None:
raise ValueError(f"Unknown benchmark {benchmark}")
benchmark_obj.initialize()
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = []
for sample in tqdm.tqdm(samples, "Lauching candidate evaluation..."):
bug = benchmark_obj.get_bug(sample["identifier"])
if bug is None:
raise ValueError(f"Unknown bug {sample['identifier']}")
futures.append(
executor.submit(evaluate_candidate, bug, sample, strategy, **kwargs)
)
logging.info("Evaluating candidates...")
results = []
for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
results.append(future.result())
samples = results
# Write results to jsonl file
write_jsonl(
os.path.join(
dir_path, f"evaluation_{benchmark}_{prompt_strategy}_{model_name}.jsonl"
),
samples,
)
def main():
logging.getLogger().setLevel(logging.INFO)
fire.Fire(entry_point)
if __name__ == "__main__":
sys.exit(main())