Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BurstAttention and Ulyless all2all support for long sequence training. #203

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion bmtrain/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .all_gather import all_gather
from .reduce_scatter import reduce_scatter
from .send_recv import send_recv
from .send_recv import send_recv
from .all2all import all2all, all2one
49 changes: 49 additions & 0 deletions bmtrain/benchmark/all2all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from .. import nccl
from .shape import SHAPES
from ..global_var import config
from ..utils import round_up, print_rank
from .utils import format_size
import torch

def all2all():
current_stream = torch.cuda.current_stream()
for shape in SHAPES:
global_size = round_up(shape, config['world_size'] * 2)

result_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda")
global_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda")

start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)

current_stream.record_event(start_evt)
nccl.all2all(global_tensor.storage(), result_tensor.storage(), config['comm'])
current_stream.record_event(end_evt)
current_stream.synchronize()

time_usage = start_evt.elapsed_time(end_evt)
bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage
print_rank("All to All:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))

def all2one():
current_stream = torch.cuda.current_stream()
for shape in SHAPES:
global_size = round_up(shape, config['world_size'] * 2)

result_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda")
global_tensor = torch.empty(global_size // 2, dtype=torch.half, device="cuda")

start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)

current_stream.record_event(start_evt)
nccl.groupStart()
for r in range(config['world_size']):
nccl.all2one(global_tensor.storage(), result_tensor.storage(), r, config['comm'])
nccl.groupEnd()
current_stream.record_event(end_evt)
current_stream.synchronize()

time_usage = start_evt.elapsed_time(end_evt)
bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage
print_rank("All to one:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))
2 changes: 1 addition & 1 deletion bmtrain/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter
from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter, all_to_all
93 changes: 80 additions & 13 deletions bmtrain/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..nccl import reduceScatter as ncclReduceScatter
from ..nccl import send as ncclSend
from ..nccl import recv as ncclRecv
from ..nccl import all2all as ncclAllToAll
from ..nccl import commCount,commRank,NCCLCommunicator
DTYPE_LIST = [
torch.float64,
Expand Down Expand Up @@ -44,6 +45,13 @@ def recv_meta(prev_rank, comm):
shape = meta_data[2:n_dims+2].tolist()
return dtype,shape

def to_contiguous(x):
if not x.is_contiguous():
x = x.contiguous()
if x.storage_offset() != 0 or x.storage().size() != x.numel():
x = x.clone()
return x

class OpBroadcast(torch.autograd.Function):

@staticmethod
Expand Down Expand Up @@ -72,10 +80,7 @@ def forward(ctx, input : torch.Tensor, comm = None):
if comm is None:
comm = config["comm"]
world_size = commCount(comm)
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
input = input.clone()
input = to_contiguous(input)
output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device)
ctx.comm = comm
ncclAllGather(
Expand All @@ -87,6 +92,7 @@ def forward(ctx, input : torch.Tensor, comm = None):

@staticmethod
def backward(ctx, grad_output):
grad_output = to_contiguous(grad_output)
return grad_output[commRank(ctx.comm)], None

def all_gather(x : torch.Tensor, comm = None):
Expand All @@ -113,10 +119,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None)
ctx.comm = comm
rank = commRank(comm)
assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes"
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
input = input.clone()
input = to_contiguous(input)
output_shape = (input.shape[0] // commCount(comm), *input.shape[1:])
output = torch.empty( output_shape, dtype=input.dtype, device=input.device )
ncclReduceScatter(
Expand All @@ -136,6 +139,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None)

@staticmethod
def backward(ctx, grad_output):
grad_output = to_contiguous(grad_output)
with torch.no_grad():
grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1)
if ctx.op in ["max", "min", "prod"]:
Expand Down Expand Up @@ -169,10 +173,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None)
if comm is None:
comm = config["comm"]
ctx.comm = comm
if not input.is_contiguous():
input = input.contiguous()
if input.storage_offset() != 0 or input.storage().size() != input.numel():
input = input.clone()
input = to_contiguous(input)
output = torch.empty( input.size(), dtype=input.dtype, device=input.device)

ncclAllReduce(
Expand All @@ -193,6 +194,7 @@ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None)

@staticmethod
def backward(ctx, grad_output):
grad_output = to_contiguous(grad_output)
if ctx.op == "sum":
return grad_output, None, None
elif ctx.op == "avg":
Expand Down Expand Up @@ -220,4 +222,69 @@ def all_reduce(x : torch.Tensor, op : str = "sum", comm = None):
return OpAllReduce.apply(x, op, comm)



class OpAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, input : torch.Tensor, comm : NCCLCommunicator = None):
if comm is None:
comm = config["comm"]
ctx.comm = comm
input = to_contiguous(input)
output = torch.empty(input.size(), dtype=input.dtype, device=input.device)

ncclAllToAll(
input.storage(),
output.storage(),
comm
)
return output

@staticmethod
def backward(ctx, grad_output):
grad_output = to_contiguous(grad_output)
grad_input = torch.empty(grad_output.size(), dtype=grad_output.dtype, device=grad_output.device)
ncclAllToAll(
grad_output.storage(),
grad_input.storage(),
ctx.comm
)
return grad_input, None

def all_to_all(x : torch.Tensor, comm = None):
"""Split input tensor and then scatter the split list to all processes in a group.

Args:
x (torch.Tensor): The input tensor of shape (...).

Returns:
torch.Tensor: the concatenated of received tensors

"""
if not config["initialized"]:
raise RuntimeError("BMTrain is not initialized")

assert x.is_cuda
return OpAllToAll.apply(x, comm)

def inverse_permute(permute_dims):
inverse_dims = [0] * len(permute_dims)
for i, dim in enumerate(permute_dims):
inverse_dims[dim] = i
return inverse_dims

def all2all_transpose(tensor : torch.Tensor, gather_dim : int, scatter_dim : int, comm = None):
# Input shape: (B, S, N, D) | (B, N, S, D)
origin_size = list(tensor.size())
output_size = origin_size.copy()
count = commCount(comm)
output_size[gather_dim] = origin_size[gather_dim] * count
output_size[scatter_dim] = origin_size[scatter_dim] // count
inv_order = inverse_permute([gather_dim, scatter_dim, 0, -1])
tensor = tensor.permute(gather_dim, scatter_dim, 0, -1)
tensor = torch.cat(tensor.chunk(count, dim=1), dim=0).contiguous()
tensor = all_to_all(tensor, count)
tensor = tensor.permute(inv_order).contiguous()
return tensor




Loading
Loading