Skip to content

Commit

Permalink
[WIP] Zero-1
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Jul 17, 2023
1 parent 9e86c98 commit 9fcbc6e
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 27 deletions.
64 changes: 61 additions & 3 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

if is_torch_xla_available():
import torch_xla.core.xla_model as xm
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
else:
xm = None

Expand All @@ -68,7 +69,7 @@
# TODO: should we do a XLAFSDPNeuronAccelerator instead?
class NeuronAccelerator(Accelerator):
# @patch_within_function(("accelerate.accelerator.AcceleratorState", NeuronAcceleratorState))
def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, **kwargs):
def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, zero_1: bool = False, **kwargs):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_tpu_available()

Expand All @@ -92,8 +93,15 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, *
full_kwargs["gradient_accumulation_steps"] = gradient_accumulation_steps

fsdp_plugin = full_kwargs["fsdp_plugin"]
if fsdp_plugin is None and os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
fsdp_plugin = NeuronFullyShardedDataParallelPlugin()
if fsdp_plugin is None:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
fsdp_plugin = NeuronFullyShardedDataParallelPlugin()
elif not isinstance(fsdp_plugin, NeuronFullyShardedDataParallelPlugin):
raise ValueError(
"The fsdp_plugin must be an instance of NeuronFullyShardedDataParallelPlugin to use XLA FSDP with "
f"the NeuronAccelerator, but an instance of {type(fsdp_plugin)} was given here."
)
self.fsdp_plugin = fsdp_plugin

use_neuronx_distributed_tp = os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false")
if tp_plugin is None:
Expand All @@ -111,6 +119,14 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, *
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
super().__init__(**full_kwargs)

self.zero_1 = zero_1

if self.fsdp_plugin is not None and self.zero_1:
raise ValueError("Either enable XLA ZeRO Stage 1 or XLA FSDP but not both.")

if self.process_index == -1 and self.zero_1:
raise ValueError("XLA ZeRO Stage 1 can only be enabled in a distributed training setting.")

if fsdp_plugin is not None and tp_plugin is not None:
raise ValueError("It is not possible to both use neuronx_distributed Tensor Parallelism and XLA FSDP.")

Expand Down Expand Up @@ -152,10 +168,52 @@ def _prepare_optimizer_for_tp(self, optimizer: torch.optim.Optimizer, device_pla
)
return optimizer

def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device_placement=None):
mixed_precision_to_dtype = {
"no": torch.float32,
"bf16": torch.bfloat16,
}
optimizer_dtype = mixed_precision_to_dtype.get(self.state.mixed_precision, None)
if optimizer_dtype is None:
raise ValueError(f"The precision {self.state.mixed_precision} is not supported for ZeRO Stage 1")

if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
params = args[0]
defaults = args_and_kwargs_to_kwargs_only(optimizer.__class__, args[1:], kwargs)

zero_1_optimizer = ZeroRedundancyOptimizer(
params,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
grad_clipping=True, # TODO: handle this case.
max_norm=None, # TODO: handle this case.
pin_layout=False,
**defaults,
)
del optimizer
else:
logger.warning(
f"Creating a ZeroRedundancyOptimizer from {optimizer}, this might change some default values. When "
"using ZeRO 1 it is recommended to create the ZeroRedundancyOptimizer yourself to avoid this kind of "
"issues."
)
zero_1_optimizer = ZeroRedundancyOptimizer(
optimizer.param_groups,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
grad_clipping=True, # TODO: handle this case.
max_norm=None, # TODO: handle this case.
pin_layout=False,
)
return zero_1_optimizer

@patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer))
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None):
if self.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
optimizer = self._prepare_optimizer_for_tp(optimizer, device_placement=device_placement)
if self.zero_1:
optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement)
return super().prepare_optimizer(optimizer, device_placement=device_placement)

@patch_within_function(("accelerate.accelerator.AcceleratedScheduler", NeuronAcceleratedScheduler))
Expand Down
5 changes: 4 additions & 1 deletion optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
if is_torch_xla_available():
import accelerate
import torch_xla.core.xla_model as xm
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

accelerate.optimizer.xm = xm

Expand Down Expand Up @@ -68,7 +69,9 @@ def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2):

def step(self, closure=None):
if self.gradient_state.sync_gradients:
if self.accelerator_state.distributed_type is DistributedType.TPU:
if isinstance(self.optimizer, ZeroRedundancyOptimizer):
self.optimizer.step(closure)
elif self.accelerator_state.distributed_type is DistributedType.TPU:
optimizer_args = {"closure": closure} if closure is not None else {}
xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args)
elif self.accelerator_state.distributed_type is NeuronDistributedType.XLA_FSDP:
Expand Down
17 changes: 1 addition & 16 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,21 +194,6 @@ def _check_model_was_parallelized(cls, model: "PreTrainedModel"):
if not cls.was_parallelized(model):
raise ValueError("The model needs to be parallelized first.")

