Skip to content

Commit

Permalink
GQA optimization for TP (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Mar 20, 2024
1 parent 709b625 commit 4a7df1a
Show file tree
Hide file tree
Showing 17 changed files with 1,394 additions and 285 deletions.
4 changes: 2 additions & 2 deletions notebooks/text-generation/scripts/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def training_function(script_args, training_args):
# if (int(os.environ.get("RANK", -1)) == 0) and int(training_args.tensor_parallel_size) > 1:
# print("Converting sharded checkpoint to consolidated format")
# from optimum.neuron.distributed.checkpointing import (
# consolidate_tensor_parallel_checkpoints_to_unified_checkpoint,
# consolidate_model_parallel_checkpoints_to_unified_checkpoint,
# )
# from shutil import rmtree

# consolidate_tensor_parallel_checkpoints_to_unified_checkpoint(
# consolidate_model_parallel_checkpoints_to_unified_checkpoint(
# training_args.output_dir, training_args.output_dir, "pytorch"
# )
# rmtree(os.path.join(training_args.output_dir, "tensor_parallel_shards")) # remove sharded checkpoint files
Expand Down
4 changes: 2 additions & 2 deletions optimum/commands/neuron/subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import TYPE_CHECKING

from ...neuron.distributed import consolidate_tensor_parallel_checkpoints_to_unified_checkpoint
from ...neuron.distributed import consolidate_model_parallel_checkpoints_to_unified_checkpoint
from ...utils import logging
from ..base import BaseOptimumCLICommand

Expand Down Expand Up @@ -53,7 +53,7 @@ def parse_args(parser: "ArgumentParser"):
def run(self):
checkpoint_format = "safetensors" if self.args.format == "safetensors" else "pytorch"
logger.info(f"Consolidating checkpoints from {self.args.checkpoint_dir} to the {checkpoint_format} format...")
consolidate_tensor_parallel_checkpoints_to_unified_checkpoint(
consolidate_model_parallel_checkpoints_to_unified_checkpoint(
self.args.checkpoint_dir,
self.args.output_dir,
save_format=self.args.format,
Expand Down
15 changes: 13 additions & 2 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def prepare_model_for_xla_fsdp(
def _prepare_model_for_mp(
self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False
):
import torch_xla.core.xla_model as xm
from neuronx_distributed.pipeline import NxDPPModel

if model in self._models or Parallelizer.was_parallelized(model):
Expand All @@ -421,7 +422,7 @@ def _prepare_model_for_mp(
cpu_ids = {name: id(param) for name, param in model.named_parameters()}
tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
# TODO: enable self.device (if needed).
# TODO: use self.device.
model = self.state.mp_plugin.parallelize_model(model, device=None)

if model_main_input_name is not None:
Expand All @@ -435,6 +436,11 @@ def _prepare_model_for_mp(
else:
model_to_cast = model

# Update CPU ids
original_parameter_names_to_gqa_qkv_names = model._gqa_qkv_metadata["original_names_to_gqa_qkv_names"]
for key in list(cpu_ids.keys()):
cpu_ids[original_parameter_names_to_gqa_qkv_names.get(key, key)] = cpu_ids.pop(key)

model_to_cast = model.local_module if isinstance(model, NxDPPModel) else model
if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1":
model_to_cast.to(torch.bfloat16)
Expand All @@ -460,6 +466,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
move_model_to_device(model, self.device)
tie_parameters(model, tied_parameters_dict)
xla_params = dict(model.named_parameters())

symmetric_diff = set(cpu_ids.keys()).symmetric_difference((xla_params.keys()))
if symmetric_diff:
raise ValueError(
Expand All @@ -470,6 +477,7 @@ def _tie_or_clone_weights_for_mp(self, output_embeddings, input_embeddings):
cpu_ids[name]: xla_params[name] for name, _ in model.named_parameters()
}

xm.mark_step()
device_placement = False

return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
Expand All @@ -485,8 +493,11 @@ def prepare_model(

model = self.patch_model_for_neuron(model)

# We do not want to use the cache here as it would imply more communication that we do not need.
# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
# need.
model.config.use_cache = False
model.config.output_attentions = False
model.config.output_hidden_states = False

if self.distributed_type is NeuronDistributedType.XLA_FSDP:
return self.prepare_model_for_xla_fsdp(
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ModelParallelismPlugin:
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True
sequence_parallel_enabled: bool = False
kv_size_multiplier: Optional[int] = None
pipeline_parallel_size: int = 1
pipeline_parallel_num_microbatches: int = 1
pipeline_parallel_use_zero1_optimizer: bool = False
Expand Down Expand Up @@ -175,6 +176,7 @@ def parallelize_model(
device=device,
parallelize_embeddings=self.parallelize_embeddings,
sequence_parallel_enabled=self.sequence_parallel_enabled,
kv_size_multiplier=self.kv_size_multiplier,
pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches,
pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer,
pipeline_parallel_gradient_checkpointing_enabled=self.gradient_checkpointing,
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from .base import Parallelizer
from .checkpointing import (
consolidate_tensor_parallel_checkpoints,
consolidate_tensor_parallel_checkpoints_to_unified_checkpoint,
consolidate_model_parallel_checkpoints,
consolidate_model_parallel_checkpoints_to_unified_checkpoint,
)
from .parallelizers_manager import ParallelizersManager
from .utils import lazy_load_for_parallelism, make_optimizer_constructor_lazy
Loading

0 comments on commit 4a7df1a

Please sign in to comment.