Skip to content

Commit

Permalink
[WIP] llama-70b
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 2, 2024
1 parent 861c782 commit eaae663
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 20 deletions.
7 changes: 4 additions & 3 deletions examples/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,10 @@ def main():

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# TODO: uncomment that.
# embedding_size = model.get_input_embeddings().weight.shape[0]
# if len(tokenizer) > embedding_size:
# model.resize_token_embeddings(len(tokenizer))

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down
41 changes: 33 additions & 8 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dataclasses import asdict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union, Callable

import torch
from transformers import PreTrainedModel
Expand Down Expand Up @@ -98,7 +98,7 @@ class PipelineParallelismSpecs:

@classmethod
@requires_torch_xla
def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: int) -> List[str]:
def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: int, log: bool = True) -> List[str]:
"""
Creates the pipeline cuts, e.g. the name of the layers at each the cuts happen for pipeline parallelism.
"""
Expand All @@ -117,7 +117,7 @@ def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: in
for cut_idx in range(num_layers_per_partition - 1, num_layers - 1, num_layers_per_partition)
]

if xm.get_local_ordinal() == 0:
if log and xm.get_ordinal() == 0:
logger.info(f"Pipeline parallelism cuts: {pipeline_cuts}.")

return pipeline_cuts
Expand Down Expand Up @@ -197,7 +197,7 @@ def _get_parameter_names_for_current_pipeline(
if not cls.supports_pipeline_parallelism():
raise NotImplementedError(f"{cls} does not support pipeline parallelism.")

cuts = cls.PIPELINE_PARALLELISM_SPECS_CLS.create_pipeline_cuts(model, pp_size)
cuts = cls.PIPELINE_PARALLELISM_SPECS_CLS.create_pipeline_cuts(model, pp_size, log=False)

start_module_name = cuts[pp_rank - 1] if pp_rank >= 1 else None
end_module_name = None if pp_rank == pp_size - 1 else cuts[pp_rank]
Expand Down Expand Up @@ -243,6 +243,7 @@ def _parallelize(
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
sequence_parallel_enabled: bool = False,
should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts.
Expand All @@ -258,6 +259,7 @@ def _parallelize(
This can be disabled in the case when the TP size does not divide the vocabulary size.
sequence_parallel_enabled (`bool`, defaults to `False`):
Whether or not sequence parallelism is enabled.
# TODO: add docstring
Returns:
`PreTrainedModel`: The parallelized model.
"""
Expand Down Expand Up @@ -304,6 +306,7 @@ def parallelize(
Returns:
`PreTrainedModel`: The parallelized model.
"""
import torch_xla.core.xla_model as xm
from neuronx_distributed import parallel_layers

if sequence_parallel_enabled and not cls.supports_sequence_parallelism():
Expand All @@ -322,13 +325,34 @@ def parallelize(

# Parallelizing the model.
# This needs to be done prior to preparing the model for sequence parallelism because modules can be overriden.

names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline(
model, remove_duplicate=True
)


name_to_parameter = dict(named_parameters(model, remove_duplicate=False))
parameter_to_name = {p: n for n, p in name_to_parameter.items()}

def predicate_func(layer):
for n, p in layer.named_parameters():
if p not in parameter_to_name:
print(n)
return False
names = {parameter_to_name[p] for p in layer.parameters()}
return names < names_of_the_parameters_to_consider

model.predicate = predicate_func

if tp_size > 1:
model = cls._parallelize(
model,
device=device,
parallelize_embeddings=parallelize_embeddings,
sequence_parallel_enabled=sequence_parallel_enabled,
# should_parallelize_predicate_func=predicate_func,
)
xm.rendezvous("End of tensor parallelism")

# Preparing the model for sequence parallelism:
sp_specs_cls = cls.SEQUENCE_PARALLELSIM_SPECS_CLS
Expand Down Expand Up @@ -358,10 +382,6 @@ def parallelize(
# The model was not loaded lazily, it is already ready.
weight_map = getattr(model, "_weight_map", {})

names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline(
model, remove_duplicate=True
)

with torch.no_grad():
tied_weights = {}
new_parameters = set()
Expand Down Expand Up @@ -482,6 +502,8 @@ def parallelize(
if left_uninitialized and hasattr(mod, "reset_parameters"):
initialize_torch_nn_module(mod, parameter_names)

xm.rendezvous("End of initalization")

pp_size = get_pipeline_model_parallel_size()
if pp_size > 1:
if not cls.supports_pipeline_parallelism():
Expand All @@ -506,9 +528,12 @@ def parallelize(
use_zero1_optimizer=pipeline_parallel_use_zero1_optimizer,
)

xm.rendezvous("End of pipeline paralellism")

if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)


return model

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ class LLamaParallelMLP(ParallelMLP):
SECOND_LINEAR_NAME = "down_proj"

@classmethod
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand All @@ -339,7 +339,7 @@ def transform(
) -> "torch.nn.Module":
# TODO: Make it smart by merging the gate and the up_proj.
# WARNING: be careful of the interleaved outputs when doing TP!
layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)
layer = super()._transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)

weight_map = getattr(model, "_weight_map", None)

Expand Down
26 changes: 19 additions & 7 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_linear_weight_info(
return linear_layer_weight_info, linear_layer_bias_weight_info

@abstractclassmethod
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand All @@ -126,6 +126,18 @@ def transform(
The device where the new parallel layer should be put.
"""

@classmethod
def transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
) -> "torch.nn.Module":
if not model.predicate(layer):
return layer
return cls._transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)


class ParallelEmbedding(ParallelLayer):
"""
Expand Down Expand Up @@ -164,7 +176,7 @@ def overwrite_vocab_size_value_for_cross_entropy_computation(cls, layer: "torch.

@classmethod
@requires_neuronx_distributed
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand Down Expand Up @@ -295,7 +307,7 @@ class ParallelSelfAttention(ParallelLayer):

@classmethod
@requires_neuronx_distributed
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand Down Expand Up @@ -475,7 +487,7 @@ class ParallelSelfAttentionWithFusedQKV(ParallelLayer):

@classmethod
@requires_neuronx_distributed
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand Down Expand Up @@ -582,7 +594,7 @@ class ParallelSelfOutput(ParallelLayer):
OUTPUT_PROJECTION_NAME = "dense"

@classmethod
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand Down Expand Up @@ -632,7 +644,7 @@ class ParallelMLP(ParallelLayer):
SECOND_LINEAR_NAME: str

@classmethod
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand Down Expand Up @@ -789,7 +801,7 @@ def patch_cross_entropy(cls, model: "PreTrainedModel"):

@classmethod
@requires_neuronx_distributed
def transform(
def _transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
Expand Down
6 changes: 6 additions & 0 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def linear_to_parallel_linear(
)
parallel_linear_layer.weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(parallel_linear_layer.weight, True)
del weight_data
elif linear_layer.weight.device != torch.device("meta"):
parallel_linear_layer.weight.copy_(
linear_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :]
Expand All @@ -433,6 +434,7 @@ def linear_to_parallel_linear(
)
parallel_linear_layer.bias.copy_(bias_weight_data)
mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, True)
del bias_weight_data
elif linear_layer.bias.device != torch.device("meta"):
if gather_output:
parallel_linear_layer.bias.copy_(linear_layer.bias)
Expand All @@ -444,6 +446,10 @@ def linear_to_parallel_linear(
else:
mark_parameter_init_status_during_parallelization(parallel_linear_layer.bias, False)

del linear_layer.weight
if linear_layer.bias is not None:
del linear_layer.bias

return parallel_linear_layer


Expand Down

0 comments on commit eaae663

Please sign in to comment.