@classmethod
def make_optimizer_constructor_lazy_for_tp(cls, optimizer_cls: Type["torch.optim.Optimizer"]):
"""
Transforms an optimizer constructor (optimizer class) to make it lazy by not initializing the parameters.
This makes the optimizer lightweight and usable to create a "real" optimizer once the model has been
parallelized.
"""

def optimizer_constructor(*args, **kwargs):
optimizer_with_no_parameters = optimizer_cls([torch.nn.Parameter(torch.empty(1))], *args[1:], **kwargs)
optimizer_with_no_parameters._args_to_recreate = (args, kwargs)
return optimizer_with_no_parameters

return optimizer_constructor

@classmethod
def optimizer_for_tp(
cls,
Expand All @@ -220,7 +205,7 @@ def optimizer_for_tp(
There are two cases:
1. The optimizer has been created via a lazy constructor from
[`Parallelizer.make_optimizer_constructor_lazy_for_tp`], it which case the exactly intended optimizer is
[`optimum.neuron.distributed.utils.make_optimizer_constructor_lazy`], it which case the exactly intended optimizer is
created for tensor parallelism.
2. The optimizer was created with a regular constructor. In this case the optimizer for tensor parallelism
is created as close as possible to what was intended but that is not guaranteed.
Expand Down
17 changes: 16 additions & 1 deletion optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Literal, Optional, Tuple, Union
from typing import Dict, Literal, Optional, Tuple, Union, Type

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -468,3 +468,18 @@ def wrapper(*args, **kwargs):
yield
finally:
pass


def make_optimizer_constructor_lazy(optimizer_cls: Type["torch.optim.Optimizer"]):
"""
Transforms an optimizer constructor (optimizer class) to make it lazy by not initializing the parameters.
This makes the optimizer lightweight and usable to create a "real" optimizer once the model has been
parallelized.
"""

def optimizer_constructor(*args, **kwargs):
optimizer_with_no_parameters = optimizer_cls([torch.nn.Parameter(torch.empty(1))], *args[1:], **kwargs)
optimizer_with_no_parameters._args_to_recreate = (args, kwargs)
return optimizer_with_no_parameters

return optimizer_constructor
4 changes: 2 additions & 2 deletions optimum/neuron/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,6 @@ def local_path_to_path_in_repo(path):
target_file.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(path, self.neuron_cache_path / path_in_cache)

if self.wait_for_everyone_on_push:
xm.rendezvous("wait for everyone after pushing")

if self.use_neuron_cache:
self._update_cache_stats(self.neuron_cache_path)
Expand Down Expand Up @@ -314,6 +312,8 @@ def on_save(self, args: "TrainingArguments", state: TrainerState, control: "Trai
"""
if self.push:
self.synchronize_temporary_neuron_cache()
if self.wait_for_everyone_on_push:
xm.rendezvous("wait for everyone after pushing")

def on_train_end(self, args: "TrainingArguments", state: TrainerState, control: "TrainerControl", **kwargs):
"""
Expand Down
10 changes: 6 additions & 4 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@

from ..utils import check_if_transformers_greater, logging
from .accelerate import NeuronAccelerator, NeuronDistributedType
from .distributed import Parallelizer, ParallelizersManager
from .distributed import ParallelizersManager
from .distributed.utils import make_optimizer_constructor_lazy
from .trainer_callback import NeuronCacheCallaback
from .utils import (
DynamicPatch,
Expand Down Expand Up @@ -202,6 +203,7 @@ def create_accelerator_and_postprocess(self):
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
tp_plugin=self.args.tp_plugin,
zero_1=self.args.zero_1,
)

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
Expand Down Expand Up @@ -269,8 +271,9 @@ def get_test_dataloader(self, *args, **kwargs) -> DataLoader:
@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args)
if check_if_transformers_greater("4.30.0") and args.tp_plugin.should_parallelize:
optimizer_cls = Parallelizer.make_optimizer_constructor_lazy_for_tp(optimizer_cls)
lazy_load = args.tp_plugin.should_parallelize or args.zero_1
if check_if_transformers_greater("4.30.0") and lazy_load:
optimizer_cls = make_optimizer_constructor_lazy(optimizer_cls)
return optimizer_cls, optimizer_kwargs

@patch_within_function(("transformers.Trainer.get_optimizer_cls_and_kwargs", get_optimizer_cls_and_kwargs))
Expand Down Expand Up @@ -402,7 +405,6 @@ def _save_checkpoint_with_accelerator(self, model, trial, metrics=None):
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

def _save_checkpoint(self, model, trial, metrics=None):
# if self.fsdp or self.is_fsdp_enabled:
if check_if_transformers_greater("4.30.0") and self.accelerator.distributed_type in [
NeuronDistributedType.XLA_FSDP,
NeuronDistributedType.TENSOR_PARALLELISM,
Expand Down
3 changes: 3 additions & 0 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
logger = logging.get_logger(__name__)


@dataclass
class NeuronTrainingArgumentsMixin:
zero_1: bool = field(default=False, metadata={"help": "Whether to use ZeRO Stage 1 Optimization."})

def __post_init__(self):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_tpu_available()
Expand Down

0 comments on commit 9fcbc6e

Please sign in to comment.