From d23eb30e3b906e4fabd68391c5fa187cb4885c1f Mon Sep 17 00:00:00 2001 From: Andy Z Date: Fri, 11 Aug 2023 13:35:15 -0700 Subject: [PATCH] Filter overlap ids based on helm subset --- .../benchmark/presentation/run_display.py | 4 +- src/helm/benchmark/presentation/summarize.py | 42 +++++++++++++------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/helm/benchmark/presentation/run_display.py b/src/helm/benchmark/presentation/run_display.py index fce259add11..1940afff79a 100644 --- a/src/helm/benchmark/presentation/run_display.py +++ b/src/helm/benchmark/presentation/run_display.py @@ -76,7 +76,7 @@ class DisplayRequest: most relevant request e.g. the request for the chosen cohice for multiple choice questions.""" -def _read_scenario_state(run_path: str) -> ScenarioState: +def read_scenario_state(run_path: str) -> ScenarioState: scenario_state_path: str = os.path.join(run_path, "scenario_state.json") if not os.path.exists(scenario_state_path): raise ValueError(f"Could not load ScenarioState from {scenario_state_path}") @@ -176,7 +176,7 @@ def write_run_display_json(run_path: str, run_spec: RunSpec, schema: Schema, ski ): hlog(f"Skipping writing display JSON for run {run_spec.name} because all output display JSON files exist.") return - scenario_state = _read_scenario_state(run_path) + scenario_state = read_scenario_state(run_path) per_instance_stats = _read_per_instance_stats(run_path) metric_names = _get_metric_names_for_groups(run_spec.groups, schema) diff --git a/src/helm/benchmark/presentation/summarize.py b/src/helm/benchmark/presentation/summarize.py index 8345e5fe4bd..f164c9be7ba 100644 --- a/src/helm/benchmark/presentation/summarize.py +++ b/src/helm/benchmark/presentation/summarize.py @@ -39,7 +39,7 @@ CONTAMINATION_STYLES, CONTAMINATION_LEVEL_STRONG, ) -from .run_display import write_run_display_json +from .run_display import write_run_display_json, read_scenario_state """ Reads the output of the benchmark runs and produces: @@ -405,18 +405,20 @@ def get_stats_file_metadata(data_overlap_dir: str) -> Dict[str, List[str]]: scenario_spec = light_scenario_key.scenario_spec num_instances = data_overlap_stats.num_instances n = data_overlap_stats_key.overlap_protocol_spec.n - """ - TODO: here we are currently just aggregating across all instance ids - for a given scenario rather than the subset run on HELM - """ - num_overlapping_inputs = len(data_overlap_stats.instance_ids_with_overlapping_input) - num_overlapping_references = len(data_overlap_stats.instance_ids_with_overlapping_reference) - if n == OVERLAP_N_COUNT: - scenario_spec_overlap_counts[scenario_spec] = ( - num_instances, - num_overlapping_inputs, - num_overlapping_references, + if scenario_spec in self.scenario_spec_instance_id_dict: + instance_ids = self.scenario_spec_instance_id_dict[scenario_spec] + num_overlapping_inputs = len( + set(data_overlap_stats.instance_ids_with_overlapping_input) & instance_ids ) + num_overlapping_references = len( + set(data_overlap_stats.instance_ids_with_overlapping_reference) & instance_ids + ) + if n == OVERLAP_N_COUNT: + scenario_spec_overlap_counts[scenario_spec] = ( + num_instances, + num_overlapping_inputs, + num_overlapping_references, + ) group_overlap_stats_list: List = [] for group, scenario_specs in group_to_scenario_specs.items(): @@ -1091,6 +1093,21 @@ def process(run: Run) -> None: parallel_map(process, self.runs, parallelism=self.num_threads) + def get_scenario_spec_instance_ids(self) -> None: + self.scenario_spec_instance_id_dict: Dict[ScenarioSpec, Set[str]] = dict() + for run in self.runs: + run_spec = run.run_spec + scenario_spec = run_spec.scenario_spec + if scenario_spec in self.scenario_spec_instance_id_dict: + continue + self.scenario_spec_instance_id_dict[scenario_spec] = set() + + run_path = run.run_path + scenario_state = read_scenario_state(run_path) + + for request_state in scenario_state.request_states: + self.scenario_spec_instance_id_dict[scenario_spec].add(request_state.instance.id) + def symlink_latest(output_path: str, suite: str) -> None: # Create a symlink runs/latest -> runs/, @@ -1135,6 +1152,7 @@ def main(): suite=args.suite, output_path=args.output_path, verbose=args.debug, num_threads=args.num_threads ) summarizer.read_runs() + summarizer.get_scenario_spec_instance_ids() summarizer.read_overlap_stats() summarizer.check_metrics_defined()