Skip to content

Commit

Permalink
Parallel cross entropy (#222)
Browse files Browse the repository at this point in the history
* Make ParallelCrossEntropy work

* Add support for all models

* [WIP] Adapt parallelization tests

* Fix all gather in test

* Remove unused code

* Disable the feature since it's not fully working
  • Loading branch information
michaelbenayoun authored Sep 19, 2023
1 parent 0cab527 commit 74d0905
Show file tree
Hide file tree
Showing 14 changed files with 558 additions and 172 deletions.
73 changes: 29 additions & 44 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,16 @@
from accelerate import Accelerator
from accelerate.checkpointing import save_accelerator_state, save_custom_state
from accelerate.utils import DistributedType
from packaging import version
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from optimum.neuron.utils.patching import ModelPatcher

from ...utils import logging
from ..distributed import Parallelizer, ParallelizersManager
from ..distributed.utils import ZeroRedundancyOptimizerCompatibleWithTensorParallelism
from ..utils import Patcher, is_neuronx_distributed_available, is_torch_xla_available, patch_within_function
from ..utils.misc import args_and_kwargs_to_kwargs_only
from ..utils.version_utils import get_torch_xla_version
from ..utils.require_utils import requires_neuronx_distributed
from .optimizer import NeuronAcceleratedOptimizer
from .scheduler import NeuronAcceleratedScheduler
from .state import NeuronAcceleratorState
Expand Down Expand Up @@ -182,6 +180,7 @@ def _prepare_optimizer_for_tp(self, optimizer: torch.optim.Optimizer, device_pla
optimizer.param_groups = xla_parameters
return optimizer

@requires_neuronx_distributed
def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device_placement=None):
mixed_precision_to_dtype = {
"no": torch.float32,
Expand All @@ -191,61 +190,47 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
if optimizer_dtype is None:
raise ValueError(f"The precision {self.state.mixed_precision} is not supported for ZeRO Stage 1")

from neuronx_distributed.optimizer import NeuronZero1Optimizer
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
get_tensor_model_parallel_group,
model_parallel_is_initialized,
)

if not is_neuronx_distributed_available() or not model_parallel_is_initialized():
sharding_groups = None
grad_norm_groups = None
else:
sharding_groups = get_data_parallel_group(as_list=True)
grad_norm_groups = get_tensor_model_parallel_group(as_list=True)

if hasattr(optimizer, "_args_to_recreate"):
args, kwargs = optimizer._args_to_recreate
params = args[0]
defaults = args_and_kwargs_to_kwargs_only(optimizer.__class__, args[1:], kwargs)

# Prior to 1.13.1+torchneuron8, the vanilla ZeroRedundancyOptimizer was not designed for TP.
full_torch_xla_version = get_torch_xla_version()
torch_xla_version, torchneuron_version = full_torch_xla_version.split("+")
torch_xla_version = version.parse(torch_xla_version)
if torch_xla_version <= version.parse("1.13.1") and int(torchneuron_version[-1]) < 8:
zero_1_optimizer = ZeroRedundancyOptimizerCompatibleWithTensorParallelism(
params,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
pin_layout=False,
**defaults,
)
# Exception to make sure `ZeroRedundancyOptimizerCompatibleWithTensorParallelism` is removed in 3 releases.
elif torch_xla_version <= version.parse("1.13.1") and int(torchneuron_version[-1]) >= 11:
raise RuntimeError(
"ZeroRedundancyOptimizerCompatibleWithTensorParallelism is deprecated and should be removed from "
"`optimum-neuron`"
)
else:
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
model_parallel_is_initialized,
)
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

if not is_neuronx_distributed_available() or not model_parallel_is_initialized():
cc_op_groups = None
else:
cc_op_groups = get_data_parallel_group(as_list=True)

