Skip to content

Commit

Permalink
Don't include NaN metric values in ckpt paths (#896)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #896

Reviewed By: JKSenthil

Differential Revision: D62469085

fbshipit-source-id: 746ba7d16390e2cc7fa513961f500317c73bcf06
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 12, 2024
1 parent 33b98f4 commit 8ee0aa9
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,23 @@ def test_get_tracked_metric_value(self) -> None:
):
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)

val_loss_unit.val_loss = float("nan") # Test nan metric value
error_container = []
with patch(
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
side_effect=error_container.append,
):
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)

self.assertEqual(
[
"Monitored metric 'val_loss' is NaN. Will not be included in checkpoint path, nor tracked for optimality."
],
error_container,
)
self.assertIsNone(val_loss)

# test with mismatched monitored metric
train_loss_ckpt_cb = BaseCheckpointSaver(
dirpath="checkpoint",
best_checkpoint_config=BestCheckpointConfig("train_loss", "max"),
Expand Down
31 changes: 31 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,37 @@


class CheckpointPathTest(unittest.TestCase):

def test_create_checkpoint_path(self) -> None:
# phase-naive and metric-naive
ckpt = CheckpointPath("foo", epoch=0, step=1)
self.assertEqual(ckpt.path, "foo/epoch_0_step_1")

# phase-aware and metric-naive
ckpt = CheckpointPath("foo", epoch=0, step={Phase.TRAIN: 1})
self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1")

# phase-aware and metric-aware
ckpt = CheckpointPath(
"foo",
epoch=0,
step={Phase.TRAIN: 1, Phase.EVALUATE: 1},
metric_data=MetricData("foo", 1.0),
)
self.assertEqual(ckpt.path, "foo/epoch_0_train_step_1_eval_step_1_foo=1.0")

# nan metric value
with self.assertRaisesRegex(
ValueError,
"Value of monitored metric 'foo' can't be NaN in CheckpointPath.",
):
CheckpointPath(
"foo",
epoch=0,
step={Phase.TRAIN: 1, Phase.EVALUATE: 1},
metric_data=MetricData("foo", float("nan")),
)

def test_from_str(self) -> None:
# invalid paths
malformed_paths = [
Expand Down
7 changes: 7 additions & 0 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import abc
import logging
import math
from datetime import timedelta
from typing import Any, cast, Iterable, List, Literal, Optional, Union

Expand Down Expand Up @@ -256,6 +257,12 @@ def _get_tracked_metric_value(
"can be converted to float and is not a multi-element tensor value."
) from e

if metric_value_f and math.isnan(metric_value_f):
logger.error(
f"Monitored metric '{monitored_metric_name}' is NaN. Will not be included in checkpoint path, nor tracked for optimality."
)
return None

return metric_value_f

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
Expand Down
6 changes: 6 additions & 0 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict
import bisect
import logging
import math
import os
import re
from dataclasses import dataclass
Expand Down Expand Up @@ -105,6 +106,11 @@ def __init__(
step if isinstance(step, dict) else {Phase.NONE: step}
)

if metric_data and math.isnan(metric_data.value):
raise ValueError(
f"Value of monitored metric '{metric_data.name}' can't be NaN in CheckpointPath."
)

@classmethod
def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
"""
Expand Down

0 comments on commit 8ee0aa9

Please sign in to comment.