Skip to content

Commit

Permalink
Allow metric-naive checkpoints for missing or malformed metric values
Browse files Browse the repository at this point in the history
Differential Revision: D65452995
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Nov 5, 2024
1 parent 8150bcc commit cbd1684
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 31 deletions.
69 changes: 45 additions & 24 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,23 +760,32 @@ def test_keep_last_n_checkpoints_e2e(self) -> None:
)

def test_best_checkpoint_attr_missing(self) -> None:
bcs = BaseCheckpointSaver(
"foo",
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="train_loss",
mode="min",
),
)
with tempfile.TemporaryDirectory() as temp_dir:
bcs = BaseCheckpointSaver(
temp_dir,
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="train_loss",
mode="min",
),
)

state = get_dummy_train_state()
my_val_unit = MyValLossUnit()
state = get_dummy_train_state()
my_val_unit = MyValLossUnit()

with self.assertRaisesRegex(
RuntimeError,
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.",
):
bcs.on_train_epoch_end(state, my_val_unit)
error_container = []
with patch(
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
side_effect=error_container.append,
):
bcs.on_train_epoch_end(state, my_val_unit)

self.assertIn(
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint. Will not be included in checkpoint path, nor tracked for optimality.",
error_container,
)

self.assertTrue(os.path.exists(f"{temp_dir}/epoch_0_train_step_0"))

def test_best_checkpoint_no_top_k(self) -> None:
"""
Expand Down Expand Up @@ -1008,15 +1017,20 @@ def test_get_tracked_metric_value(self) -> None:

# pyre-ignore
val_loss_unit.val_loss = "hola" # Test weird metric value
with self.assertRaisesRegex(
RuntimeError,
(
"Unable to convert monitored metric val_loss to a float. Please ensure the value "
"can be converted to float and is not a multi-element tensor value."
),
error_container = []
with patch(
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
side_effect=error_container.append,
):
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)

self.assertIn(
"Unable to convert monitored metric val_loss to a float: could not convert string to float: 'hola'. "
"Please ensure the value can be converted to float and is not a multi-element tensor value. Will not be "
"included in checkpoint path, nor tracked for optimality.",
error_container,
)

val_loss_unit.val_loss = float("nan") # Test nan metric value
error_container = []
with patch(
Expand Down Expand Up @@ -1053,12 +1067,19 @@ def test_get_tracked_metric_value(self) -> None:
dirpath="checkpoint",
best_checkpoint_config=BestCheckpointConfig("train_loss", "max"),
)
with self.assertRaisesRegex(
RuntimeError,
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.",
error_container = []
with patch(
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.error",
side_effect=error_container.append,
):
val_loss = train_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)

self.assertIn(
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint. "
"Will not be included in checkpoint path, nor tracked for optimality.",
error_container,
)

ckpt_cb = BaseCheckpointSaver(
dirpath="checkpoint",
)
Expand Down
18 changes: 11 additions & 7 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,23 @@ def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:

monitored_metric_name = self._best_checkpoint_config.monitored_metric
if not hasattr(unit, monitored_metric_name):
raise RuntimeError(
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint."
logger.error(
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint. "
"Will not be included in checkpoint path, nor tracked for optimality."
)
return None

metric_value_f = None
if (metric_value := getattr(unit, monitored_metric_name)) is not None:
try:
metric_value_f = float(metric_value)
except ValueError as e:
raise RuntimeError(
f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value "
"can be converted to float and is not a multi-element tensor value."
) from e
except ValueError as exc:
logger.error(
f"Unable to convert monitored metric {monitored_metric_name} to a float: {exc}. Please ensure the value "
"can be converted to float and is not a multi-element tensor value. Will not be included in checkpoint path, "
"nor tracked for optimality."
)
return None

if metric_value_f and math.isnan(metric_value_f):
logger.error(
Expand Down

0 comments on commit cbd1684

Please sign in to comment.