Skip to content

Commit

Permalink
Save / load from checkpoint TP (#269)
Browse files Browse the repository at this point in the history
* [WIP] fix resume_from_checkpoint

* [WIP] fix resume_from_checkpoint

* Fix resume from checkpoint

* Save config file

* Fail if using ZeRO-1

* Add docstring

* Apply suggestions

* Styling

* Add test

* Fix

* Final fix

* Fix tests
  • Loading branch information
michaelbenayoun authored Oct 27, 2023
1 parent cfc098d commit 2e2fe40
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 69 deletions.
4 changes: 4 additions & 0 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,13 @@ def prepare_model_for_xla_fsdp(
def _prepare_model_for_tp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
if model in self._models or Parallelizer.was_parallelized(model):
return model

cpu_ids = [id(v) for v in model.parameters()]
# TODO: enable self.device (if needed).
model = self.state.tp_plugin.parallelize_model(model, device=None)

if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
Expand Down
5 changes: 5 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import enum
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -143,10 +144,13 @@ class TensorParallelismPlugin:
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True
sequence_parallel_enabled: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None

def __post_init__(self):
if self.tensor_parallel_size < 1:
raise ValueError(f"The tensor parallel size must be >= 1, but {self.tensor_parallel_size} was given here.")
if isinstance(self.checkpoint_dir, str):
self.checkpoint_dir = Path(self.checkpoint_dir)

@property
def should_parallelize(self):
Expand All @@ -163,5 +167,6 @@ def parallelize_model(
device=device,
parallelize_embeddings=self.parallelize_embeddings,
sequence_parallel_enabled=self.sequence_parallel_enabled,
checkpoint_dir=self.checkpoint_dir,
)
return parallelized_model
101 changes: 83 additions & 18 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Base class related to `neuronx_distributed` to perform parallelism."""

import contextlib
import gc
import shutil
from abc import ABC, abstractclassmethod
from dataclasses import asdict
Expand All @@ -28,6 +29,7 @@
from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from ..utils.deprecate_utils import deprecate
from ..utils.require_utils import requires_neuronx_distributed
from .parallel_layers import (
IOSequenceParallelizer,
LayerNormSequenceParallelizer,
Expand All @@ -37,14 +39,6 @@
from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, WeightInformation, load_tensor_for_weight


if is_neuronx_distributed_available():
import neuronx_distributed
from neuronx_distributed import parallel_layers

if is_torch_xla_available():
import torch_xla.core.xla_model as xm


if TYPE_CHECKING:
from transformers import PreTrainedModel

Expand Down Expand Up @@ -164,12 +158,14 @@ def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_paral
)

@classmethod
@requires_neuronx_distributed
def parallelize(
cls,
model: "PreTrainedModel",
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
sequence_parallel_enabled: bool = False,
checkpoint_dir: Optional[Union[str, Path]] = None,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts using
Expand All @@ -188,13 +184,18 @@ def parallelize(
This can be disabled in the case when the TP size does not divide the vocabulary size.
sequence_parallel_enabled (`bool`, defaults to `False`):
Whether or not sequence parallelism is enabled.
checkpoint_dir (`Optional[Union[str, Path]]`):
Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized
model.
Returns:
`PreTrainedModel`: The parallelized model.
"""
if sequence_parallel_enabled and cls.SEQUENCE_PARALLEL_LAYERNORM_PATTERNS is None:
raise NotImplementedError(f"Sequence parallelism is not supported for {model.__class__}.")

from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank

# Preparing the model for sequence parallelism:
# 1. Transforming the LayerNorms.
layer_norm_qualified_name_patterns = (
Expand Down Expand Up @@ -259,7 +260,7 @@ def parallelize(
# parallelization since those are the only classes that we initialize on the `meta` device.
num_dims = current_weight.dim()
partition_dim = getattr(current_weight, "partition_dim")
tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank()
tp_rank = get_tensor_model_parallel_rank()
size_per_rank = current_weight.size(partition_dim)
slices = [
None
Expand Down Expand Up @@ -298,14 +299,20 @@ def parallelize(
# `reset_parameters()` method.
mod.reset_parameters()

if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)

return model

@classmethod
def deparallelize(cls, model: "PreTrainedModel") -> "PreTrainedModel":
raise NotImplementedError

@classmethod
@requires_neuronx_distributed
def was_parallelized(cls, model: "PreTrainedModel") -> bool:
from neuronx_distributed import parallel_layers

parallel_layer_classes = (
parallel_layers.ParallelEmbedding,
parallel_layers.ColumnParallelLinear,
Expand Down Expand Up @@ -410,15 +417,24 @@ def _get_parameters_tp_metadata(cls, named_parameters: Dict[str, "torch.nn.Param
return tp_metadata

@classmethod
@requires_neuronx_distributed
def save_model_checkpoint_as_regular(
cls,
model: "PreTrainedModel",
output_dir: Union[str, Path],
optimizer: Optional["torch.optim.Optimizer"] = None,
):
cls._check_model_was_parallelized(model)
data_parallel_rank = parallel_layers.parallel_state.get_data_parallel_rank()
tensor_parallel_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank()

import neuronx_distributed
import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
)

data_parallel_rank = get_data_parallel_rank()
tensor_parallel_rank = get_tensor_model_parallel_rank()

if data_parallel_rank != 0:
return
Expand Down Expand Up @@ -454,6 +470,7 @@ def save_model_checkpoint_as_regular(
xm.rendezvous("saving regular checkpoint")

@classmethod
@requires_neuronx_distributed
def save_model_checkpoint_as_sharded(
cls,
model: "PreTrainedModel",
Expand All @@ -462,6 +479,8 @@ def save_model_checkpoint_as_sharded(
):
cls._check_model_was_parallelized(model)

import torch_xla.core.xla_model as xm
from neuronx_distributed import parallel_layers
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -508,12 +527,11 @@ def save_model_checkpoint(
cls.save_model_checkpoint_as_sharded(model, output_dir, optimizer=optimizer)

@classmethod
def load_model_regular_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
raise NotImplementedError("This requires being able to deparallelize the model.")

@classmethod
@requires_neuronx_distributed
def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]):
cls._check_model_was_parallelized(model)
from neuronx_distributed import parallel_layers

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)
parallel_layers.load(load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, model=model, sharded=True)
Expand All @@ -525,7 +543,54 @@ def load_model_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Pa

if (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir():
cls.load_model_sharded_checkpoint(model, load_dir)
elif (load_dir / WEIGHTS_NAME).is_file():
cls.load_model_regular_checkpoint(model, load_dir)
else:
raise FileNotFoundError(f"Could not find a checkpoint file under {load_dir.as_posix()}.")
raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.")

@classmethod
@requires_neuronx_distributed
def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]):
from neuronx_distributed.optimizer import NeuronZero1Optimizer

is_zero_1_optimizer = optimizer.__class__.__name__ == "NeuronAcceleratedOptimizer" and isinstance(
optimizer.optimizer, NeuronZero1Optimizer
)
is_zero_1_optimizer = is_zero_1_optimizer or isinstance(optimizer, NeuronZero1Optimizer)
if is_zero_1_optimizer:
raise NotImplementedError(
"It is not possible to load a sharded optimizer checkpoint when using ZeRO-1 yet."
)

if not isinstance(load_dir, Path):
load_dir = Path(load_dir)

import torch_xla.core.xla_model as xm
from neuronx_distributed.parallel_layers.parallel_state import (
get_pipeline_model_parallel_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_size,
)

world_size = get_tensor_model_parallel_size()
tp_rank = get_tensor_model_parallel_rank()
pp_rank = get_pipeline_model_parallel_rank()

if not (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir():
raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.")

checkpoint_name = load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME / f"tp_rank_{tp_rank:02d}_pp_rank{pp_rank:02d}.pt"

device = "xla"
for group in optimizer.param_groups:
for p in group["params"]:
device = p.device
break

for worker_start in range(0, world_size):
if tp_rank == worker_start:
checkpoint = torch.load(checkpoint_name, map_location="cpu")
optimizer_state_dict = checkpoint["optimizer_state_dict"]
xm.send_cpu_data_to_device(optimizer_state_dict, device)
optimizer.load_state_dict(optimizer_state_dict)
del checkpoint
gc.collect()
xm.rendezvous("neuron.load_checkpoint" + str(worker_start))
4 changes: 1 addition & 3 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,9 +821,7 @@ def generate(
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
if generation_config.use_cache:
warnings.warn(
"use_cache is not supported for generation on Neuron devices, switching to use_cache=False."
)
warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.")
model_kwargs["use_cache"] = False

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
Expand Down
Loading

0 comments on commit 2e2fe40

Please sign in to comment.