From 8ee0aa9dcab3482587669a2faf4f72f7e6dea28f Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Thu, 12 Sep 2024 15:44:21 -0700 Subject: [PATCH] Don't include NaN metric values in ckpt paths (#896) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/896 Reviewed By: JKSenthil Differential Revision: D62469085 fbshipit-source-id: 746ba7d16390e2cc7fa513961f500317c73bcf06 --- .../callbacks/test_base_checkpointer.py | 17 ++++++++++ tests/utils/test_checkpoint.py | 31 +++++++++++++++++++ .../framework/callbacks/base_checkpointer.py | 7 +++++ torchtnt/utils/checkpoint.py | 6 ++++ 4 files changed, 61 insertions(+) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index 815e652cad..d40cd2c236 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -919,6 +919,23 @@ def test_get_tracked_metric_value(self) -> None: ): val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit) + val_loss_unit.val_loss = float("nan") # Test nan metric value + error_container = [] + with patch( + "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error", + side_effect=error_container.append, + ): + val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit) + + self.assertEqual( + [ + "Monitored metric 'val_loss' is NaN. Will not be included in checkpoint path, nor tracked for optimality." + ], + error_container, + ) + self.assertIsNone(val_loss) + + # test with mismatched monitored metric train_loss_ckpt_cb = BaseCheckpointSaver( dirpath="checkpoint", best_checkpoint_config=BestCheckpointConfig("train_loss", "max"), diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 49be9560b3..bbe4ad4024 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -47,6 +47,37 @@ class CheckpointPathTest(unittest.TestCase): + + def test_create_checkpoint_path(self) -> None: + # phase-naive and metric-naive + ckpt = CheckpointPath("foo", epoch=0, step=1) + self.assertEqual(ckpt.path, "foo/epoch_0_step_1") + + # phase-aware and metric-naive + ckpt = CheckpointPath("foo", epoch=0, step={Phase.TRAIN: 1}) + self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1") + + # phase-aware and metric-aware + ckpt = CheckpointPath( + "foo", + epoch=0, + step={Phase.TRAIN: 1, Phase.EVALUATE: 1}, + metric_data=MetricData("foo", 1.0), + ) + self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_foo=1.0") + + # nan metric value + with self.assertRaisesRegex( + ValueError, + "Value of monitored metric 'foo' can't be NaN in CheckpointPath.", + ): + CheckpointPath( + "foo", + epoch=0, + step={Phase.TRAIN: 1, Phase.EVALUATE: 1}, + metric_data=MetricData("foo", float("nan")), + ) + def test_from_str(self) -> None: # invalid paths malformed_paths = [ diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 751e381654..0b98737d65 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -8,6 +8,7 @@ import abc import logging +import math from datetime import timedelta from typing import Any, cast, Iterable, List, Literal, Optional, Union @@ -256,6 +257,12 @@ def _get_tracked_metric_value( "can be converted to float and is not a multi-element tensor value." ) from e + if metric_value_f and math.isnan(metric_value_f): + logger.error( + f"Monitored metric '{monitored_metric_name}' is NaN. Will not be included in checkpoint path, nor tracked for optimality." + ) + return None + return metric_value_f def on_train_start(self, state: State, unit: TTrainUnit) -> None: diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index d5ca9ea8aa..7b96f72b94 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -7,6 +7,7 @@ # pyre-strict import bisect import logging +import math import os import re from dataclasses import dataclass @@ -105,6 +106,11 @@ def __init__( step if isinstance(step, dict) else {Phase.NONE: step} ) + if metric_data and math.isnan(metric_data.value): + raise ValueError( + f"Value of monitored metric '{metric_data.name}' can't be NaN in CheckpointPath." + ) + @classmethod def from_str(cls, checkpoint_path: str) -> "CheckpointPath": """