Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero-1 support #140

Merged
merged 7 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 78 additions & 9 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from ...utils import logging
from ..distributed import Parallelizer, ParallelizersManager
from ..distributed.utils import ZeroRedundancyOptimizerCompatibleWithTensorParallelism
from ..utils import Patcher, is_neuronx_distributed_available, is_torch_xla_available, patch_within_function
from ..utils.misc import args_and_kwargs_to_kwargs_only
from .optimizer import NeuronAcceleratedOptimizer
Expand All @@ -55,6 +56,7 @@

if is_torch_xla_available():
import torch_xla.core.xla_model as xm
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
else:
xm = None

Expand All @@ -68,7 +70,7 @@
# TODO: should we do a XLAFSDPNeuronAccelerator instead?
class NeuronAccelerator(Accelerator):
# @patch_within_function(("accelerate.accelerator.AcceleratorState", NeuronAcceleratorState))
def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, **kwargs):
def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, zero_1: bool = False, **kwargs):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_tpu_available()

Expand All @@ -92,8 +94,15 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, *
full_kwargs["gradient_accumulation_steps"] = gradient_accumulation_steps

fsdp_plugin = full_kwargs["fsdp_plugin"]
if fsdp_plugin is None and os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
fsdp_plugin = NeuronFullyShardedDataParallelPlugin()
if fsdp_plugin is None:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true":
fsdp_plugin = NeuronFullyShardedDataParallelPlugin()
elif not isinstance(fsdp_plugin, NeuronFullyShardedDataParallelPlugin):
raise ValueError(
"The fsdp_plugin must be an instance of NeuronFullyShardedDataParallelPlugin to use XLA FSDP with "
f"the NeuronAccelerator, but an instance of {type(fsdp_plugin)} was given here."
)
self.fsdp_plugin = fsdp_plugin

use_neuronx_distributed_tp = os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false")
if tp_plugin is None:
Expand All @@ -111,6 +120,14 @@ def __init__(self, *args, tp_plugin: Optional[TensorParallelismPlugin] = None, *
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
super().__init__(**full_kwargs)

self.zero_1 = zero_1

if self.fsdp_plugin is not None and self.zero_1:
raise ValueError("Either enable XLA ZeRO Stage 1 or XLA FSDP but not both.")

if self.process_index == -1 and self.zero_1:
raise ValueError("XLA ZeRO Stage 1 can only be enabled in a distributed training setting.")

if fsdp_plugin is not None and tp_plugin is not None:
raise ValueError("It is not possible to both use neuronx_distributed Tensor Parallelism and XLA FSDP.")

Expand Down Expand Up @@ -147,15 +164,67 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona
return super().prepare_data_loader(data_loader, device_placement=device_placement)

def _prepare_optimizer_for_tp(self, optimizer: torch.optim.Optimizer, device_placement=None):
optimizer = Parallelizer.optimizer_for_tp(
optimizer, collections.ChainMap(*self._model_cpu_parameters_to_xla.values())
)
cpu_parameters_to_xla = collections.ChainMap(*self._model_cpu_parameters_to_xla.values())
if not self.zero_1:
optimizer = Parallelizer.optimizer_for_tp(optimizer, cpu_parameters_to_xla)
else:
xla_parameters, _ = Parallelizer.optimizer_cpu_params_to_xla_params(optimizer, cpu_parameters_to_xla)
if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
args = (xla_parameters,) + args[1:]
optimizer._args_to_recreate = (args, kwargs)
else:
optimizer.param_groups = xla_parameters
return optimizer

def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device_placement=None):
mixed_precision_to_dtype = {
"no": torch.float32,
"bf16": torch.bfloat16,
}
optimizer_dtype = mixed_precision_to_dtype.get(self.state.mixed_precision, None)
if optimizer_dtype is None:
raise ValueError(f"The precision {self.state.mixed_precision} is not supported for ZeRO Stage 1")

zero_redundancy_optimizer_class = (
ZeroRedundancyOptimizerCompatibleWithTensorParallelism
if self.state.tp_plugin.should_parallelize
else ZeroRedundancyOptimizer
)

if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
params = args[0]
defaults = args_and_kwargs_to_kwargs_only(optimizer.__class__, args[1:], kwargs)

