Skip to content

Commit

Permalink
Fix torchmetrics backwards compatibility issue (mosaicml#2468)
Browse files Browse the repository at this point in the history
* add fix

* fix tests

* qwf

* dsfg

* add key

* remove short

* add map test

* remove comment

* filter warning

* simplify wrapping

* checkdown

* fix torchmetrics

* 300

* fix tests

* remove metric

* cleanup

* bug fixes

* fix lint

* fix lint

* fix test

* lint

* remove cuda

* fix tests

* fix ignore

* fix loading

* fix test

* save ckpt

---------

Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Your Name <[email protected]>
  • Loading branch information
4 people committed Aug 31, 2023
1 parent b8cc2ac commit 6c635cf
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 120 deletions.
108 changes: 50 additions & 58 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,41 +894,49 @@ def state_dict(self) -> Dict[str, Any]:
serialized_value = {type(obj).__qualname__: obj.state_dict() for obj in ensure_tuple(attribute_value)}
elif attribute_name == 'train_metrics':
serialized_value = {}
for k, v in attribute_value.items():
# No need to use __qualname__, we already know this corresponds to
# a metric object when we deserialize.
# Along with the rest of a Composer checkpoint, the state_dict() and _computed attributes of
# a Torchmetrics object are enough information to recreate it upon serialization. We only serialize
# the minimum metric information to maximize backwards compatibility --- old checkpoints
# will continue to be compatible even if other Torchmetrics attributes have changed.
# metric._computed stores the cached value of the previous metric computation
# We need to serialize this because it cannot always be recomputed from the state dict.
# See https://torchmetrics.readthedocs.io/en/stable/pages/implement.html#torchmetrics.Metric for more details
v.persistent(mode=True)
# We cast the metric tensor to a numpy array, so that FSDP doesn't mistake it for a tensor to be sharded upon load.
_computed = v._computed
_computed_device = str(_computed.device) if _computed is not None else None
_np_computed = _computed.cpu().numpy() if _computed is not None else None
serialized_value[k] = {
'state_dict': v.state_dict(),
'_computed': _np_computed,
'_computed_device': _computed_device
}
elif attribute_name == 'eval_metrics':
serialized_value = {}
for eval_key, eval_metrics in attribute_value.items():
serialized_value[eval_key] = {}
for k, v in eval_metrics.items():
if self.fsdp_sharded_state_dict_enabled:
# Sharded state dict breaks in many different ways with torchmetrics, due to both sharding
# metric tensors and only sometimes flattening path names in state dict and _computed, so
# we disable saving metrics with sharded checkpoints.
warnings.warn(
textwrap.dedent('Train metrics are not saved with sharded state dict as metric tensors will '
'be sharded and break on load. If you wish to save metric state, set '
'fsdp_config["state_dict_type"] = "full" to disable sharded checkpoints.'))
else:
for k, v in attribute_value.items():
# No need to use __qualname__, we already know this corresponds to
# a metric object when we deserialize.
# Along with the rest of a Composer checkpoint, the state_dict() and _computed attributes of
# a Torchmetrics object are enough information to recreate it upon serialization. We only serialize
# the minimum metric information to maximize backwards compatibility --- old checkpoints
# will continue to be compatible even if other Torchmetrics attributes have changed.
# metric._computed stores the cached value of the previous metric computation
# We need to serialize this because it cannot always be recomputed from the state dict.
# See https://torchmetrics.readthedocs.io/en/stable/pages/implement.html#torchmetrics.Metric for more details
v.persistent(mode=True)
# We cast the metric tensor to a numpy array, so that FSDP doesn't mistake it for a tensor to be sharded upon load.
_computed = v._computed
_computed_device = str(_computed.device) if _computed is not None else None
_np_computed = _computed.cpu().numpy() if _computed is not None else None
serialized_value[eval_key][k] = {
serialized_value[k] = {
'state_dict': v.state_dict(),
'_computed': _np_computed,
'_computed_device': _computed_device
'_computed': v._computed,
}
elif attribute_name == 'eval_metrics':
serialized_value = {}
if self.fsdp_sharded_state_dict_enabled:
# Sharded state dict breaks in many different ways with torchmetrics, due to both sharding
# metric tensors and only sometimes flattening path names in state dict and _computed, so
# we disable saving metrics with sharded checkpoints.
warnings.warn(
textwrap.dedent('Eval metrics are not saved with sharded state dict as metric tensors will '
'be sharded and break on load. If you wish to save metric state, set '
'fsdp_config["state_dict_type"] = "full" to disable sharded checkpoints.'))
else:
for eval_key, eval_metrics in attribute_value.items():
serialized_value[eval_key] = {}
for k, v in eval_metrics.items():
v.persistent(mode=True)
serialized_value[eval_key][k] = {
'state_dict': v.state_dict(),
'_computed': v._computed,
}
else:
serialized_value = attribute_value

Expand Down Expand Up @@ -1265,24 +1273,16 @@ def load_state_dict(
serialized_value[metric_name]._state_dict_pre_hooks = OrderedDict()
metric_state_dict = serialized_value[metric_name].state_dict()
metric_computed_field = serialized_value[metric_name]._computed
# The metric tensor is saved as a numpy array, so that FSDP doesn't mistake it for a tensor to be sharded upon load.
# So we have to cast it back to a torch tensor.
metric_computed_device = getattr(serialized_value[metric_name], '_computed_device', None)
if metric_computed_field is not None:
metric_computed_field = torch.from_numpy(metric_computed_field) if isinstance(
metric_computed_field, np.ndarray) else metric_computed_field
if metric_computed_device is not None:
metric_computed_field = metric_computed_field.to(metric_computed_device)
elif isinstance(serialized_value[metric_name], dict):
# The metric tensor is saved as a numpy array, so that FSDP doesn't mistake it for a tensor to be sharded upon load.
# So we have to cast it back to a torch tensor.
# For checkpoints saved using Composer >= 0.14
metric_state_dict = serialized_value[metric_name]['state_dict']
metric_computed_field = serialized_value[metric_name]['_computed']
metric_computed_device = serialized_value[metric_name].get('_computed_device', None)
if metric_computed_field is not None:
metric_computed_field = torch.from_numpy(metric_computed_field) if isinstance(
metric_computed_field, np.ndarray) else metric_computed_field
# Backwards compatible loading of torchmetrics from 0.16.0 which casted metric tensors to numpy
if isinstance(metric_computed_field, np.ndarray):
metric_computed_field = torch.from_numpy(metric_computed_field)
metric_computed_device = serialized_value[metric_name].get('_computed_device', None)
if metric_computed_device is not None:
metric_computed_field = metric_computed_field.to(metric_computed_device)
else:
Expand Down Expand Up @@ -1321,25 +1321,17 @@ def load_state_dict(
serialized_value[eval_key][metric_name]._state_dict_pre_hooks = OrderedDict()
eval_metric_state_dict = serialized_value[eval_key][metric_name].state_dict()
eval_metric_computed_field = serialized_value[eval_key][metric_name]._computed
elif isinstance(serialized_value[eval_key][metric_name], dict):
# The metric tensor is saved as a numpy array, so that FSDP doesn't mistake it for a tensor to be sharded upon load.
# So we have to cast it back to a torch tensor.
eval_metric_computed_device = getattr(serialized_value[eval_key][metric_name],
'_computed_device', None)
if eval_metric_computed_field is not None:
eval_metric_computed_field = torch.from_numpy(eval_metric_computed_field) if isinstance(
eval_metric_computed_field, np.ndarray) else eval_metric_computed_field
if eval_metric_computed_device is not None:
eval_metric_computed_field = eval_metric_computed_field.to(eval_metric_computed_device)
elif isinstance(serialized_value[eval_key][metric_name], dict):
# For checkpoints saved using Composer >= 0.14
eval_metric_state_dict = serialized_value[eval_key][metric_name]['state_dict']
eval_metric_computed_field = serialized_value[eval_key][metric_name]['_computed']
# The metric tensor is saved as a numpy array, so that FSDP doesn't mistake it for a tensor to be sharded upon load.
# So we have to cast it back to a torch tensor.
eval_metric_computed_device = serialized_value[eval_key][metric_name]['_computed_device']
if eval_metric_computed_field is not None:
eval_metric_computed_field = torch.from_numpy(eval_metric_computed_field) if isinstance(
eval_metric_computed_field, np.ndarray) else eval_metric_computed_field
# Backwards compatible loading of torchmetrics from 0.16.0 which casted metric tensors to numpy
if isinstance(eval_metric_computed_field, np.ndarray):
eval_metric_computed_field = torch.from_numpy(eval_metric_computed_field)
eval_metric_computed_device = serialized_value[eval_key][metric_name].get(
'_computed_device', None)
if eval_metric_computed_device is not None:
eval_metric_computed_field = eval_metric_computed_field.to(
eval_metric_computed_device)
Expand Down
43 changes: 17 additions & 26 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,24 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):

# We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph.
with torch.no_grad():
# 1. Load just model first.
model_state_dict = {'state': {'model': state.state_dict()['model']}}
# 1. Load model and metadata first
model_state_dict = None
if load_weights_only:
model_state_dict = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(cur_state_dict)
cur_state_dict.pop('optimizers')
model_state_dict = {'state': cur_state_dict}

dist_cp.load_state_dict(model_state_dict, storage_reader)

state.load_model_state(
state.load_state_dict(
model_state_dict['state'],
logger,
strict=strict_model_weights,
Expand All @@ -419,7 +432,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
storage_reader=storage_reader)
state.load_optim_state(optim_state)

# 3. Optionally, load RNG.
# 3. Optionally load RNG
rng_state_dicts = reproducibility.get_rng_state()
if not load_weights_only:
# If we are resuming on more ranks than were used at save time we only want to load in rngs for those ranks
Expand All @@ -434,28 +447,6 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
rng_state_dicts_load['rng'].extend(rng_state_dicts[num_ranks_that_saved_rng:])
rng_state_dicts = rng_state_dicts_load['rng']

# 4. Optionally, load the rest of state.
if not load_weights_only:
cur_state_dict = state.state_dict()

if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(cur_state_dict)

# Remove model and optimizers because they were already loaded.
cur_state_dict.pop('model')
cur_state_dict.pop('optimizers')

rest_of_the_state_dict = {'state': cur_state_dict}
dist_cp.load_state_dict(rest_of_the_state_dict, storage_reader)
state.load_state_dict(
rest_of_the_state_dict['state'],
logger,
)

return rng_state_dicts


Expand Down
12 changes: 10 additions & 2 deletions tests/common/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,16 @@ def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: floa
# Increase update count so Torchmetrics doesn't throw warning when computing two metrics which haven't been updated
item1._update_count += 1
item2._update_count += 1
assert item1.compute().allclose(item2.compute(), atol=atol,
rtol=rtol), f'{path} differs: {item1.compute()} != {item2.compute()}'
item1_compute = item1.compute()
item2_compute = item2.compute()
if isinstance(item1_compute, torch.Tensor) and isinstance(item2_compute, torch.Tensor):
assert item1_compute.allclose(item2_compute, atol=atol, rtol=rtol,
equal_nan=True), f'{path} differs: {item1_compute} != {item2_compute}'
elif isinstance(item1_compute, dict):
assert isinstance(item2_compute, dict)
_check_dict_recursively(item1_compute, item2_compute, path, atol, rtol)
else:
assert 'Torchmetric compute() returned unexpected type, please add support in `_check_item`'
item1._update_count -= 1
item2._update_count -= 1
return
Expand Down
59 changes: 58 additions & 1 deletion tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from composer.callbacks import CheckpointSaver
from composer.core import Callback, Time, TimeUnit
from composer.loggers import RemoteUploaderDownloader, remote_uploader_downloader
from composer.metrics import MAP
from composer.optim import ExponentialScheduler
from composer.trainer import trainer
from composer.trainer.trainer import Trainer
Expand Down Expand Up @@ -427,7 +428,13 @@ def _metrics_equal(self, train_metrics_1, train_metrics_2, eval_metrics_1, eval_
except AssertionError:
return False

def get_trainer(self, model=None, max_duration='2ep', latest_filename='latest-rank{rank}.pt', **kwargs):
def get_trainer(
self,
model=None,
max_duration='2ep',
latest_filename='latest-rank{rank}.pt',
**kwargs,
):
if model is None:
model = SimpleConvModel()
optimizer = torch.optim.Adam(model.parameters())
Expand Down Expand Up @@ -563,6 +570,56 @@ def test_other_backends_error(self, load_path: str, monkeypatch: MonkeyPatch):
with pytest.raises(NotImplementedError):
self.get_trainer(load_path=load_path)

def test_load_map(self, tmp_path: pathlib.Path):
map_metric = MAP()

targets = [
{
'boxes': torch.tensor([[258.15, 41.29, 606.41, 285.07]]),
'labels': torch.tensor([4]),
}, # coco image id 42
{
'boxes': torch.tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]),
'labels': torch.tensor([3, 2]),
}, # coco image id 73
]

# Perfect result
predictions = [
{
'boxes': torch.tensor([[258.15, 41.29, 606.41, 285.07]]),
'scores': torch.tensor([0.236]),
'labels': torch.tensor([4]),
}, # coco image id 42
{
'boxes': torch.tensor([[61.00, 22.75, 565.00, 632.42], [12.66, 3.32, 281.26, 275.23]]),
'scores': torch.tensor([0.318, 0.726]),
'labels': torch.tensor([3, 2]),
}, # coco image id 73
]

map_metric.update(predictions, targets)
map_metric.compute()

model_1 = SimpleConvModel()
model_1.train_metrics = map_metric
trainer_1 = self.get_trainer(
model=model_1,
save_folder=str(tmp_path),
)
trainer_1.save_checkpoint('latest-rank0.pt')

model_2 = SimpleConvModel()
model_2.train_metrics = MAP()
trainer_2 = self.get_trainer(
model=model_2,
load_path=str(tmp_path / 'latest-rank0.pt'),
)

assert self._metrics_equal(
trainer_1.state.train_metrics, trainer_2.state.train_metrics, trainer_1.state.eval_metrics,
trainer_2.state.eval_metrics), 'Original metrics do not equal metrics from loaded checkpoint.'

@pytest.mark.parametrize('missing_key', [True, False])
@pytest.mark.parametrize('unexpected_key', [True, False])
def test_strict_errors(self, missing_key: bool, unexpected_key: bool):
Expand Down
Loading

0 comments on commit 6c635cf

Please sign in to comment.