Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save / load from checkpoint TP #269

Merged
merged 13 commits into from
Oct 27, 2023
4 changes: 4 additions & 0 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,13 @@ def prepare_model_for_xla_fsdp(
def _prepare_model_for_tp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
if model in self._models or Parallelizer.was_parallelized(model):
return model

cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, device=None)

if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
Expand Down
5 changes: 5 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import enum
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -143,10 +144,13 @@ class TensorParallelismPlugin:
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True
sequence_parallel_enabled: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None

def __post_init__(self):
if self.tensor_parallel_size < 1:
raise ValueError(f"The tensor parallel size must be >= 1, but {self.tensor_parallel_size} was given here.")
if isinstance(self.checkpoint_dir, str):
self.checkpoint_dir = Path(self.checkpoint_dir)

@property
def should_parallelize(self):
Expand All @@ -163,5 +167,6 @@ def parallelize_model(
device=device,
parallelize_embeddings=self.parallelize_embeddings,
sequence_parallel_enabled=self.sequence_parallel_enabled,
checkpoint_dir=self.checkpoint_dir,
)
return parallelized_model
101 changes: 83 additions & 18 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Base class related to `neuronx_distributed` to perform parallelism."""

import contextlib
import gc
import shutil
from abc import ABC, abstractclassmethod
from dataclasses import asdict
Expand All @@ -28,6 +29,7 @@
from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from ..utils.deprecate_utils import deprecate
from ..utils.require_utils import requires_neuronx_distributed
from .parallel_layers import (
IOSequenceParallelizer,
LayerNormSequenceParallelizer,
Expand All @@ -37,14 +39,6 @@
from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, WeightInformation, load_tensor_for_weight


if is_neuronx_distributed_available():
import neuronx_distributed
from neuronx_distributed import parallel_layers

if is_torch_xla_available():
import torch_xla.core.xla_model as xm


if TYPE_CHECKING:
from transformers import PreTrainedModel

Expand Down Expand Up @@ -164,12 +158,14 @@ def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_paral
)

@classmethod
@requires_neuronx_distributed
def parallelize(
cls,
model: "PreTrainedModel",
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
sequence_parallel_enabled: bool = False,
checkpoint_dir: Optional[Union[str, Path]] = None,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts using
Expand All @@ -188,13 +184,18 @@ def parallelize(
This can be disabled in the case when the TP size does not divide the vocabulary size.
sequence_parallel_enabled (`bool`, defaults to `False`):
Whether or not sequence parallelism is enabled.
checkpoint_dir (`Optional[Union[str, Path]]`):
Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized
model.

Returns:
`PreTrainedModel`: The parallelized model.
"""
if sequence_parallel_enabled and cls.SEQUENCE_PARALLEL_LAYERNORM_PATTERNS is None:
raise NotImplementedError(f"Sequence parallelism is not supported for {model.__class__}.")

from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank

