Skip to content

Commit

Permalink
[WIP] initial support for pp
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 31, 2023
1 parent e394ec5 commit 2920df7
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 23 deletions.
35 changes: 24 additions & 11 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,32 +349,45 @@ def prepare_model_for_xla_fsdp(

return model

@requires_neuronx_distributed
def _prepare_model_for_tp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
from neuronx_distributed.pipeline import NxDPPModel

if model in self._models or Parallelizer.was_parallelized(model):
return model

cpu_ids = [id(v) for v in model.parameters()]
cpu_ids = {name: id(param) for name, param in model.named_parameters()}
# TODO: enable self.device (if needed).
model = self.state.mp_plugin.parallelize_model(model, device=None)

if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
model.to(torch.float32)

def _tie_or_clone_weights_for_tp(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
output_embeddings.weight = input_embeddings.weight
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings

with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_tp)]):
model.tie_weights()
move_model_to_device(model, self.device)
model.tie_weights()
self._model_cpu_parameters_to_xla[id(model)] = dict(zip(cpu_ids, model.parameters()))
if isinstance(model, NxDPPModel):
with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_tp)]):
model.tie_weights()
model.move_model_to_device()
model.tie_weights()
xla_ids = {name: param for name, param in model.local_named_parameters()}
self._model_cpu_parameters_to_xla[id(model)] = {cpu_ids[name]: xla_ids[name] for name, _ in model.local_named_parameters()}
else:
if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model.to(torch.bfloat16)
else:
model.to(torch.float32)

with ModelPatcher(patching_specs=[(model, "_tie_or_clone_weights", _tie_or_clone_weights_for_tp)]):
model.tie_weights()
move_model_to_device(model, self.device)
model.tie_weights()
xla_ids = {name: id(param) for name, param in model.named_parameters()}
self._model_cpu_parameters_to_xla[id(model)] = {cpu_ids[name]: xla_ids[name] for name, _ in model.named_parameters()}

device_placement = False

return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
Expand Down
55 changes: 49 additions & 6 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_paral
)



class PipelineParallelismSpecs:
TRASNFORMER_LAYER_CLS: Type["torch.nn.Module"]
LEAF_MODULE_CLASSES_NAMES: Optional[List[Union[str, Type["torch.nn.Module"]]]] = None
Expand All @@ -122,6 +121,22 @@ def create_pipeline_cuts(cls, model: PreTrainedModel, pipeline_parallel_size: in

return pipeline_cuts

# @classmethod
# def create_pipeline_cuts(cls, model, pipeline_parallel_size):
# """
# Evenly split the transformer layers between the PP ranks
# """
# assert model.config.num_hidden_layers % pipeline_parallel_size == 0
# num_layer_per_partition = model.config.num_hidden_layers // pipeline_parallel_size
# pipeline_cuts = []
# current_cut = num_layer_per_partition - 1
# for i in range(pipeline_parallel_size-1):
# pipeline_cuts.append(f"model.layers.{current_cut}")
# current_cut += num_layer_per_partition
# if torch.distributed.get_rank() == 0:
# print(f"pipeline_cuts {pipeline_cuts}")
# return pipeline_cuts

@classmethod
def leaf_module_cls(cls) -> List[str]:
if cls.LEAF_MODULE_CLASSES_NAMES is None:
Expand Down Expand Up @@ -170,8 +185,9 @@ def _get_parameter_names_for_current_pipeline(cls, model: "torch.nn.Module") ->
)
pp_size = get_pipeline_model_parallel_size()
pp_rank = get_pipeline_model_parallel_rank()
all_parameter_names = {n for n, _ in model.named_parameters()}
if pp_size == 1:
return {n for n, _ in model.named_parameters()}
return all_parameter_names

if cls.PIPELINE_PARALLELISM_SPECS_CLS is None:
raise NotImplementedError(f"{cls} does not support pipeline parallelism.")
Expand All @@ -196,7 +212,15 @@ def _get_parameter_names_for_current_pipeline(cls, model: "torch.nn.Module") ->
# `mod.named_parameters()` to get the fully qualified names.
name = parameter2name[param]
parameter_names.add(name)
return parameter_names

parameter_outside_of_transformer_layers_names = set()
for mod in model.modules():
if not isinstance(mod, cls.PIPELINE_PARALLELISM_SPECS_CLS.TRASNFORMER_LAYER_CLS):
for name, _ in mod.named_parameters():
if name not in parameter_names:
parameter_outside_of_transformer_layers_names.add(name)