zero_1_optimizer = ZeroRedundancyOptimizer(
params,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
pin_layout=False,
cc_op_groups=cc_op_groups,
**defaults,
)
zero_1_optimizer = NeuronZero1Optimizer(
params,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
pin_layout=False,
sharding_groups=sharding_groups,
grad_norm_groups=grad_norm_groups,
**defaults,
)
del optimizer
else:
logger.warning(
f"Creating a ZeroRedundancyOptimizer from {optimizer}, this might change some default values. When "
f"Creating a NeuronZero1Optimizer from {optimizer}, this might change some default values. When "
"using ZeRO 1 it is recommended to create the ZeroRedundancyOptimizer yourself to avoid this kind of "
"issues."
)
zero_1_optimizer = ZeroRedundancyOptimizerCompatibleWithTensorParallelism(
zero_1_optimizer = NeuronZero1Optimizer(
optimizer.param_groups,
optimizer.__class__,
optimizer_dtype=optimizer_dtype,
pin_layout=False,
sharding_groups=sharding_groups,
grad_norm_groups=grad_norm_groups,
)
return zero_1_optimizer

Expand Down
14 changes: 13 additions & 1 deletion optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Dict, Optional

from .base import Parallelizer
from .parallel_layers import ParallelEmbedding, ParallelMLP, ParallelSelfAttention
from .parallel_layers import ParallelCrossEntropy, ParallelEmbedding, ParallelMLP, ParallelSelfAttention
from .utils import linear_to_parallel_linear


Expand All @@ -44,6 +44,10 @@ class GPTNeoParallelMLP(ParallelMLP):
SECOND_LINEAR_NAME = "c_proj"


class GPTNeoParallelCrossEntropy(ParallelCrossEntropy):
LAST_LINEAR_PROJECTION_NAME = "lm_head"


class GPTNeoParallelizer(Parallelizer):
@classmethod
def _parallelize(
Expand All @@ -62,6 +66,8 @@ def _parallelize(
device=device,
)
block.mlp = GPTNeoParallelMLP.transform(model, block.mlp, device=device)
if parallelize_embeddings:
model = GPTNeoParallelCrossEntropy.transform(model, model, device=device)
return model


Expand Down Expand Up @@ -127,6 +133,10 @@ def transform(
return layer


class LlamaParallelCrossEntropy(ParallelCrossEntropy):
LAST_LINEAR_PROJECTION_NAME = "lm_head"


class LlamaParallelizer(Parallelizer):
@classmethod
def _parallelize(
Expand All @@ -141,4 +151,6 @@ def _parallelize(
for layer in model.model.layers:
layer.self_attn = LlamaParallelSelfAttention.transform(model, layer.self_attn, device=device)
layer.mlp = LLamaParallelMLP.transform(model, layer.mlp, device=device)
if parallelize_embeddings:
model = LlamaParallelCrossEntropy.transform(model, model, device=device)
return model
8 changes: 7 additions & 1 deletion optimum/neuron/distributed/encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ...utils import NormalizedConfigManager
from .base import Parallelizer
from .parallel_layers import ParallelEmbedding, ParallelMLP, ParallelSelfAttention
from .parallel_layers import ParallelCrossEntropy, ParallelEmbedding, ParallelMLP, ParallelSelfAttention
from .utils import linear_to_parallel_linear


Expand Down Expand Up @@ -143,6 +143,10 @@ def transform(
return layer


class T5ParallelCrossEntropy(ParallelCrossEntropy):
LAST_LINEAR_PROJECTION_NAME = "lm_head"


class T5Parallelizer(Parallelizer):
@classmethod
def _parallelize(
Expand Down Expand Up @@ -182,4 +186,6 @@ def _parallelize(
block.layer[2].DenseReluDense = T5ParallelMLP.transform(
model, block.layer[2].DenseReluDense, device=device
)
if parallelize_embeddings:
model = T5ParallelCrossEntropy.transform(model, model, device=device)
return model
43 changes: 42 additions & 1 deletion optimum/neuron/distributed/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from typing import TYPE_CHECKING, Dict, Optional

from ..utils.require_utils import requires_neuronx_distributed
from .base import Parallelizer
from .parallel_layers import ParallelEmbedding, ParallelSelfAttention, ParallelSelfOutput
from .parallel_layers import ParallelCrossEntropy, ParallelEmbedding, ParallelSelfAttention, ParallelSelfOutput


if TYPE_CHECKING:
Expand All @@ -33,6 +34,23 @@ class BertParallelEmbedding(ParallelEmbedding):
"BertForMaskedLM": "cls.predictions.decoder",
}

@classmethod
@requires_neuronx_distributed
def transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]] = None,
device: Optional["torch.device"] = None,
) -> "torch.nn.Module":
layer = super().transform(model, layer, orig_to_parallel=orig_to_parallel, device=device)
from transformers.models.bert.modeling_bert import BertLMPredictionHead

for mod in layer.modules():
if isinstance(mod, BertLMPredictionHead):
mod.bias = mod.decoder.bias
return layer


class BertParallelSelfAttention(ParallelSelfAttention):
ALL_HEAD_SIZE_NAME = "all_head_size"
Expand All @@ -42,6 +60,14 @@ class BertParallelSelfOutput(ParallelSelfOutput):
pass


class BertParallelCrossEntropy(ParallelCrossEntropy):
LAST_LINEAR_PROJECTION_NAME = {
"BertForPreTraining": "cls.predictions.decoder",
"BertLMHeadModel": "cls.predictions.decoder",
"BertForMaskedLM": "cls.predictions.decoder",
}


class BertParallelizer(Parallelizer):
@classmethod
def _parallelize(
Expand All @@ -66,6 +92,10 @@ def _parallelize(
orig_to_parallel=orig_to_parallel,
device=device,
)
# Valid because we currently parallelize the cross-entropy loss only for language-modeling tasks where the
# embeddings and the LM head are tied.
if parallelize_embeddings:
model = BertParallelCrossEntropy.transform(model, model, device=device)
return model


Expand All @@ -85,6 +115,13 @@ class RobertaParallelSelfOutput(BertParallelSelfOutput):
pass


class RobertaParallelCrossEntropy(ParallelCrossEntropy):
LAST_LINEAR_PROJECTION_NAME = {
"RobertaForCausalLM": "lm_head.decoder",
"RobertaForMaskedLM": "lm_head.decoder",
}


class RobertaParallelizer(Parallelizer):
@classmethod
def _parallelize(
Expand All @@ -109,4 +146,8 @@ def _parallelize(
orig_to_parallel=orig_to_parallel,
device=device,
)
# Valid because we currently parallelize the cross-entropy loss only for language-modeling tasks where the
# embeddings and the LM head are tied.
if parallelize_embeddings:
model = RobertaParallelCrossEntropy.transform(model, model, device=device)
return model
Loading

0 comments on commit 74d0905

Please sign in to comment.