Skip to content

Commit

Permalink
Consolidation works
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 9, 2024
1 parent 2d9904c commit bfdb128
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
1 change: 0 additions & 1 deletion optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch

from ...distributed import ParallelizersManager
from ...utils import is_torch_xla_available


if TYPE_CHECKING:
Expand Down
21 changes: 7 additions & 14 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def consolidate_tensor_parallel_checkpoints(
continue
state_dicts.append(load_function(sharded_checkpoint.as_posix()))

parameter_names = state_dicts[0]["model"].keys()
parameter_names = state_dicts[0].keys()
sharded_metadatas = {
name: (
ParameterMetadata(**metadata["sharded_metadata"][name])
Expand Down Expand Up @@ -101,9 +101,9 @@ def consolidate_tensor_parallel_checkpoints(
# This might not be the case anymore when `ParameterMetadata` uses slices.
sharded_metadata = sharded_metadatas[name]
if sharded_metadata.is_tied:
consolidated_state_dict[original_name] = state_dicts[0]["model"][name].to("cpu")
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu")
else:
weights = [state_dict["model"][name] for state_dict in state_dicts]
weights = [state_dict[name] for state_dict in state_dicts]
tp_size = len(weights)
full_weight = torch.cat(
weights,
Expand Down Expand Up @@ -152,28 +152,21 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Union[str, Path]) ->

# Case 2: If no file was found, maybe the checkpoint was saved without xser.
if not sharded_checkpoints:
sharded_checkpoints = model_checkpoint_dir.glob("dp_rank_*.pt")
sharded_checkpoints = list(model_checkpoint_dir.glob("dp_rank_*.pt"))
load_function = torch.load

if not sharded_checkpoints:
raise ValueError(f"Could not find any sharded checkpoint in {model_checkpoint_dir.as_posix()}")

def get_checkpoint_name(checkpoint_path: Path) -> str:
name = checkpoint_path.name
if name == "checkpoint.pt":
name = checkpoint_path.parent.name
return name

pp_size = max((int(get_checkpoint_name(checkpoint_path)[-2:]) for checkpoint_path in sharded_checkpoints)) + 1
pp_size = max((int(checkpoint_path.stem[-2:]) for checkpoint_path in sharded_checkpoints)) + 1
checkpoints_grouped_by_pp_ranks = [[] for _ in range(pp_size)]
for pp_rank in range(pp_size):
for checkpoint_path in sharded_checkpoints:
# checkpoint_name = get_checkpoint_name(checkpoint_path)
checkpoint_name = checkpoint_path.name
checkpoint_name = checkpoint_path.stem
if int(checkpoint_name[-2:]) == pp_rank:
checkpoints_grouped_by_pp_ranks[pp_rank].append(checkpoint_path)

metadata_path = checkpoint_dir / "user_content"
metadata_path = checkpoint_dir / "user_content.pt"
metadata = torch.load(metadata_path)

consolidated_state_dict = {}
Expand Down

0 comments on commit bfdb128

Please sign in to comment.