Skip to content

Commit

Permalink
[WIP] llama-70b
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 6, 2024
1 parent 3d99397 commit dd49c38
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 23 deletions.
22 changes: 10 additions & 12 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _parallelize(
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
sequence_parallel_enabled: bool = False,
should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None,
should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts.
Expand All @@ -260,7 +260,9 @@ def _parallelize(
This can be disabled in the case when the TP size does not divide the vocabulary size.
sequence_parallel_enabled (`bool`, defaults to `False`):
Whether or not sequence parallelism is enabled.
# TODO: add docstring
should_parallelize_layer_predicate_func (Optional[Callable[[torch.nn.Module], bool]], defaults to `None`):
A function that takes a layer as input and returns a boolean specifying if the input layer should be
parallelized. This is useful to skip unnecessary parallelization, for pipeline parallelism for instance.
Returns:
`PreTrainedModel`: The parallelized model.
"""
Expand Down Expand Up @@ -337,27 +339,23 @@ def parallelize(
name_to_parameter = dict(named_parameters(model, remove_duplicate=False))
parameter_to_name = {p: n for n, p in name_to_parameter.items()}

xm.master_print(name_to_parameter.keys())

def predicate_func(layer):
for n, p in layer.named_parameters():
for p in layer.parameters():
if p not in parameter_to_name:
xm.master_print(n)
return True
names = {parameter_to_name[p] for p in layer.parameters()}
return names < names_of_the_parameters_to_consider

model.predicate = predicate_func

if tp_size > 1:
model = cls._parallelize(
model,
device=device,
parallelize_embeddings=parallelize_embeddings,
sequence_parallel_enabled=sequence_parallel_enabled,
# should_parallelize_predicate_func=predicate_func,
should_parallelize_predicate_func=predicate_func,
)
# xm.rendezvous("End of tensor parallelism")

xm.rendezvous("End of tensor parallelism")

# Preparing the model for sequence parallelism:
sp_specs_cls = cls.SEQUENCE_PARALLELSIM_SPECS_CLS
Expand Down Expand Up @@ -507,7 +505,7 @@ def predicate_func(layer):
if left_uninitialized and hasattr(mod, "reset_parameters"):
initialize_torch_nn_module(mod, parameter_names)

# xm.rendezvous("End of initalization")
xm.rendezvous("End of initalization")

pp_size = get_pipeline_model_parallel_size()
if pp_size > 1:
Expand Down Expand Up @@ -535,7 +533,7 @@ def predicate_func(layer):
if gradient_checkpointing:
apply_checkpoint(model)

# xxm.rendezvous("End of pipeline paralellism")
xm.rendezvous("End of pipeline paralellism")

if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)
Expand Down
5 changes: 4 additions & 1 deletion optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Classes related to `neuronx-distributed` to perform parallelism."""

import warnings
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple

import torch
from transformers.cache_utils import Cache
Expand Down Expand Up @@ -605,7 +605,10 @@ def transform(
layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None,
) -> "torch.nn.Module":
if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer):
return layer
# TODO: Make it smart by merging the gate and the up_proj.
# WARNING: be careful of the interleaved outputs when doing TP!
layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)
Expand Down
11 changes: 9 additions & 2 deletions optimum/neuron/distributed/encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Classes related to `neuronx-distributed` to perform parallelism."""

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

import torch
from transformers.models.t5.modeling_t5 import T5Attention, T5ForSequenceClassification, T5LayerNorm
Expand Down Expand Up @@ -54,9 +54,12 @@ def transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
device: Optional["torch.device"] = None,
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None,
) -> "torch.nn.Module":
if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer):
return layer
from neuronx_distributed.parallel_layers.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_size,
Expand Down Expand Up @@ -100,7 +103,11 @@ def transform(
layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None,
) -> "torch.nn.Module":
if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer):
return layer

from transformers.models.t5.modeling_t5 import T5DenseGatedActDense

if cls.FIRST_LINEAR_NAME is None or cls.SECOND_LINEAR_NAME is None:
Expand Down
11 changes: 9 additions & 2 deletions optimum/neuron/distributed/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Classes related to `neuronx-distributed` to perform parallelism."""

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

import torch

Expand Down Expand Up @@ -64,8 +64,15 @@ def transform(
layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None,
) -> "torch.nn.Module":
layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)
layer = super().transform(
model,
layer,
sequence_parallel_enabled=sequence_parallel_enabled,
device=device,
should_parallelize_layer_predicate_func=should_parallelize_layer_predicate_func,
)
from transformers.models.bert.modeling_bert import BertLMPredictionHead

for mod in layer.modules():
Expand Down
5 changes: 3 additions & 2 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union

import torch
from torch.nn.modules.loss import _WeightedLoss
Expand Down Expand Up @@ -133,8 +133,9 @@ def transform(
layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
should_parallelize_layer_predicate_func: Optional[Callable[["torch.nn.Module"], bool]] = None,
) -> "torch.nn.Module":
if not model.predicate(layer):
if should_parallelize_layer_predicate_func is not None and not should_parallelize_layer_predicate_func(layer):
return layer
return cls._transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)

Expand Down
5 changes: 2 additions & 3 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
is_torch_xla_available,
patch_within_function,
)
from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path
from .utils.cache_utils import get_neuron_cache_path
from .utils.misc import is_main_worker
from .utils.require_utils import requires_neuronx_distributed
from .utils.training_utils import (
Expand Down Expand Up @@ -787,8 +787,7 @@ def _inner_training_loop(
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

# Train!
# parameter_count = get_model_param_count(model, trainable_only=True)
parameter_count = 10
parameter_count = get_model_param_count(model, trainable_only=True)
if is_main_worker():
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def numel(parameter_name, parameter) -> int:

if get_pipeline_model_parallel_size() > 1:
param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device())
# param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True))
param_count = int(param_count.detach().item())

return param_count

0 comments on commit dd49c38

Please sign in to comment.