Skip to content

Commit

Permalink
Remove old code
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 12, 2023
1 parent 0ac558d commit 5debbfe
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 68 deletions.
69 changes: 5 additions & 64 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,16 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Type, Union
from typing import Dict, Literal, Optional, Tuple, Type, Union

import torch
from transformers import PretrainedConfig

from ..utils import DynamicPatch, Patcher, is_neuronx_distributed_available, is_torch_xla_available
from ..utils import DynamicPatch, Patcher, is_neuronx_distributed_available
from ..utils.misc import download_checkpoints_in_cache
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla


if is_torch_xla_available():
import torch_xla.core.xla_model as xm
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
else:
ZeroRedundancyOptimizer = object


TENSOR_PARALLEL_SHARDS_DIR_NAME = "tensor_parallel_shards"


Expand Down Expand Up @@ -483,6 +476,7 @@ def gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads(
return sliced_linear_layer


@requires_torch_xla
@classmethod
def from_pretrained_for_tp(
cls,
Expand Down Expand Up @@ -542,6 +536,8 @@ def from_pretrained_for_tp(
**kwargs,
)

import torch_xla.core.xla_model as xm

xm.rendezvous("waiting after download and conversion")

if not isinstance(config, PretrainedConfig):
Expand Down Expand Up @@ -638,61 +634,6 @@ def optimizer_constructor(*args, **kwargs):
return optimizer_constructor


@requires_torch_xla
@requires_neuronx_distributed
class ZeroRedundancyOptimizerCompatibleWithTensorParallelism(ZeroRedundancyOptimizer):
def __init__(
self,
params: Iterator[torch.Tensor],
optimizer_class: Type[torch.optim.Optimizer],
optimizer_dtype: Optional[Any] = None,
grad_clipping: bool = True,
max_norm: Optional[float] = None,
pin_layout: bool = True,
**defaults: Any,
):
from neuronx_distributed.parallel_layers.parallel_state import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_size,
model_parallel_is_initialized,
)

if not is_neuronx_distributed_available() or not model_parallel_is_initialized():
return super().__init__(
params,
optimizer_class,
optimizer_dtype=optimizer_dtype,
grad_clipping=grad_clipping,
max_norm=max_norm,
pin_layout=pin_layout,
**defaults,
)

self.params = list(params)
super(ZeroRedundancyOptimizer, self).__init__(self.params, defaults)

if isinstance(self.params[0], dict):
self.params = [p for pg in self.params for p in pg["params"]]

self.device = self.params[0].device

self.rank = get_data_parallel_rank()
self.world_size = get_data_parallel_size()
self.cc_op_groups = get_data_parallel_group(as_list=True)

self.optimizer_dtype = optimizer_dtype if optimizer_dtype is not None else torch.float32
self.grad_clipping = grad_clipping
self.max_norm = max_norm if max_norm is not None else 1.0
self.pin_layout = pin_layout

# Shard parameters for use in optimizer
self.sharded_params = []
self._shard_parameters()
# Optimizer initialization
self.base_optimizer = optimizer_class(iter(self.sharded_params), **defaults)


@dataclass
class ParameterMetadata:
kind: Literal["tied", "sharded"]
Expand Down
3 changes: 0 additions & 3 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def __post_init__(self):
# Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available`
patch_accelerate_is_tpu_available()

# if not self.disable_embedding_parallelization:
# raise NotImplementedError("Disabling the parallelization of the embeddings is not fully supported yet.")

if self.fsdp != "":
# Disabling FSDP until next release because it is still very experimental and not validated.
raise RuntimeError("FSDP is not supported yet.")
Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/model_parallel_test_template.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ xm.mark_step()
if is_parallel and parallelize_embeddings:
gathered_model_outputs = dict()
for name, output in model_outputs.items():
if name == "loss" or output is None:
if name == "loss" or output is None or output.shape[-1] != (vocab_size // {tp_size}):
gathered_model_outputs[name] = output
else:
gathered_model_outputs[name] = gather_along_last_dim(output)
Expand Down

0 comments on commit 5debbfe

Please sign in to comment.