Skip to content

Commit

Permalink
using nccl ops from TRT-LLM namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Nov 8, 2024
1 parent 6d40ff1 commit b6f5980
Show file tree
Hide file tree
Showing 7 changed files with 296 additions and 27 deletions.
16 changes: 16 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@ See the examples started with `data_parallel` for more details.
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

3. Tensor parallel distributed inference using nccl ops plugin

apt install libmpich-dev
apt install libopenmpi-dev
pip install tensorrt-llm
#then pip install the tensorrt and torch version compatible with installed torchTRT
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py

4. Tensor parallel distributed llama3 inference using nccl ops plugin

apt install libmpich-dev
apt install libopenmpi-dev
pip install tensorrt-llm
#then pip install the tensorrt and torch version compatible with installed torchTRT
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
4 changes: 3 additions & 1 deletion examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
accelerate
transformers
diffusers
diffusers
site
tensorrt-llm
20 changes: 7 additions & 13 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,20 @@
import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_nccl_ops import register_nccl_ops
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
tp_size = 2
device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")

logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)

tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
Expand All @@ -38,7 +33,7 @@
)

with torch.no_grad():
model = ParallelTransformer(model_args, tp_mesh)
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -53,7 +48,6 @@
"use_python_runtime": True,
"workspace_size": 1 << 33,
"debug": False,
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
},
dynamic=False,
)
Expand Down
186 changes: 186 additions & 0 deletions examples/distributed_inference/tensor_parallel_nccl_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import ctypes
import logging
import os
import site
from enum import IntEnum, IntFlag, auto
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import tensorrt_llm
import torch
import torch.distributed as dist
import torch_tensorrt
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.fx import GraphModule, Node
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
custom_fused_all_gather_op,
custom_fused_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name


# class for AllReduce
class AllReduceStrategy(IntEnum):
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
They must be kept in sync.
"""

NCCL = 0
ONESHOT = 1
TWOSHOT = 2
AUTO = 3


class AllReduceConfig(IntFlag):
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
They must be kept in sync
"""

USE_MEMCPY = auto()
PUSH_MODE = auto()


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

return local_rank, world_size


def register_nccl_ops(logger_file_name):
# Initialization
initialize_distributed_env()
# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()
logger = initialize_logger(_rank, logger_file_name)
device_id = (
_rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

# TensorRT NCCL plugins
# Iterate over all registered plugin creators
plugin_registry = trt.get_plugin_registry()
for plugin_creator in plugin_registry.plugin_creator_list:
logger.info(
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
)

@dynamo_tensorrt_converter(custom_fused_all_gather_op)
def insert_nccl_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
_world_size = int(os.environ["WORLD_SIZE"])
group = list(range(_world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)
p_dtype = trt.float16
pf_type = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
set_layer_name(layer, target, name)
return layer.get_output(0)

@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
def insert_nccl_reduce_scatter_plugin(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
plug_inputs = [args[0]]
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"ReduceScatter", "1", "tensorrt_llm"
)

assert allreduce_plg_creator is not None

counter = 0
strategy = AllReduceStrategy.NCCL
config = AllReduceConfig(0)

world_size = dist.get_world_size()
group = list(range(world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)

p_dtype = trt.float16
pf_dtype = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = [group, pf_dtype]
p_strategy = trt.PluginField(
"strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8
)
pfc.append(p_strategy)
p_config = trt.PluginField(
"config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8
)
pfc.append(p_config)
p_counter = trt.PluginField(
"counter", np.array([counter], np.int32), trt.PluginFieldType.INT32
)
pfc.append(p_counter)

pfc = trt.PluginFieldCollection(pfc)
ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)

layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug)
set_layer_name(layer, target, name)
return layer.get_output(0)

return device_mesh, _world_size, _rank, logger
23 changes: 10 additions & 13 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
import sys
import time

import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_nccl_ops import register_nccl_ops
from torch.distributed._tensor import Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = register_nccl_ops(
"./tensor_parallel_simple_example"
)

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand All @@ -36,14 +40,7 @@ def forward(self, x):
return x


# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


print(f"Starting PyTorch TP example on rank {_rank}.")
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
Expand Down Expand Up @@ -91,9 +88,9 @@ def forward(self, x):
output = tp_model(inp)
end = time.time()
if i == 0:
print(f"Compilation time is {end-start}")
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
print(f"Inference time is {end-start}")
logger.info(f"Inference time is {end-start}")
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
from .fuse_distributed_ops import fuse_distributed_ops
from .fuse_prims_broadcast import fuse_prims_broadcast
from .lower_linear import lower_linear
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
Expand All @@ -26,6 +27,7 @@
lower_scaled_dot_product_attention,
lower_linear,
fuse_prims_broadcast,
fuse_distributed_ops,
replace_max_pool_with_indices,
replace_full_like_with_full,
view_to_reshape,
Expand Down
Loading

0 comments on commit b6f5980

Please sign in to comment.