Skip to content

Commit

Permalink
fix: when there is no train_metrics, do not checkpoint (mosaicml#2502)
Browse files Browse the repository at this point in the history
  • Loading branch information
furkanbiten committed Sep 5, 2023
1 parent 336bf8d commit 3865074
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,21 +904,22 @@ def state_dict(self) -> Dict[str, Any]:
'fsdp_config["state_dict_type"] = "full" to disable sharded checkpoints.'))
else:
serialized_value = {}
for k, v in attribute_value.items():
# No need to use __qualname__, we already know this corresponds to
# a metric object when we deserialize.
# Along with the rest of a Composer checkpoint, the state_dict() and _computed attributes of
# a Torchmetrics object are enough information to recreate it upon serialization. We only serialize
# the minimum metric information to maximize backwards compatibility --- old checkpoints
# will continue to be compatible even if other Torchmetrics attributes have changed.
# metric._computed stores the cached value of the previous metric computation
# We need to serialize this because it cannot always be recomputed from the state dict.
# See https://torchmetrics.readthedocs.io/en/stable/pages/implement.html#torchmetrics.Metric for more details
v.persistent(mode=True)
serialized_value[k] = {
'state_dict': v.state_dict(),
'_computed': v._computed,
}
if attribute_value is not None:
for k, v in attribute_value.items():
# No need to use __qualname__, we already know this corresponds to
# a metric object when we deserialize.
# Along with the rest of a Composer checkpoint, the state_dict() and _computed attributes of
# a Torchmetrics object are enough information to recreate it upon serialization. We only serialize
# the minimum metric information to maximize backwards compatibility --- old checkpoints
# will continue to be compatible even if other Torchmetrics attributes have changed.
# metric._computed stores the cached value of the previous metric computation
# We need to serialize this because it cannot always be recomputed from the state dict.
# See https://torchmetrics.readthedocs.io/en/stable/pages/implement.html#torchmetrics.Metric for more details
v.persistent(mode=True)
serialized_value[k] = {
'state_dict': v.state_dict(),
'_computed': v._computed,
}
elif attribute_name == 'eval_metrics':
if self.fsdp_sharded_state_dict_enabled:
serialized_value = None
Expand Down

0 comments on commit 3865074

Please sign in to comment.