zero_1_optimizer = zero_redundancy_optimizer_class(
params,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
pin_layout=False,
**defaults,
)
del optimizer
else:
logger.warning(
f"Creating a ZeroRedundancyOptimizer from {optimizer}, this might change some default values. When "
"using ZeRO 1 it is recommended to create the ZeroRedundancyOptimizer yourself to avoid this kind of "
"issues."
)
zero_1_optimizer = zero_redundancy_optimizer_class(
optimizer.param_groups,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
pin_layout=False,
)
return zero_1_optimizer

@patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer))
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None):
if self.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
optimizer = self._prepare_optimizer_for_tp(optimizer, device_placement=device_placement)
if self.zero_1:
optimizer = self._prepare_optimizer_for_zero_1(optimizer, device_placement=device_placement)
return super().prepare_optimizer(optimizer, device_placement=device_placement)

@patch_within_function(("accelerate.accelerator.AcceleratedScheduler", NeuronAcceleratedScheduler))
Expand Down Expand Up @@ -282,7 +351,7 @@ def clip_grad_norm_for_xla_fsdp(self, parameters, max_norm, norm_type: int = 2):
if parameters == list(model.parameters()):
return model.clip_grad_norm_(max_norm, norm_type)

def _clip_grad_norm_for_tp(self, parameters, max_norm, norm_type: int = 2):
def _prepare_clip_grad_norm(self, parameters, max_norm, norm_type: int = 2):
self.unscale_gradients()
parameters = list(parameters)
for model in self._models:
Expand All @@ -295,8 +364,8 @@ def _clip_grad_norm_for_tp(self, parameters, max_norm, norm_type: int = 2):
def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.clip_grad_norm_for_xla_fsdp(parameters, max_norm, norm_type=norm_type)
elif self.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
return self._clip_grad_norm_for_tp(parameters, max_norm, norm_type=norm_type)
elif self.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM or self.zero_1:
return self._prepare_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
return super().clip_grad_norm_(parameters, max_norm, norm_type=norm_type)

def clip_grad_value_(self, parameters, clip_value):
Expand Down
24 changes: 15 additions & 9 deletions optimum/neuron/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
if is_torch_xla_available():
import accelerate
import torch_xla.core.xla_model as xm
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

accelerate.optimizer.xm = xm

Expand Down Expand Up @@ -57,18 +58,23 @@ def load_state_dict(self, state_dict):
return super().load_state_dict(state_dict)

def prepare_clip_grad_norm(self, parameters, max_norm, norm_type=2):
# Deferring the clipping to later since gradients need to be reduced first when performing tensor parallelism.
# TODO: find a better way to make sure we are using the right parameters for the right optimizer.
if self.accelerator_state.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
parameter_ids = {id(p) for p in parameters}
if parameter_ids == self.parameter_ids:
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}
else:
raise RuntimeError("The AcceleratedOptimizer handles grad clipping only for tensor parallelism.")
parameter_ids = {id(p) for p in parameters}
if parameter_ids == self.parameter_ids:
self.clip_grad_norm_to_perform = {"max_norm": max_norm, "norm_type": norm_type}

def step(self, closure=None):
if self.gradient_state.sync_gradients:
if self.accelerator_state.distributed_type is DistributedType.TPU:
if isinstance(self.optimizer, ZeroRedundancyOptimizer):
if self.clip_grad_norm_to_perform is not None:
# `ZeroRedundancyOptimizer` does not allow to pass a norm type, it could be done but postponing for
# now.
self.optimizer.grad_clipping = True
self.optimizer.max_norm = self.clip_grad_norm_to_perform["max_norm"]
else:
self.optimizer.grad_clipping = False
optimizer_args = {"closure": closure} if closure is not None else {}
self.optimizer.step(closure)
elif self.accelerator_state.distributed_type is DistributedType.TPU:
optimizer_args = {"closure": closure} if closure is not None else {}
xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args)
elif self.accelerator_state.distributed_type is NeuronDistributedType.XLA_FSDP:
Expand Down
93 changes: 44 additions & 49 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from abc import ABC, abstractclassmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union

import torch
from transformers.utils import WEIGHTS_NAME
Expand Down Expand Up @@ -195,19 +195,41 @@ def _check_model_was_parallelized(cls, model: "PreTrainedModel"):
raise ValueError("The model needs to be parallelized first.")

