Skip to content

Commit

Permalink
Filter overlap ids based on helm subset
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy Z authored and Andy Z committed Aug 11, 2023
1 parent c495533 commit d23eb30
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/helm/benchmark/presentation/run_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class DisplayRequest:
most relevant request e.g. the request for the chosen cohice for multiple choice questions."""


def _read_scenario_state(run_path: str) -> ScenarioState:
def read_scenario_state(run_path: str) -> ScenarioState:
scenario_state_path: str = os.path.join(run_path, "scenario_state.json")
if not os.path.exists(scenario_state_path):
raise ValueError(f"Could not load ScenarioState from {scenario_state_path}")
Expand Down Expand Up @@ -176,7 +176,7 @@ def write_run_display_json(run_path: str, run_spec: RunSpec, schema: Schema, ski
):
hlog(f"Skipping writing display JSON for run {run_spec.name} because all output display JSON files exist.")
return
scenario_state = _read_scenario_state(run_path)
scenario_state = read_scenario_state(run_path)
per_instance_stats = _read_per_instance_stats(run_path)

metric_names = _get_metric_names_for_groups(run_spec.groups, schema)
Expand Down
42 changes: 30 additions & 12 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
CONTAMINATION_STYLES,
CONTAMINATION_LEVEL_STRONG,
)
from .run_display import write_run_display_json
from .run_display import write_run_display_json, read_scenario_state

"""
Reads the output of the benchmark runs and produces:
Expand Down Expand Up @@ -405,18 +405,20 @@ def get_stats_file_metadata(data_overlap_dir: str) -> Dict[str, List[str]]:
scenario_spec = light_scenario_key.scenario_spec
num_instances = data_overlap_stats.num_instances
n = data_overlap_stats_key.overlap_protocol_spec.n
"""
TODO: here we are currently just aggregating across all instance ids
for a given scenario rather than the subset run on HELM
"""
num_overlapping_inputs = len(data_overlap_stats.instance_ids_with_overlapping_input)
num_overlapping_references = len(data_overlap_stats.instance_ids_with_overlapping_reference)
if n == OVERLAP_N_COUNT:
scenario_spec_overlap_counts[scenario_spec] = (
num_instances,
num_overlapping_inputs,
num_overlapping_references,
if scenario_spec in self.scenario_spec_instance_id_dict:
instance_ids = self.scenario_spec_instance_id_dict[scenario_spec]
num_overlapping_inputs = len(
set(data_overlap_stats.instance_ids_with_overlapping_input) & instance_ids
)
num_overlapping_references = len(
set(data_overlap_stats.instance_ids_with_overlapping_reference) & instance_ids
)
if n == OVERLAP_N_COUNT:
scenario_spec_overlap_counts[scenario_spec] = (
num_instances,
num_overlapping_inputs,
num_overlapping_references,
)

group_overlap_stats_list: List = []
for group, scenario_specs in group_to_scenario_specs.items():
Expand Down Expand Up @@ -1091,6 +1093,21 @@ def process(run: Run) -> None:

parallel_map(process, self.runs, parallelism=self.num_threads)

def get_scenario_spec_instance_ids(self) -> None:
self.scenario_spec_instance_id_dict: Dict[ScenarioSpec, Set[str]] = dict()
for run in self.runs:
run_spec = run.run_spec
scenario_spec = run_spec.scenario_spec
if scenario_spec in self.scenario_spec_instance_id_dict:
continue
self.scenario_spec_instance_id_dict[scenario_spec] = set()

run_path = run.run_path
scenario_state = read_scenario_state(run_path)

for request_state in scenario_state.request_states:
self.scenario_spec_instance_id_dict[scenario_spec].add(request_state.instance.id)


def symlink_latest(output_path: str, suite: str) -> None:
# Create a symlink runs/latest -> runs/<name_of_suite>,
Expand Down Expand Up @@ -1135,6 +1152,7 @@ def main():
suite=args.suite, output_path=args.output_path, verbose=args.debug, num_threads=args.num_threads
)
summarizer.read_runs()
summarizer.get_scenario_spec_instance_ids()
summarizer.read_overlap_stats()
summarizer.check_metrics_defined()

Expand Down

0 comments on commit d23eb30

Please sign in to comment.