From 48f12e411768883f285ed58f0b2fe5df4a24fcfc Mon Sep 17 00:00:00 2001 From: apbose Date: Mon, 23 Sep 2024 11:09:23 -0700 Subject: [PATCH] using nccl ops from TRT-LLM namespace --- .../distributed_inference/requirement.txt | 5 +- .../tensor_parallel_simple_example.py | 191 +++++++++++++++++- .../lowering/passes/_aten_lowering_pass.py | 2 + .../lowering/passes/fuse_distributed_ops.py | 71 +++++++ 4 files changed, 261 insertions(+), 8 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py diff --git a/examples/distributed_inference/requirement.txt b/examples/distributed_inference/requirement.txt index 6d8e0aa9f2..542f2776a0 100644 --- a/examples/distributed_inference/requirement.txt +++ b/examples/distributed_inference/requirement.txt @@ -1,3 +1,6 @@ accelerate transformers -diffusers \ No newline at end of file +diffusers +site +# Install tensorrt-llm without its dependencies (use the command separately). pip install tensorrt-llm --no-deps +tensorrt-llm \ No newline at end of file diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index 470487a751..44728e0413 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -1,8 +1,17 @@ +import ctypes +import logging import os +import site import sys import time +from enum import IntEnum, IntFlag, auto +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import numpy as np +import tensorrt as trt +import tensorrt_llm import torch +import torch.distributed as dist import torch.nn as nn import torch_tensorrt from torch.distributed._tensor import Shard @@ -12,6 +21,181 @@ RowwiseParallel, parallelize_module, ) +from torch.fx import GraphModule, Node +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + custom_fused_all_gather_op, + custom_fused_reduce_scatter_op, +) +from torch_tensorrt.dynamo.types import TRTTensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name + + +# This is required for env initialization since we use mpirun +def initialize(rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + return local_rank, world_size + + +initialize() +# create a device mesh based on the given world_size. +_world_size = int(os.environ["WORLD_SIZE"]) + +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) +_rank = device_mesh.get_rank() +device_id = _rank % torch.cuda.device_count() # Ensure each rank gets a unique device +torch.cuda.set_device(device_id) + + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +fh = logging.FileHandler(f"./tensor_parallel_simple_example_{_rank}.log", mode="w") +fh.setLevel(logging.INFO) +logger.addHandler(fh) + + +# TensorRT NCCL plugins +tensorrt_llm_lib_path = tensorrt_llm.__file__ +plugin_lib_path = tensorrt_llm_lib_path + "/libs/libnvinfer_plugin_tensorrt_llm.so" +try: + ctypes.CDLL(plugin_lib_path) + logger.info(f"plugin loaded successfully") +except OSError as e: + logger.info(f"unsuccessful load : {e}") +trt.init_libnvinfer_plugins(None, "") +#Iterate over all registered plugin creators +plugin_registry = trt.get_plugin_registry() +for plugin_creator in plugin_registry.plugin_creator_list: + logger.info( + f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}" + ) + + +# class for AllReduce +class AllReduceStrategy(IntEnum): + """Warning: actual definition is in kernels/customAllReduceKernels.h. + + They must be kept in sync. + """ + + NCCL = 0 + ONESHOT = 1 + TWOSHOT = 2 + AUTO = 3 + + +class AllReduceConfig(IntFlag): + """Warning: actual definition is in kernels/customAllReduceKernels.h. + + They must be kept in sync + """ + + USE_MEMCPY = auto() + PUSH_MODE = auto() + + +@dynamo_tensorrt_converter(custom_fused_all_gather_op) +def insert_nccl_gather_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + plug_inputs = [args[0]] + allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "AllGather", "1", "tensorrt_llm" + ) + assert allgather_plg_creator is not None + world_size = dist.get_world_size() + group = list(range(world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + p_dtype = trt.float16 + pf_type = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = trt.PluginFieldCollection([group, pf_type]) + allgather = allgather_plg_creator.create_plugin("allgather", pfc) + layer = ctx.net.add_plugin_v2(plug_inputs, allgather) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op) +def insert_nccl_reduce_scatter_plugin( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + plug_inputs = [args[0]] + allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "ReduceScatter", "1", "tensorrt_llm" + ) + + assert allreduce_plg_creator is not None + + counter = 0 + strategy = AllReduceStrategy.NCCL + config = AllReduceConfig(0) + + world_size = dist.get_world_size() + group = list(range(world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + + p_dtype = trt.float16 + pf_dtype = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = [group, pf_dtype] + p_strategy = trt.PluginField( + "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_strategy) + p_config = trt.PluginField( + "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_config) + p_counter = trt.PluginField( + "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 + ) + pfc.append(p_counter) + + pfc = trt.PluginFieldCollection(pfc) + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) + + layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) + set_layer_name(layer, target, name) + return layer.get_output(0) + """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py @@ -36,13 +220,6 @@ def forward(self, x): return x -# create a device mesh based on the given world_size. -_world_size = int(os.environ["WORLD_SIZE"]) - -device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) -_rank = device_mesh.get_rank() - - print(f"Starting PyTorch TP example on rank {_rank}.") assert ( _world_size % 2 == 0 diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b6435c0d8c..92afd7ed29 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -5,6 +5,7 @@ from .accumulate_fp32_matmul import accumulate_fp32_matmul from .constant_folding import constant_fold +from .fuse_distributed_ops import fuse_distributed_ops from .fuse_prims_broadcast import fuse_prims_broadcast from .lower_linear import lower_linear from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention @@ -25,6 +26,7 @@ lower_scaled_dot_product_attention, lower_linear, fuse_prims_broadcast, + fuse_distributed_ops, replace_max_pool_with_indices, replace_full_like_with_full, view_to_reshape, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py new file mode 100644 index 0000000000..9856bf37fb --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_distributed_ops.py @@ -0,0 +1,71 @@ +import logging +from typing import Sequence + +import torch + +# dead-code elimination, linting, and recompilation for graph, in-place +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def custom_fused_all_gather_op(args0, args1, args2): + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.all_gather_into_tensor.default(args0, args1, args2) + ) + + +def custom_fused_reduce_scatter_op(args0, args1, args2, args3): + return torch.ops._c10d_functional.wait_tensor.default( + torch.ops._c10d_functional.reduce_scatter_tensor.default( + args0, args1, args2, args3 + ) + ) + + +def fuse_distributed_ops( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + modified_graph = False + for node in gm.graph.nodes: + if ( + node.target + in ( + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ) + and len(node.users) == 1 + and list(node.users)[0].target + == torch.ops._c10d_functional.wait_tensor.default + ): + wait_tensor_node = list(node.users)[0] + fused_op = None + if node.target == torch.ops._c10d_functional.all_gather_into_tensor.default: + fused_op = custom_fused_all_gather_op + fused_op_args = (node.args[0], node.args[1], node.args[2]) + else: + fused_op = custom_fused_reduce_scatter_op + fused_op_args = (node.args[0], node.args[1], node.args[2], node.args[3]) + with gm.graph.inserting_after(wait_tensor_node): + fused_node = gm.graph.create_node( + op="call_function", + target=fused_op, # Define your custom fused function + args=fused_op_args, + ) + + wait_tensor_node.replace_all_uses_with(fused_node) + fused_node.meta.update(node.meta) + modified_graph = True + gm.graph.erase_node(wait_tensor_node) + gm.graph.erase_node(node) + + # If graph was modified, clean it up + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after fusing wait_tensor and distributed op tensor:\n{gm.graph}" + ) + + return gm