diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 2f1b5e903..b90d62601 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -19,23 +19,16 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Type, Union +from typing import Dict, Literal, Optional, Tuple, Type, Union import torch from transformers import PretrainedConfig -from ..utils import DynamicPatch, Patcher, is_neuronx_distributed_available, is_torch_xla_available +from ..utils import DynamicPatch, Patcher, is_neuronx_distributed_available from ..utils.misc import download_checkpoints_in_cache from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer -else: - ZeroRedundancyOptimizer = object - - TENSOR_PARALLEL_SHARDS_DIR_NAME = "tensor_parallel_shards" @@ -483,6 +476,7 @@ def gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads( return sliced_linear_layer +@requires_torch_xla @classmethod def from_pretrained_for_tp( cls, @@ -542,6 +536,8 @@ def from_pretrained_for_tp( **kwargs, ) + import torch_xla.core.xla_model as xm + xm.rendezvous("waiting after download and conversion") if not isinstance(config, PretrainedConfig): @@ -638,61 +634,6 @@ def optimizer_constructor(*args, **kwargs): return optimizer_constructor -@requires_torch_xla -@requires_neuronx_distributed -class ZeroRedundancyOptimizerCompatibleWithTensorParallelism(ZeroRedundancyOptimizer): - def __init__( - self, - params: Iterator[torch.Tensor], - optimizer_class: Type[torch.optim.Optimizer], - optimizer_dtype: Optional[Any] = None, - grad_clipping: bool = True, - max_norm: Optional[float] = None, - pin_layout: bool = True, - **defaults: Any, - ): - from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_group, - get_data_parallel_rank, - get_data_parallel_size, - model_parallel_is_initialized, - ) - - if not is_neuronx_distributed_available() or not model_parallel_is_initialized(): - return super().__init__( - params, - optimizer_class, - optimizer_dtype=optimizer_dtype, - grad_clipping=grad_clipping, - max_norm=max_norm, - pin_layout=pin_layout, - **defaults, - ) - - self.params = list(params) - super(ZeroRedundancyOptimizer, self).__init__(self.params, defaults) - - if isinstance(self.params[0], dict): - self.params = [p for pg in self.params for p in pg["params"]] - - self.device = self.params[0].device - - self.rank = get_data_parallel_rank() - self.world_size = get_data_parallel_size() - self.cc_op_groups = get_data_parallel_group(as_list=True) - - self.optimizer_dtype = optimizer_dtype if optimizer_dtype is not None else torch.float32 - self.grad_clipping = grad_clipping - self.max_norm = max_norm if max_norm is not None else 1.0 - self.pin_layout = pin_layout - - # Shard parameters for use in optimizer - self.sharded_params = [] - self._shard_parameters() - # Optimizer initialization - self.base_optimizer = optimizer_class(iter(self.sharded_params), **defaults) - - @dataclass class ParameterMetadata: kind: Literal["tied", "sharded"] diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 0062e610f..71d3cf56e 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -64,9 +64,6 @@ def __post_init__(self): # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` patch_accelerate_is_tpu_available() - # if not self.disable_embedding_parallelization: - # raise NotImplementedError("Disabling the parallelization of the embeddings is not fully supported yet.") - if self.fsdp != "": # Disabling FSDP until next release because it is still very experimental and not validated. raise RuntimeError("FSDP is not supported yet.") diff --git a/tests/distributed/model_parallel_test_template.txt b/tests/distributed/model_parallel_test_template.txt index a0e5b5b94..0e267f219 100644 --- a/tests/distributed/model_parallel_test_template.txt +++ b/tests/distributed/model_parallel_test_template.txt @@ -97,7 +97,7 @@ xm.mark_step() if is_parallel and parallelize_embeddings: gathered_model_outputs = dict() for name, output in model_outputs.items(): - if name == "loss" or output is None: + if name == "loss" or output is None or output.shape[-1] != (vocab_size // {tp_size}): gathered_model_outputs[name] = output else: gathered_model_outputs[name] = gather_along_last_dim(output)