Skip to content

Commit

Permalink
Enable Sequence Parallelism (#429)
Browse files Browse the repository at this point in the history
  • Loading branch information
polisettyvarma authored Sep 4, 2024
1 parent 1280f59 commit 0d6e379
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 14 deletions.
3 changes: 2 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def validate_args(args, defaults={}):
args.async_tensor_model_parallel_allreduce = False

if not args.use_dataset_only:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if deepspeed.accelerator.get_accelerator().device_name() == "cuda" \
and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
Expand Down
4 changes: 3 additions & 1 deletion megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

# Parts of the code here are adapted from PyTorch
Expand Down Expand Up @@ -450,7 +451,8 @@ def linear_with_grad_accumulation_and_async_allreduce(
]

if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if get_accelerator().device_name() == "cuda" \
and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
Expand Down
14 changes: 12 additions & 2 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
with some changes. """

from deepspeed.accelerator.real_accelerator import get_accelerator
import numbers
import torch
from torch.nn.parameter import Parameter
Expand All @@ -13,6 +15,7 @@
import inspect

from megatron.core.utils import make_viewless_tensor
from megatron import get_args

try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
Expand Down Expand Up @@ -56,8 +59,15 @@ def __init__(self, normalized_shape, eps=1e-5,
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
init_device = None
if get_accelerator().device_name() == 'hpu':
init_device = get_accelerator().current_device_name()
self.weight = Parameter(torch.empty(*normalized_shape,
device=init_device,
dtype=get_args().params_dtype))
self.bias = Parameter(torch.empty(*normalized_shape,
device=init_device,
dtype=get_args().params_dtype))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
Expand Down
8 changes: 6 additions & 2 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""GPT-2 model."""
Expand Down Expand Up @@ -393,9 +394,12 @@ def _to_float16(inputs):
if args.normalization == 'layernorm':
self.specs.append(LayerSpec(LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon))
eps=args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel))
else:
self.specs.append(LayerSpec(RMSNorm, args.hidden_size, args.layernorm_epsilon))
self.specs.append(LayerSpec(RMSNorm, args.hidden_size,
args.layernorm_epsilon,
sequence_parallel=args.sequence_parallel))

def _logits_helper(embedding, lm_output):
"""A wrapper to massage inputs/outputs from pipeline. """
Expand Down
5 changes: 3 additions & 2 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Transformer based language model."""
Expand Down Expand Up @@ -256,8 +257,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):

# Dropout.
if self.sequence_parallel:
# already partition sequence, do not need scatter_to_sequence_parallel_region
# embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
# already partition sequence, do not need scatter_to_sequence_parallel_region ?
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
Expand Down
17 changes: 15 additions & 2 deletions megatron/model/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.

from deepspeed.accelerator import get_accelerator
from megatron import get_args

import torch
from torch.nn import init
from torch.nn.parameter import Parameter

# Taken from facebookresearch/llama
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = Parameter(torch.ones(dim))
init_device = None
if get_accelerator().device_name() == 'hpu':
init_device = get_accelerator().current_device_name()
self.weight = Parameter(torch.empty(dim,
device=init_device,
dtype=get_args().params_dtype))
init.ones_(self.weight)
setattr(self.weight, 'sequence_parallel', sequence_parallel)

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return output * self.weight
13 changes: 9 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,8 @@ def __init__(self, config,
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# Self attention.
self.self_attention = ParallelAttention(
config,
Expand All @@ -957,7 +958,8 @@ def __init__(self, config,
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# Cross attention.
if self.layer_type in (LayerType.decoder,
LayerType.retro_decoder,
Expand All @@ -977,7 +979,9 @@ def __init__(self, config,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.post_inter_attention_layernorm = RMSNorm(config.hidden_size,
config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)

# MLP
self.num_experts = num_experts
Expand Down Expand Up @@ -1780,7 +1784,8 @@ def build_layer(layer_number, n_e):
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)

def _get_layer(self, layer_number):
return self.layers[layer_number]
Expand Down

0 comments on commit 0d6e379

Please sign in to comment.