Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documentation on tensor parallelism #339

Merged
merged 20 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,20 @@
title: Neuron model cache
- local: guides/fine_tune
title: Fine-tune Transformers with AWS Trainium
- local: guides/distributed_training
title: Distributed Training
- local: guides/models
title: Neuron models for inference
- local: guides/export_model
title: Export a model to Inferentia
- local: guides/models
title: Neuron models for inference
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
- local: guides/pipelines
title: Inference pipelines with AWS Neuron
title: How-To Guides
- sections:
- local: package_reference/trainer
title: Neuron Trainer
- local: package_reference/distributed
title: Neuron Distributed
- local: package_reference/export
title: Neuron Exporter
- local: package_reference/modeling
Expand Down
199 changes: 199 additions & 0 deletions docs/source/guides/distributed_training.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Distributed Training with `optimum-neuron`

[AWS Trainium instances](https://aws.amazon.com/machine-learning/trainium/) are great to train models. They can contain up to 16 Neuron devices, each device containing 2 Neuron cores and has 32GB of memory (16GB per core). For example a `trn1.32xlarge` instance has 32 x 16 = 512GB of memory.

But there is a caveat: each Neuron core is an independent data-parallel worker by default. It means that the model, the gradient state and the optimizer state, amounting to approximately 4 times the model size, must fit in each of the Neuron cores (16GB) to be able to train. If that is the case, then the activations must also fit in the remaining memory.

To alleviate that, `optimum-neuron` supports parallelism features enabling your to harness the full power of your Trainium instance:
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

1. [ZeRO-1](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html): It is a optimization of data-parallism which consists in sharding the optimizer state (which usually represents half of the memory needed on the device) over the data-parallel ranks.
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
2. [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html): It is a technique which consists in sharding each of your model parameters along a given dimension on multiple devices. The number of devices to shard your parameters on is called the `tensor_parallel_size`.
3. [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html): **coming soon!**


The good news is that is it possible to combine those techniques, and `optimum-neuron` makes it very easy!

<Tip>
All the [example scripts](https://github.com/huggingface/optimum-neuron/tree/main/examples) provided in the `optimum-neuron` repo have those features implemented via the [`NeuronTrainer`].
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
</Tip>

## How to enable ZeRO-1?

Whether you use the [`NeuronTrainer`] or decide to have your own training script that uses the [`NeuronAccelerator`], it is very easy to enable the ZeRO-1 optimization.

### Via the `NeuronTrainer`

```python
from optimum.neuron import NeuronTrainingArguments, NeuronTrainer

# To enable ZeRO-1, set the `zero_1` argument to `True` in the training arguments.
training_args = NeuronTrainingArguments(
...,
zero_1=True,
)

trainer = NeuronTrainer(
model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)

trainer.train()
```

<Tip>
Since the [example scripts](https://github.com/huggingface/optimum-neuron/tree/main/examples) use the [`NeuronTrainer`], you can enable ZeRO-1 when using them by add the `--zero_1` flag
to your command line.

For example:

```bash
torchrun --nproc_per_node=2 examples/language-modeling/run_clm.py \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v0.6 \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--per_device_train_batch_size 1 \
--block_size 1024 \
--bf16 \
--zero_1 \
--output_dir my_training/

```
</Tip>

### Via the `NeuronAccelerator`

There is a little bit more work to do when not using the `NeuronTrainer`:

1. (Optional) Wrap the optimizer class to make it lazy. When ZeRO-1 is enabled the original optimizer is overridden to use a sharded version of it so it is possible to load original optimizer lazily so that the optimizer state is not materialized until it is actually sharded.
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

```python
from torch.optim import AdamW
from optimum.neuron.distributed import make_optimizer_constructor_lazy

lazy_adamw = make_optimizer_constructor_lazy(AdamW)
```

2. Set the `zero_1` argument to `True` when instantiating the `NeuronAccelerator`.

```python
accelerator = NeuronAccelerator(
...
zero_1=True,
)

model = ...
lazy_optimizer = lazy_adamw(...) # Actually instantiate the optimizer.


model, optimizer = accelerator.prepare(model, lazy_optimizer)
```

michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

## How to enable Tensor Parallelism?

Just as for ZeRO-1, it is possible to use Tensor Parallelism both with the [`NeuronTrainer`] and the [`NeuronAccelerator`].
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

When doing Tensor Parallelism, you have different settings:
1. The `tensor_parallel_size`. Ideally it should be smallest value for which the model fits.
2. Whether or not sequence parallelism should be enabled. [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf) shards the activations on the sequence axis outside of the tensor parallel regions.
It is useful because it saves memory by sharding the activations.
3. Whether or not parallelization of the embedding (and thus the LM head for decoder / seq2seq models) should be done. **It is not supported yet**.

On top of that, it is very important to make sure that the original model is loaded in an efficient manner: the training script is going to be called by `torchrun`, which will dispatch it to workers, one worker per core. If each worker (there are 32 of them in a `trn1.32xlarge` instance) loads the full model weights, it can take a lot of time and go out-of-memory really fast.

`optimum-neuron` provides a context-manager [`distributed.lazy_load_for_parallelism`] that loads the model lazily to prevent that, only the parameters of the corresponding model shard will be materialized in each worker.

## Via the `NeuronTrainer`

```python
from optimum.neuron import NeuronTrainingArguments, NeuronTrainer
from optimum.neuron.distributed import lazy_load_for_parallelism

# Specify the `tensor_parallel_size` in the training arguments.
training_args = NeuronTrainingArguments(
...,
tensor_parallel_size=8,
disable_embedding_parallelization=True, # It is `True` because the feature is not supported yet.
disable_sequence_parallel=False, # It is `False` by default.
)

with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model = ...


trainer = NeuronTrainer(
model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)

trainer.train()
```

<Tip>
Since the [example scripts](https://github.com/huggingface/optimum-neuron/tree/main/examples) use the [`NeuronTrainer`], you can enable Tensor Parallelism when using them by specifying the `--tensor_parallel_size` argument, optionally the `disable_embedding_parallelization` and `disable_sequence_parallel` flags.
to your command line.

For example:

```bash
torchrun --nproc_per_node=2 examples/language-modeling/run_clm.py \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v0.6 \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--per_device_train_batch_size 1 \
--block_size 1024 \
--bf16 \
--tensor_parallel_size 2 \
--output_dir my_training/
```
</Tip>


## Via the `NeuronAccelerator`

Just as for ZeRO-1, it is possible to wrap the optimizer class to make it lazy. Since the model parameters are going to be sharded, it is not needed to materialize the optimizer state prior to model parallelization: the wrapper makes sure that it stays unmaterialized.

```python
from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.accelerate.utils import TensorParallelismPlugin
from optimum.neuron.distributed import lazy_load_for_parallelism

tensor_parallel_size = 8
tp_plugin = TensorParallelismPlugin(
tensor_parallel_size,
not self.disable_embedding_parallelization,
sequence_parallel_enabled=True,
checkpoint_dir=None, # Can be specified when resuming from checkpoint.
)

accelerator = NeuronAccelerator(
...
tp_plugin=tp_plugin,
)

with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
model = ...

lazy_adamw = make_optimizer_constructor_lazy(AdamW)
lazy_optimizer = lazy_adamw(...) # Actually instantiate the optimizer.

model, optimizer = accelerator.prepare(model, lazy_optimizer)
```
49 changes: 49 additions & 0 deletions docs/source/package_reference/distributed.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Optimum Neuron Distributed

The `optimum.neuron.distributed` module provides a set of tools to perform distributed training and inference.

## Parallelization

The main task in distributed training / inference is being able to shard things such as the model weights, the gradient and the optimizer state. The `Parallelizer` classes handle that.
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved

### Base `Parallelizer`

The [`~optimum.neuron.distributed.Parallelizer`] class is the base abstract class being derived for every model supporting model parallelism. It provides methods to parallelize the model and save and load sharded checkpoints.

[[autodoc]] distributed.Parallelizer
- _parallelize
- parallelize
- optimizer_for_tp
- save_model_checkpoint
- load_model_checkpoint

### Selecting Model-Specific Parallelizer Classes

Each model that supports parallelization in `optimum-neuron` has its own derived `Parallelizer` class. The factory class [`~optimum.neuron.distributed.ParallelizersManager`] allows you to retrieve such model-specific `Parallelizer`s easily.

[[autodoc]] distributed.parallelizers_manager.ParallelizersManager
- get_supported_model_types
- is_model_supported
- parallelizer_for_model


## Utils

### Lazy Loading

Distributed training / inference is usually needed when the model is too big to fit in one device. Tools that allow for lazy loading of model weights and optimizer states are thus needed to avoid going out-of-memory before parallelization.

[[autodoc]] distributed.utils.lazy_load_for_parallelism

[[autodoc]] distributed.utils.make_optimizer_constructor_lazy
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
consolidate_tensor_parallel_checkpoints_to_unified_checkpoint,
)
from .parallelizers_manager import ParallelizersManager
from .utils import lazy_load_for_parallelism
from .utils import lazy_load_for_parallelism, make_optimizer_constructor_lazy
18 changes: 18 additions & 0 deletions optimum/neuron/distributed/parallelizers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class ParallelizersManager:

@classmethod
def get_supported_model_types(cls) -> List[str]:
"""
Provides the list of supported model types for parallelization.
"""
return list(cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys())

@classmethod
Expand All @@ -74,11 +77,26 @@ def _get_model_type(cls, model_type_or_model: Union[str, PreTrainedModel]) -> st

@classmethod
def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> bool:
"""
Returns `True` if the model can be parallelized, `False` otherwise.

Args:
model_type_or_model (`Union[str, PreTrainedModel]`):
Either the model type or an instance of the model.
"""
model_type = cls._get_model_type(model_type_or_model)
return model_type in cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS

@classmethod
def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Type[Parallelizer]:
"""
Returns the parallelizer class associated to the model.

Args:
model_type_or_model (`Union[str, PreTrainedModel]`):
Either the model type or an instance of the model.

"""
model_type = cls._get_model_type(model_type_or_model)
if not cls.is_model_supported(model_type_or_model):
supported_models = ", ".join(cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys())
Expand Down
6 changes: 3 additions & 3 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class NeuronTrainingArgumentsMixin:
default=True,
metadata={"help": "Whether or not the embedding parallelization when doing TP should be disabled."},
)
sequence_parallel_enabled: bool = field(
disable_sequence_parallel: bool = field(
default=False,
metadata={"help": "Whether or not to enable sequence parallelism."},
metadata={"help": "Whether or not to disable sequence parallelism."},
)

def __post_init__(self):
Expand Down Expand Up @@ -108,7 +108,7 @@ def __post_init__(self):
self.tp_plugin = TensorParallelismPlugin(
self.tensor_parallel_size,
not self.disable_embedding_parallelization,
sequence_parallel_enabled=self.sequence_parallel_enabled,
sequence_parallel_enabled=not self.disable_sequence_parallel,
checkpoint_dir=resume_from_checkpoint,
)
super().__post_init__()
Expand Down
Loading