From 8e83dc55ce03106113938ce00dad6f115f2aaf3b Mon Sep 17 00:00:00 2001 From: Mike Lay Date: Fri, 11 Aug 2023 16:10:17 -0700 Subject: [PATCH] Fix GitHub source link in scenario.json (#1762) --- src/helm/benchmark/scenarios/scenario.py | 14 +++++++++----- src/helm/benchmark/scenarios/test_scenario.py | 6 ++++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/helm/benchmark/scenarios/scenario.py b/src/helm/benchmark/scenarios/scenario.py index b448363ef9..73e3aa214c 100644 --- a/src/helm/benchmark/scenarios/scenario.py +++ b/src/helm/benchmark/scenarios/scenario.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field, replace from typing import List, Optional, Tuple -import re +from pathlib import PurePath import inspect from helm.common.object_spec import ObjectSpec, create_object @@ -200,10 +200,14 @@ class Scenario(ABC): """Where the scenario subclass for `self` is defined.""" def __post_init__(self) -> None: - # Assume `/.../src/helm/benchmark/...` - path = inspect.getfile(type(self)) - # Strip out prefix in absolute path and replace with GitHub link. - self.definition_path = re.sub(r"^.*\/src/", "https://github.com/stanford-crfm/helm/blob/main/src/", path) + parts = list(PurePath(inspect.getfile(type(self))).parts) + path = parts.pop() + parts.reverse() + for part in parts: + path = part + "/" + path + if part == "helm": + break + self.definition_path = "https://github.com/stanford-crfm/helm/blob/main/src/" + path @abstractmethod def get_instances(self) -> List[Instance]: diff --git a/src/helm/benchmark/scenarios/test_scenario.py b/src/helm/benchmark/scenarios/test_scenario.py index 1eb083275b..ea8193715c 100644 --- a/src/helm/benchmark/scenarios/test_scenario.py +++ b/src/helm/benchmark/scenarios/test_scenario.py @@ -35,6 +35,12 @@ def test_render_lines(self): "}", ] + def test_definition_path(self): + assert ( + self.scenario.definition_path + == "https://github.com/stanford-crfm/helm/blob/main/src/helm/benchmark/scenarios/simple_scenarios.py" + ) + def test_input_equality(): input1 = Input(text="input1")