Skip to content

Commit

Permalink
FIX: fix PR from FC comments and fix generate baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
nghi-truyen committed Jun 21, 2024
1 parent 56a7066 commit 41b7b2c
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 223 deletions.
23 changes: 9 additions & 14 deletions smash/core/signal_analysis/evaluation/_standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

if TYPE_CHECKING:
from smash.fcore._mwd_setup import SetupDT
from smash.util._typing import AnyTuple
from smash.util._typing import ListLike
from smash.util._typing import AnyTuple, ListLike


def _standardize_evaluation_metric(metric: str | ListLike[str]) -> list:
Expand All @@ -20,17 +19,19 @@ def _standardize_evaluation_metric(metric: str | ListLike[str]) -> list:
metric = [metric.lower()]

elif isinstance(metric, list):
for mtc in metric:
for i, mtc in enumerate(metric):
if isinstance(mtc, str):
if mtc.lower() not in METRICS:
raise ValueError(f"Unknown evaluation metric {mtc}. Choices: {METRICS}")
raise ValueError(
f"Unknown evaluation metric {mtc} at index {i} in metric. Choices: {METRICS}"
)
else:
raise TypeError(f"metric '{mtc}' must be str or a list of str")
raise TypeError("List of evaluation metrics must contain only str")

metric = [c.lower() for c in metric]

else:
raise TypeError("metric must be str or a list of str")
raise TypeError("Evaluation metric must be str or a list of str")

return metric

Expand All @@ -40,13 +41,7 @@ def _standardize_evaluation_start_end_eval(eval: str | pd.Timestamp | None, kind
et = pd.Timestamp(setup.end_time)

if eval is None:
if kind == "start":
eval = pd.Timestamp(st)
elif kind == "end":
eval = pd.Timestamp(et)
# % Should be unreachable
else:
pass
eval = pd.Timestamp(getattr(setup, f"{kind}_time"))

else:
if isinstance(eval, str):
Expand All @@ -60,7 +55,7 @@ def _standardize_evaluation_start_end_eval(eval: str | pd.Timestamp | None, kind
pass

else:
raise TypeError("{kind}_eval argument must be str or pandas.Timestamp object")
raise TypeError(f"{kind}_eval argument must be str or pandas.Timestamp object")

if (eval - st).total_seconds() < 0 or (et - eval).total_seconds() < 0:
raise ValueError(
Expand Down
3 changes: 1 addition & 2 deletions smash/core/signal_analysis/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from smash.core.signal_analysis.evaluation._standardize import _standardize_evaluation_args

if TYPE_CHECKING:
from smash.util._typing import ListLike

from pandas import Timestamp

from smash.core.model.model import Model
from smash.util._typing import ListLike


__all__ = ["evaluation"]
Expand Down
Binary file modified smash/tests/baseline.hdf5
Binary file not shown.
2 changes: 1 addition & 1 deletion smash/tests/core/signal_analysis/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def generic_evaluation(model: smash.Model, qs: np.ndarray, **kwargs) -> dict:
metrics = smash.evaluation(instance, metric=METRICS)

for i, m in enumerate(METRICS):
res[f"metrics.{m}"] = metrics[:, i]
res[f"evaluation.{m}"] = metrics[:, i]

return res

Expand Down
Loading

0 comments on commit 41b7b2c

Please sign in to comment.