Skip to content

Commit

Permalink
more prints
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 5, 2024
1 parent db1431d commit b114fef
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,7 @@ def get_optim_state_dict(self) -> dict[str, Any]:
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):
ic(1)
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand All @@ -1015,8 +1016,9 @@ def get_optim_state_dict(self) -> dict[str, Any]:
'fsdp_state_dict_type to None, "full", or "sharded".',
),
)

ic(2)
optimizer = ensure_tuple(self.optimizers)[0]
ic(3)
optim_state_dict = get_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
Expand All @@ -1026,6 +1028,7 @@ def get_optim_state_dict(self) -> dict[str, Any]:
cpu_offload=self.fsdp_enabled,
),
)
ic(4)
return {type(optimizer).__qualname__: optim_state_dict}
else:
optimizer = ensure_tuple(self.optimizers)[0]
Expand All @@ -1046,6 +1049,7 @@ def state_dict(self) -> dict[str, Any]:
"""
state_dict = {}
for attribute_name in self.serialized_attributes:
ic(attribute_name)
attribute_value = getattr(self, attribute_name)
if attribute_name == 'dataset_state':
serialized_value = self._dataset_state_dict()
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def load_sharded_checkpoint(
num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
state_dict: dict[str, Any] = {
'state': cur_state_dict,
'rng': reproducibility.get_rng_state()[:num_rng_ranks],
'rng': 42 # reproducibility.get_rng_state()[:num_rng_ranks],
}

if ignore_keys:
Expand Down Expand Up @@ -1144,7 +1144,7 @@ def _save_checkpoint(
ic('before reproducibility.get_rng_state()')
state_dict = {
'state': state.state_dict(),
'rng': reproducibility.get_rng_state(),
'rng': 42 #reproducibility.get_rng_state(),
}
ic('after reproducibility.get_rng_state()')

Expand Down

0 comments on commit b114fef

Please sign in to comment.