return parameter_names | parameter_outside_of_transformer_layers_names


@abstractclassmethod
Expand Down Expand Up @@ -295,6 +319,8 @@ def parallelize(
)

names_of_the_parameters_to_consider = cls._get_parameter_names_for_current_pipeline(model)
if torch.distributed.get_rank() == 0:
print("NAMES TO CONSIDER", names_of_the_parameters_to_consider)

weight_map = getattr(model, "_weight_map", None)

Expand All @@ -309,8 +335,8 @@ def parallelize(
for name, parameter in named_parameters(model, remove_duplicate=False):

# Skipping the parameters that will not end-up in this pipeline rank.
# if name not in names_of_the_parameters_to_consider:
# continue
if name not in names_of_the_parameters_to_consider:
continue

split = name.rsplit(".", maxsplit=1)
module = model.get_submodule(split[0])
Expand Down Expand Up @@ -382,17 +408,25 @@ def parallelize(
raise NotImplementedError("{cls} does not support pipeline parallelism.")

model.config.return_dict = False
model.config.use_cache = False
model.config.output_attentions = False
# model.config.output_hidden_states =
model = NxDPPModel(
model,
transformer_layer_cls=cls.PIPELINE_PARALLELISM_SPECS_CLS.TRASNFORMER_LAYER_CLS,
num_microbatches=3,
output_loss_value_spec=(True, False),
input_names=["input_ids", "attention_mask"],
input_names=["input_ids", "attention_mask", "labels"],
pipeline_cuts=cls.PIPELINE_PARALLELISM_SPECS_CLS.create_pipeline_cuts(model, pp_size),
leaf_module_cls=cls.PIPELINE_PARALLELISM_SPECS_CLS.leaf_module_cls(),
trace_file_path="/home/ubuntu/trace",
use_zero1_optimizer=False,
)

for name, p in model.local_named_parameters():
if p.device == torch.device("meta"):
print(name)

# TODO: see how it works out with pp.
if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)
Expand Down Expand Up @@ -436,18 +470,27 @@ def optimizer_cpu_params_to_xla_params(
new_param = {k: v for k, v in param.items() if k != "params"}
params = []
for p in param["params"]:
# This can be the case with pipeline parallelism.
if id(p) not in orig_param_to_parallel_param_on_xla:
continue
params.append(orig_param_to_parallel_param_on_xla[id(p)])
new_param["params"] = params
else:
new_param = []
for p in param:
# This can be the case with pipeline parallelism.
if id(p) not in orig_param_to_parallel_param_on_xla:
continue
new_param.append(orig_param_to_parallel_param_on_xla[id(p)])
parameters_on_xla.append(new_param)
else:
for param_group in optimizer.param_groups:
new_params = []
params = param_group["params"]
for idx in range(len(params)):
if id(params[idx]) not in orig_param_to_parallel_param_on_xla:
need_to_create_new_optimizer = True
continue
param_on_xla = orig_param_to_parallel_param_on_xla[id(params[idx])]
if params[idx] != param_on_xla:
need_to_create_new_optimizer = True
Expand Down
16 changes: 10 additions & 6 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import torch
from packaging import version
from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments
from transformers.dependency_versions_check import dep_version_check
from transformers.integrations import is_fairscale_available
from transformers.modeling_utils import unwrap_model
from transformers.trainer import (
OPTIMIZER_NAME,
Expand Down Expand Up @@ -80,10 +78,6 @@
else:
IS_SAGEMAKER_MP_POST_1_10 = False

if is_fairscale_available():
dep_version_check("fairscale")


logger = logging.get_logger("transformers.trainer")

KEEP_HF_HUB_PROGRESS_BARS = os.environ.get("KEEP_HF_HUB_PROGRESS_BARS")
Expand Down Expand Up @@ -280,6 +274,16 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
def create_optimizer(self):
return super().create_optimizer()

def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
from neuronx_distributed.pipeline import NxDPPModel

if isinstance(model, NxDPPModel):
inputs = self._prepare_inputs(inputs)
loss = model.run_train(**inputs)
return loss.detach() / self.args.gradient_accumulation_steps
return super().training_step(model, inputs)


def compute_loss(self, model, inputs, return_outputs: bool = False):
self.state.last_inputs = inputs
self.trigger_on_step_middle_for_neuron_cache_callback(model)
Expand Down

0 comments on commit 2920df7

Please sign in to comment.