Skip to content

Commit

Permalink
[WIP] integrate new API for saving and loading
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 9, 2024
1 parent 57fc9a5 commit c7a1f4b
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 143 deletions.
8 changes: 7 additions & 1 deletion optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,13 @@ def save_state_for_mp(self, output_dir: Optional[str] = None, **save_model_func_
def save_optimizer_func(accelerator, optimizer, model, output_dir, i):
logger.info("Saving parallel model and optimizer")
parallelizer = ParallelizersManager.parallelizer_for_model(model)
parallelizer.save_model_checkpoint(model, output_dir, as_regular=False, optimizer=optimizer)
parallelizer.save_model_sharded_checkpoint(
model,
output_dir,
optimizer=optimizer,
use_xser=self.state.mp_plugin.use_xser,
async_save=self.state.mp_plugin.async_save,
)
logger.info(f"Parallel model and optimizer saved to the directory {output_dir}")

return self._custom_save_state(
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class ModelParallelismPlugin:
gradient_checkpointing: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None
num_ranks_per_loading_step: int = -1
use_xser: bool = True
async_save: bool = False

def __post_init__(self):
if self.tensor_parallel_size < 1:
Expand Down
174 changes: 40 additions & 134 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import contextlib
import gc
import math
import shutil
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import asdict
Expand All @@ -27,7 +26,6 @@

import torch
from transformers import PreTrainedModel
from transformers.utils import WEIGHTS_NAME

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
Expand All @@ -42,7 +40,7 @@
SequenceCollectiveOpInfo,
)
from .utils import (
TENSOR_PARALLEL_SHARDS_DIR_NAME,
MODEL_PARALLEL_SHARDS_DIR_NAME,
OptimumGQAQKVColumnParallelLinear,
OptimumNeuronFXTracer,
ParameterMetadata,
Expand Down Expand Up @@ -762,7 +760,7 @@ def should_parallelize_layer_predicate_func(layer):

# TODO: can we optimize by skipping initialization and weight loading when `checkpoint_dir` is not None.
if not is_precompilation() and checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)
cls.load_model_sharded_checkpoint(model, checkpoint_dir)

model._gqa_qkv_metadata = gqa_qkv_metadata

Expand Down Expand Up @@ -919,162 +917,70 @@ def _get_parameters_tp_metadata(cls, named_parameters: Dict[str, "torch.nn.Param

@classmethod
@requires_neuronx_distributed
def save_model_checkpoint_as_regular(
cls,
model: "PreTrainedModel",
output_dir: Union[str, Path],
optimizer: Optional["torch.optim.Optimizer"] = None,
):
import neuronx_distributed
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

cls._check_model_was_parallelized(model)

data_parallel_rank = get_data_parallel_rank()
tensor_parallel_rank = get_tensor_model_parallel_rank()

if data_parallel_rank != 0:
return

if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

if is_main_worker() and optimizer is not None:
logger.warning(
"Saving the optimizer state as a regular file under the tensor parallel setting is not supported yet."
)

state_dict = {}
for name, param in model.named_parameters():
if getattr(param, "tensor_model_parallel", False):
if param.partition_dim == 1:
tensor = neuronx_distributed.utils.gather_from_tensor_model_parallel_region(param)
else:
# Because the gather works only on last dim. Need to make it work for all dims.
tensor = neuronx_distributed.utils.gather_from_tensor_model_parallel_region(
param.transpose()
).transpose()
else:
tensor = param
state_dict[name] = tensor

model_state_dict = {"model": state_dict}
should_save = tensor_parallel_rank == 0
xm._maybe_convert_to_cpu(model_state_dict, convert=should_save)
if should_save:
output_path = output_dir / WEIGHTS_NAME
torch.save(model_state_dict["model"], output_path.as_posix())
xm.rendezvous("saving regular checkpoint")

@classmethod
@requires_neuronx_distributed
def save_model_checkpoint_as_sharded(
def save_model_sharded_checkpoint(
cls,
model: Union["PreTrainedModel", "NxDPPModel"],
output_dir: Union[str, Path],
optimizer: Optional["torch.optim.Optimizer"] = None,
use_xser: bool = True,
async_save: bool = False,
):
import torch_xla.core.xla_model as xm
from neuronx_distributed import parallel_layers
from neuronx_distributed.pipeline import NxDPPModel
import neuronx_distributed

cls._check_model_was_parallelized(model)

if not isinstance(output_dir, Path):
output_dir = Path(output_dir)

if isinstance(model, NxDPPModel):
model_state_dict = model.local_state_dict()
else:
model_state_dict = model.state_dict()

state_dict = {"model": model_state_dict}
state_dict["sharded_metadata"] = {
metadata = {}
metadata["sharded_metadata"] = {
k: asdict(v) for k, v in cls._get_parameters_tp_metadata(dict(model.named_parameters())).items()
}
state_dict["gqa_qkv_metadata"] = model._gqa_qkv_metadata

if optimizer is not None:
# TODO: have metadata working for the optimizer.
state_dict["optimizer_state_dict"] = optimizer.state_dict()

output_path = output_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME

if is_main_worker():
if output_path.is_dir():
shutil.rmtree(output_path, ignore_errors=True)
output_path.mkdir()
xm.rendezvous("waiting before saving")
parallel_layers.save(state_dict, output_path.as_posix(), save_xser=True)
metadata["gqa_qkv_metadata"] = model._gqa_qkv_metadata

neuronx_distributed.trainer.save_checkpoint(
output_dir.as_posix(),
tag=MODEL_PARALLEL_SHARDS_DIR_NAME,
model=model,
optimizer=optimizer,
user_content=metadata,
use_xser=use_xser,
async_save=async_save,
)

@classmethod
def save_model_checkpoint(
@requires_neuronx_distributed
def load_sharded_checkpoint(
cls,
model: "PreTrainedModel",
output_dir: Union[str, Path],
as_regular: bool = False,
as_sharded: bool = True,
optimizer: Optional["torch.optim.Optimizer"] = None,
load_dir: Union[str, Path],
model: Optional["PreTrainedModel"] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
):
if not as_regular and not as_sharded:
raise ValueError("At least as_regular or as_sharded must be True.")
if as_regular:
cls.save_model_checkpoint_as_regular(model, output_dir, optimizer=optimizer)
if as_sharded:
cls.save_model_checkpoint_as_sharded(model, output_dir, optimizer=optimizer)

@classmethod
@requires_neuronx_distributed
def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
import neuronx_distributed

cls._check_model_was_parallelized(model)
if model is None and optimizer is None:
raise ValueError("At least a model or an optimizer must be provided")

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)
neuronx_distributed.parallel_layers.load(
load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME,
model_or_optimizer=model,
load_xser=True,
sharded=True,
)
if model is not None:
cls._check_model_was_parallelized(model)

@classmethod
def load_model_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
if not isinstance(load_dir, Path):
load_dir = Path(load_dir)

if (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir():
cls.load_model_sharded_checkpoint(model, load_dir)
else:
if not (load_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir():
raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.")

@classmethod
@requires_neuronx_distributed
def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]):
import neuronx_distributed
from neuronx_distributed.optimizer import NeuronZero1Optimizer

is_zero_1_optimizer = optimizer.__class__.__name__ == "NeuronAcceleratedOptimizer" and isinstance(
optimizer.optimizer, NeuronZero1Optimizer
neuronx_distributed.trainer.load_checkpoint(
load_dir.as_posix(),
tag=MODEL_PARALLEL_SHARDS_DIR_NAME,
model=model,
optimizer=optimizer,
)
is_zero_1_optimizer = is_zero_1_optimizer or isinstance(optimizer, NeuronZero1Optimizer)
if is_zero_1_optimizer:
raise NotImplementedError(
"It is not possible to load a sharded optimizer checkpoint when using ZeRO-1 yet."
)

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)
@classmethod
def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
return cls.load_sharded_checkpoint(load_dir, model=model)

neuronx_distributed.parallel_layers.load(
load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME,
model_or_optimizer=optimizer,
model_key="optimizer_state_dict",
load_xser=True,
sharded=True,
)
@classmethod
def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]):
return cls.load_sharded_checkpoint(load_dir, optimizer=optimizer)
8 changes: 4 additions & 4 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME

from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors
from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank
from .utils import MODEL_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank


def create_gqa_query_or_output_projection_weight_from_full_weight(
Expand Down Expand Up @@ -136,9 +136,9 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Union[str, Path]) ->
if not isinstance(checkpoint_dir, Path):
checkpoint_dir = Path(checkpoint_dir)

if checkpoint_dir.name != TENSOR_PARALLEL_SHARDS_DIR_NAME:
if (checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir():
checkpoint_dir = checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME
if checkpoint_dir.name != MODEL_PARALLEL_SHARDS_DIR_NAME:
if (checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir():
checkpoint_dir = checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME
else:
raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}")

Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, *args, **kwargs):
logger = logging.get_logger()


TENSOR_PARALLEL_SHARDS_DIR_NAME = "tensor_parallel_shards"
MODEL_PARALLEL_SHARDS_DIR_NAME = "shards"


@deprecate(
Expand Down
8 changes: 7 additions & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,13 @@ def _save_xla(self, output_dir: Optional[str] = None):

# This mark_step is needed to avoid hang issues.
xm.mark_step()
Parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer)
Parallelizer.save_model_sharded_checkpoint(
self.model,
output_dir,
optimizer=self.optimizer,
use_xser=self.accelerator.state.mp_plugin.use_xser,
async_save=self.accelerator.state.mp_plugin.async_save,
)
else:
safe_save_function_patcher = Patcher(
[("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)]
Expand Down
17 changes: 17 additions & 0 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ class NeuronTrainingArgumentsMixin:
)
},
)
use_xser: bool = field(
default=True,
metadata={
"help": "Whether to use `torch-xla` serialization when saving checkpoints when doing model parallelism"
},
)
async_save: bool = field(
default=False,
metadata={
"help": (
"Whether to use asynchronous saving method when doing model parallelism. It can boost saving "
"performance but will result in more host memory usage, increasing the risk of going OOM."
)
},
)

def __post_init__(self):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
Expand Down Expand Up @@ -165,6 +180,8 @@ def __post_init__(self):
gradient_checkpointing=self.gradient_checkpointing,
checkpoint_dir=resume_from_checkpoint,
num_ranks_per_loading_step=self.num_ranks_per_loading_step,
use_xser=self.use_xser,
async_save=self.async_save,
)

if self.bf16 and self.half_precision_backend == "amp":
Expand Down
4 changes: 2 additions & 2 deletions tests/distributed/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from optimum.neuron.accelerate.utils.dataclasses import NeuronDistributedType
from optimum.neuron.distributed.checkpointing import consolidate_model_parallel_checkpoints_to_unified_checkpoint
from optimum.neuron.distributed.utils import (
TENSOR_PARALLEL_SHARDS_DIR_NAME,
MODEL_PARALLEL_SHARDS_DIR_NAME,
make_optimizer_constructor_lazy,
)
from optimum.neuron.utils.import_utils import (
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch):
tensors_directory = f"{ref_data_file_name}.tensors"
assert not pytorch_checkpoint_exists
assert not safetensors_checkpoint_exists
assert TENSOR_PARALLEL_SHARDS_DIR_NAME in tmpdir_content
assert MODEL_PARALLEL_SHARDS_DIR_NAME in tmpdir_content
assert ref_data_file_name in tmpdir_content
assert tensors_directory in tmpdir_content
else:
Expand Down

0 comments on commit c7a1f4b

Please sign in to comment.