From 3865074232f79f30365acc1179ef0fd560000ce8 Mon Sep 17 00:00:00 2001 From: furkanbiten Date: Wed, 6 Sep 2023 01:49:55 +0200 Subject: [PATCH] fix: when there is no train_metrics, do not checkpoint (#2502) --- composer/core/state.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/composer/core/state.py b/composer/core/state.py index 1062234252..219c8aec85 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -904,21 +904,22 @@ def state_dict(self) -> Dict[str, Any]: 'fsdp_config["state_dict_type"] = "full" to disable sharded checkpoints.')) else: serialized_value = {} - for k, v in attribute_value.items(): - # No need to use __qualname__, we already know this corresponds to - # a metric object when we deserialize. - # Along with the rest of a Composer checkpoint, the state_dict() and _computed attributes of - # a Torchmetrics object are enough information to recreate it upon serialization. We only serialize - # the minimum metric information to maximize backwards compatibility --- old checkpoints - # will continue to be compatible even if other Torchmetrics attributes have changed. - # metric._computed stores the cached value of the previous metric computation - # We need to serialize this because it cannot always be recomputed from the state dict. - # See https://torchmetrics.readthedocs.io/en/stable/pages/implement.html#torchmetrics.Metric for more details - v.persistent(mode=True) - serialized_value[k] = { - 'state_dict': v.state_dict(), - '_computed': v._computed, - } + if attribute_value is not None: + for k, v in attribute_value.items(): + # No need to use __qualname__, we already know this corresponds to + # a metric object when we deserialize. + # Along with the rest of a Composer checkpoint, the state_dict() and _computed attributes of + # a Torchmetrics object are enough information to recreate it upon serialization. We only serialize + # the minimum metric information to maximize backwards compatibility --- old checkpoints + # will continue to be compatible even if other Torchmetrics attributes have changed. + # metric._computed stores the cached value of the previous metric computation + # We need to serialize this because it cannot always be recomputed from the state dict. + # See https://torchmetrics.readthedocs.io/en/stable/pages/implement.html#torchmetrics.Metric for more details + v.persistent(mode=True) + serialized_value[k] = { + 'state_dict': v.state_dict(), + '_computed': v._computed, + } elif attribute_name == 'eval_metrics': if self.fsdp_sharded_state_dict_enabled: serialized_value = None