# Preparing the model for sequence parallelism:
# 1. Transforming the LayerNorms.
layer_norm_qualified_name_patterns = (
Expand Down Expand Up @@ -259,7 +260,7 @@ def parallelize(
# parallelization since those are the only classes that we initialize on the `meta` device.
num_dims = current_weight.dim()
partition_dim = getattr(current_weight, "partition_dim")
tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank()
tp_rank = get_tensor_model_parallel_rank()
size_per_rank = current_weight.size(partition_dim)
slices = [
None
Expand Down Expand Up @@ -298,14 +299,20 @@ def parallelize(
# `reset_parameters()` method.
mod.reset_parameters()

if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)

return model

@classmethod
def deparallelize(cls, model: "PreTrainedModel") -> "PreTrainedModel":
raise NotImplementedError

@classmethod
@requires_neuronx_distributed
def was_parallelized(cls, model: "PreTrainedModel") -> bool:
from neuronx_distributed import parallel_layers

parallel_layer_classes = (
parallel_layers.ParallelEmbedding,
parallel_layers.ColumnParallelLinear,
Expand Down Expand Up @@ -410,15 +417,24 @@ def _get_parameters_tp_metadata(cls, named_parameters: Dict[str, "torch.nn.Param
return tp_metadata

@classmethod
@requires_neuronx_distributed
def save_model_checkpoint_as_regular(
cls,
model: "PreTrainedModel",
output_dir: Union[str, Path],
optimizer: Optional["torch.optim.Optimizer"] = None,
):
cls._check_model_was_parallelized(model)
data_parallel_rank = parallel_layers.parallel_state.get_data_parallel_rank()
tensor_parallel_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank()

import neuronx_distributed
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

data_parallel_rank = get_data_parallel_rank()
tensor_parallel_rank = get_tensor_model_parallel_rank()

if data_parallel_rank != 0:
return
Expand Down Expand Up @@ -454,6 +470,7 @@ def save_model_checkpoint_as_regular(
xm.rendezvous("saving regular checkpoint")

@classmethod
@requires_neuronx_distributed
def save_model_checkpoint_as_sharded(
cls,
model: "PreTrainedModel",
Expand All @@ -462,6 +479,8 @@ def save_model_checkpoint_as_sharded(
):
cls._check_model_was_parallelized(model)

import torch_xla.core.xla_model as xm
from neuronx_distributed import parallel_layers
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -508,12 +527,11 @@ def save_model_checkpoint(
cls.save_model_checkpoint_as_sharded(model, output_dir, optimizer=optimizer)

@classmethod
def load_model_regular_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
raise NotImplementedError("This requires being able to deparallelize the model.")

@classmethod
@requires_neuronx_distributed
def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
cls._check_model_was_parallelized(model)
from neuronx_distributed import parallel_layers

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)
parallel_layers.load(load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, model=model, sharded=True)
Expand All @@ -525,7 +543,54 @@ def load_model_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Pa

if (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir():
cls.load_model_sharded_checkpoint(model, load_dir)
elif (load_dir / WEIGHTS_NAME).is_file():
cls.load_model_regular_checkpoint(model, load_dir)
else:
raise FileNotFoundError(f"Could not find a checkpoint file under {load_dir.as_posix()}.")
raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.")

@classmethod
@requires_neuronx_distributed
def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]):
from neuronx_distributed.optimizer import NeuronZero1Optimizer
dacorvo marked this conversation as resolved.
Show resolved Hide resolved

is_zero_1_optimizer = optimizer.__class__.__name__ == "NeuronAcceleratedOptimizer" and isinstance(
optimizer.optimizer, NeuronZero1Optimizer
)
is_zero_1_optimizer = is_zero_1_optimizer or isinstance(optimizer, NeuronZero1Optimizer)
if is_zero_1_optimizer:
raise NotImplementedError(
"It is not possible to load a sharded optimizer checkpoint when using ZeRO-1 yet."
)

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)

import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
get_pipeline_model_parallel_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_size,
)

world_size = get_tensor_model_parallel_size()
tp_rank = get_tensor_model_parallel_rank()
pp_rank = get_pipeline_model_parallel_rank()

if not (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir():
raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.")

checkpoint_name = load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME / f"tp_rank_{tp_rank:02d}_pp_rank{pp_rank:02d}.pt"

device = "xla"
for group in optimizer.param_groups:
for p in group["params"]:
device = p.device
break

for worker_start in range(0, world_size):
if tp_rank == worker_start:
checkpoint = torch.load(checkpoint_name, map_location="cpu")
optimizer_state_dict = checkpoint["optimizer_state_dict"]
xm.send_cpu_data_to_device(optimizer_state_dict, device)
optimizer.load_state_dict(optimizer_state_dict)
del checkpoint
gc.collect()
xm.rendezvous("neuron.load_checkpoint" + str(worker_start))
4 changes: 1 addition & 3 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,9 +821,7 @@ def generate(
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
if generation_config.use_cache:
warnings.warn(
"use_cache is not supported for generation on Neuron devices, switching to use_cache=False."
)
warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.")
model_kwargs["use_cache"] = False

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
Expand Down
Loading
Loading