diff --git a/megatron/arguments.py b/megatron/arguments.py index 49b3d8e4c6..9228da6ee9 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -399,7 +399,8 @@ def validate_args(args, defaults={}): args.async_tensor_model_parallel_allreduce = False if not args.use_dataset_only: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if deepspeed.accelerator.get_accelerator().device_name() == "cuda" \ + and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if args.sequence_parallel: raise RuntimeError( "Using sequence parallelism requires setting the environment variable " diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 2245113c9c..67a78853aa 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Parts of the code here are adapted from PyTorch @@ -450,7 +451,8 @@ def linear_with_grad_accumulation_and_async_allreduce( ] if not linear_with_grad_accumulation_and_async_allreduce.warned: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": + if get_accelerator().device_name() == "cuda" \ + and os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if sequence_parallel: warnings.warn( "When using sequence parallelism it is recommended to set the " diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 2f3b89014b..d1ef034397 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -1,9 +1,11 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """This code is copied fron NVIDIA apex: https://github.com/NVIDIA/apex with some changes. """ +from deepspeed.accelerator.real_accelerator import get_accelerator import numbers import torch from torch.nn.parameter import Parameter @@ -13,6 +15,7 @@ import inspect from megatron.core.utils import make_viewless_tensor +from megatron import get_args try: from apex.contrib.layer_norm.layer_norm import FastLayerNormFN @@ -56,8 +59,15 @@ def __init__(self, normalized_shape, eps=1e-5, normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape)) - self.bias = Parameter(torch.Tensor(*normalized_shape)) + init_device = None + if get_accelerator().device_name() == 'hpu': + init_device = get_accelerator().current_device_name() + self.weight = Parameter(torch.empty(*normalized_shape, + device=init_device, + dtype=get_args().params_dtype)) + self.bias = Parameter(torch.empty(*normalized_shape, + device=init_device, + dtype=get_args().params_dtype)) self.reset_parameters() self.no_persist_layer_norm = no_persist_layer_norm self.sequence_parallel = sequence_parallel diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 8968c96655..e5e60c43ee 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """GPT-2 model.""" @@ -393,9 +394,12 @@ def _to_float16(inputs): if args.normalization == 'layernorm': self.specs.append(LayerSpec(LayerNorm, args.hidden_size, - eps=args.layernorm_epsilon)) + eps=args.layernorm_epsilon, + sequence_parallel=args.sequence_parallel)) else: - self.specs.append(LayerSpec(RMSNorm, args.hidden_size, args.layernorm_epsilon)) + self.specs.append(LayerSpec(RMSNorm, args.hidden_size, + args.layernorm_epsilon, + sequence_parallel=args.sequence_parallel)) def _logits_helper(embedding, lm_output): """A wrapper to massage inputs/outputs from pipeline. """ diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index ec2ae1877a..3b8e4e0da1 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Transformer based language model.""" @@ -256,8 +257,8 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): # Dropout. if self.sequence_parallel: - # already partition sequence, do not need scatter_to_sequence_parallel_region - # embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + # already partition sequence, do not need scatter_to_sequence_parallel_region ? + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: diff --git a/megatron/model/rmsnorm.py b/megatron/model/rmsnorm.py index 60e8978171..4860d81716 100644 --- a/megatron/model/rmsnorm.py +++ b/megatron/model/rmsnorm.py @@ -1,4 +1,10 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. + +from deepspeed.accelerator import get_accelerator +from megatron import get_args + import torch +from torch.nn import init from torch.nn.parameter import Parameter # Taken from facebookresearch/llama @@ -6,11 +12,18 @@ class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps - self.weight = Parameter(torch.ones(dim)) + init_device = None + if get_accelerator().device_name() == 'hpu': + init_device = get_accelerator().current_device_name() + self.weight = Parameter(torch.empty(dim, + device=init_device, + dtype=get_args().params_dtype)) + init.ones_(self.weight) + setattr(self.weight, 'sequence_parallel', sequence_parallel) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) - return output * self.weight \ No newline at end of file + return output * self.weight diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 036c11566a..74e977103f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -931,7 +931,8 @@ def __init__(self, config, config.hidden_size, eps=config.layernorm_epsilon) else: - self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) # Self attention. self.self_attention = ParallelAttention( config, @@ -957,7 +958,8 @@ def __init__(self, config, config.hidden_size, eps=config.layernorm_epsilon) else: - self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) # Cross attention. if self.layer_type in (LayerType.decoder, LayerType.retro_decoder, @@ -977,7 +979,9 @@ def __init__(self, config, apply_layernorm_1p=args.apply_layernorm_1p, mem_efficient_ln=args.mem_efficient_ln) else: - self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.post_inter_attention_layernorm = RMSNorm(config.hidden_size, + config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) # MLP self.num_experts = num_experts @@ -1780,7 +1784,8 @@ def build_layer(layer_number, n_e): config.hidden_size, eps=config.layernorm_epsilon) else: - self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) + self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon, + sequence_parallel=config.sequence_parallel) def _get_layer(self, layer_number): return self.layers[layer_number]