diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index f6ac063ce..1e4c9efac 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -1,3 +1,6 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,14 +14,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +# This file has been modified from its original version +# import os import pathlib import subprocess - -from pathlib import Path - -srcpath = Path(__file__).parent.absolute() +import torch +from torch.utils import cpp_extension # Setting this param to a list has a problem of generating different # compilation commands (with different order of architectures) and @@ -28,6 +32,138 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" +def load(neox_args=None): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + if torch.version.hip is None: + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME + ) + if int(bare_metal_major) >= 11: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,code=sm_80") + if int(bare_metal_minor) >= 1: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_86,code=sm_86") + if int(bare_metal_minor) >= 4: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_87,code=sm_87") + if int(bare_metal_minor) >= 8: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_89,code=sm_89") + if int(bare_metal_major) >= 12: + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / "build" + _create_build_dir(buildpath) + + # Determine verbosity + verbose = True if neox_args is None else (neox_args.rank == 0) + + # Helper function to build the kernels. + def _cpp_extention_load_helper( + name, sources, extra_cuda_flags, extra_include_paths + ): + if torch.version.hip is not None: + extra_cuda_cflags = ["-O3"] + extra_cuda_flags + cc_flag + else: + extra_cuda_cflags = ( + ["-O3", "-gencode", "arch=compute_70,code=sm_70", "--use_fast_math"] + + extra_cuda_flags + + cc_flag + ) + + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=[ + "-O3", + ], + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + verbose=verbose, + ) + + # ============== + # Fused softmax. + # ============== + + if torch.version.hip is not None: + extra_include_paths = [os.path.abspath(srcpath)] + else: + extra_include_paths = [] + + if torch.version.hip is not None: + extra_cuda_flags = [ + "-D__HIP_NO_HALF_OPERATORS__=1", + "-D__HIP_NO_HALF_CONVERSIONS__=1", + ] + else: + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + + # Upper triangular softmax. + sources = [ + srcpath / "scaled_upper_triang_masked_softmax.cpp", + srcpath / "scaled_upper_triang_masked_softmax_cuda.cu", + ] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", + sources, + extra_cuda_flags, + extra_include_paths, + ) + # Masked softmax. + sources = [ + srcpath / "scaled_masked_softmax.cpp", + srcpath / "scaled_masked_softmax_cuda.cu", + ] + scaled_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths + ) + # fused rope + sources = [ + srcpath / "fused_rotary_positional_embedding.cpp", + srcpath / "fused_rotary_positional_embedding_cuda.cu", + ] + fused_rotary_positional_embedding_cuda = _cpp_extention_load_helper( + "fused_rotary_positional_embedding_cuda", + sources, + extra_cuda_flags, + extra_include_paths, + ) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") + + def load_fused_kernels(): try: import scaled_upper_triang_masked_softmax_cuda diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 26c2d1820..7479713ec 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -18,7 +18,9 @@ #include #include #include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif #include #include #include "scaled_masked_softmax.h" diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu index 99a52abd5..475c16833 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -18,7 +18,9 @@ #include #include #include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif #include #include #include "scaled_upper_triang_masked_softmax.h" diff --git a/megatron/initialize.py b/megatron/initialize.py index 72779b094..29afe7f9a 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -59,6 +59,7 @@ def finish_mpu_init(): or neox_args.scaled_masked_softmax_fusion or neox_args.rope_fusion ): + fused_kernels.load(neox_args) fused_kernels.load_fused_kernels() if neox_args.lazy_mpu_init: diff --git a/tests/model/test_fused_kernels.py b/tests/model/test_fused_kernels.py index b8cb34d1b..cc458bf4a 100644 --- a/tests/model/test_fused_kernels.py +++ b/tests/model/test_fused_kernels.py @@ -22,6 +22,7 @@ from transformers import BertTokenizer, GPT2Tokenizer from transformers.models.bert.modeling_bert import BertModel from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from megatron.fused_kernels import load import transformers transformers.logging.set_verbosity( @@ -33,6 +34,7 @@ reason="ModuleNotFoundError: No module named 'scaled_masked_softmax_cuda'" ) def test_load_fused_kernels(): + load() try: import scaled_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda @@ -47,6 +49,7 @@ def test_load_fused_kernels(): @pytest.mark.xfail(reason="SystemExit: None") def test_fused_softmax(): + load() from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes from megatron.model.gpt2_model import ( gpt2_attention_mask_func as attention_mask_func, @@ -149,6 +152,7 @@ def test_fused_softmax(): @pytest.mark.xfail(reason="SystemExit: None") def test_fused_upper_triangle_mask_softmax(): + load() from megatron.model.gpt2_model import ( gpt2_attention_mask_func as attention_mask_func, )