@classmethod
def make_optimizer_constructor_lazy_for_tp(cls, optimizer_cls: Type["torch.optim.Optimizer"]):
"""
Transforms an optimizer constructor (optimizer class) to make it lazy by not initializing the parameters.
This makes the optimizer lightweight and usable to create a "real" optimizer once the model has been
parallelized.
"""

def optimizer_constructor(*args, **kwargs):
optimizer_with_no_parameters = optimizer_cls([torch.nn.Parameter(torch.empty(1))], *args[1:], **kwargs)
optimizer_with_no_parameters._args_to_recreate = (args, kwargs)
return optimizer_with_no_parameters

return optimizer_constructor
def optimizer_cpu_params_to_xla_params(
cls,
optimizer: "torch.optim.Optimizer",
orig_param_to_parallel_param_on_xla: Mapping[int, "torch.nn.Parameter"],
) -> Tuple[List[Dict[str, Any]], bool]:
parameters_on_xla = []
need_to_create_new_optimizer = False
if hasattr(optimizer, "_args_to_recreate"):
args, _ = optimizer._args_to_recreate
parameters = args[0]
for param in parameters:
if isinstance(param, dict):
new_param = {k: v for k, v in param.items() if k != "params"}
params = []
for p in param["params"]:
params.append(orig_param_to_parallel_param_on_xla[id(p)])
new_param["params"] = params
else:
new_param = []
for p in param:
new_param.append(orig_param_to_parallel_param_on_xla[id(p)])
parameters_on_xla.append(new_param)
else:
for param_group in optimizer.param_groups:
new_params = []
params = param_group["params"]
for idx in range(len(params)):
param_on_xla = orig_param_to_parallel_param_on_xla[id(params[idx])]
if params[idx] != param_on_xla:
need_to_create_new_optimizer = True
new_params.append(param_on_xla)
new_group = {k: v for k, v in param_group.items() if k != "params"}
new_group["params"] = new_params
parameters_on_xla.append(new_group)
return parameters_on_xla, need_to_create_new_optimizer

@classmethod
def optimizer_for_tp(
Expand All @@ -220,7 +242,7 @@ def optimizer_for_tp(

There are two cases:
1. The optimizer has been created via a lazy constructor from
[`Parallelizer.make_optimizer_constructor_lazy_for_tp`], it which case the exactly intended optimizer is
[`optimum.neuron.distributed.utils.make_optimizer_constructor_lazy`], it which case the exactly intended optimizer is
created for tensor parallelism.
2. The optimizer was created with a regular constructor. In this case the optimizer for tensor parallelism
is created as close as possible to what was intended but that is not guaranteed.
Expand All @@ -235,45 +257,18 @@ def optimizer_for_tp(
Returns:
`torch.optim.Optimizer`: The tensor parallelism ready optimizer.
"""
parallel_parameters, need_to_create_new_optimizer = cls.optimizer_cpu_params_to_xla_params(
optimizer, orig_param_to_parallel_param_on_xla
)
if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
parameters = args[0]
parallel_parameters = []
for param in parameters:
if isinstance(param, dict):
new_param = {k: v for k, v in param.items() if k != "params"}
params = []
for p in param["params"]:
params.append(orig_param_to_parallel_param_on_xla[id(p)])
new_param["params"] = params
else:
new_param = []
for p in param:
new_param.append(orig_param_to_parallel_param_on_xla[id(p)])
parallel_parameters.append(new_param)
optimizer_for_tp = optimizer.__class__(parallel_parameters, *args[1:], **kwargs)
del optimizer
elif need_to_create_new_optimizer:
optimizer_for_tp = optimizer.__class__(parallel_parameters)
del optimizer
else:
need_to_create_new_optimizer = False
new_groups = []
for param_group in optimizer.param_groups:
new_params = []
params = param_group["params"]
for idx in range(len(params)):
param_on_xla = orig_param_to_parallel_param_on_xla[id(params[idx])]
if params[idx] != param_on_xla:
need_to_create_new_optimizer = True
new_params.append(param_on_xla)
new_group = {k: v for k, v in param_group.items() if k != "params"}
new_group["params"] = new_params
new_groups.append(new_group)

if need_to_create_new_optimizer:
optimizer_for_tp = optimizer.__class__(new_groups)
del optimizer
else:
optimizer_for_tp = optimizer

optimizer_for_tp = optimizer
return optimizer_for_tp

@classmethod
Expand Down
Loading
Loading