Skip to content

Commit

Permalink
[Feature] Support the scenario where sp size is not divisible by attn…
Browse files Browse the repository at this point in the history
… head num (#769)

* Support the scenario where sp size is not divisible by attn head num

* refactor attention.py

* do not have to set sp_inner_size in config

* rename

* fix lint
  • Loading branch information
HIT-cwh authored Jun 17, 2024
1 parent bddf85d commit 7646e7b
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 11 deletions.
10 changes: 9 additions & 1 deletion xtuner/parallel/sequence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from .setup_distributed import (get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_inner_sequence_parallel_group,
get_inner_sequence_parallel_rank,
get_inner_sequence_parallel_world_size,
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
init_sequence_parallel)
init_inner_sequence_parallel,
init_sequence_parallel,
is_inner_sequence_parallel_initialized)

__all__ = [
'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn',
Expand All @@ -29,5 +34,8 @@
'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist',
'all_to_all', 'gather_for_sequence_parallel',
'split_forward_gather_backward', 'gather_forward_split_backward',
'get_inner_sequence_parallel_group', 'get_inner_sequence_parallel_rank',
'get_inner_sequence_parallel_world_size', 'init_inner_sequence_parallel',
'is_inner_sequence_parallel_initialized',
'pad_cumulative_len_for_sequence_parallel'
]
95 changes: 85 additions & 10 deletions xtuner/parallel/sequence/attention.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch.distributed as dist

from .comm import all_to_all
from .setup_distributed import (get_sequence_parallel_group,
get_sequence_parallel_world_size)
from .comm import (all_to_all, gather_forward_split_backward,
split_forward_gather_backward)
from .setup_distributed import (get_inner_sequence_parallel_group,
get_inner_sequence_parallel_world_size,
get_sequence_parallel_group,
get_sequence_parallel_world_size,
init_inner_sequence_parallel,
is_inner_sequence_parallel_initialized)


def pre_process_for_sequence_parallel_attn(query_states,
key_states,
value_states,
scatter_dim=2,
gather_dim=1):
sequence_parallel_world_size = get_sequence_parallel_world_size()
n_head = query_states.shape[2]
assert n_head % sequence_parallel_world_size == 0, \
b, s_div_sp, h, d = query_states.shape
sp = get_sequence_parallel_world_size()

if not is_inner_sequence_parallel_initialized():
insp = sp // math.gcd(h, sp)
init_inner_sequence_parallel(insp)
else:
insp = get_inner_sequence_parallel_world_size()

def pre_process_for_inner_sp(q, k, v):
if scatter_dim != 2 and gather_dim != 1:
raise NotImplementedError(
'Currently only `scatter_dim == 2` and `gather_dim == 1` '
f'is supported. But got scatter_dim = {scatter_dim} and '
f'gather_dim = {gather_dim}.')

# (b, s_div_sp, h, d) ->
# (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
# (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
# (b, s_div_sp, insp*h, d/insp)
q = q.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
d // insp).transpose(3, 4).flatten(2, 4)
k = k.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
d // insp).transpose(3, 4).flatten(2, 4)
v = v.view(b, s_div_sp, sp // insp, h * insp // sp, insp,
d // insp).transpose(3, 4).flatten(2, 4)

return q, k, v

def post_process_for_inner_sp(q, k, v):
# (b, s, insp*h/sp, d/insp) -> (b, s, insp*h/sp, d)
q = gather_forward_split_backward(q, -1,
get_inner_sequence_parallel_group())
k = gather_forward_split_backward(k, -1,
get_inner_sequence_parallel_group())
v = gather_forward_split_backward(v, -1,
get_inner_sequence_parallel_group())

return q, k, v

assert (h * insp) % sp == 0, \
('The number of attention heads should be divisible by '
f'sequence_parallel_world_size. But got n_head = {n_head} and '
f'sequence_parallel_world_size = {sequence_parallel_world_size}.')
'(sequence_parallel_world_size // sequence_parallel_inner_world_size)'
f'. But got n_head = {h}, sequence_parallel_world_size = '
f'{sp} and sequence_parallel_inner_world_size = {insp}.')

if insp > 1:
query_states, key_states, value_states = pre_process_for_inner_sp(
query_states, key_states, value_states)

# (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim)
# (b, s_div_sp, insp*h, d/insp) -> (b, s, insp*h/sp, d/insp)
sequence_parallel_group = get_sequence_parallel_group()
query_states = all_to_all(
query_states,
Expand All @@ -36,19 +86,44 @@ def pre_process_for_sequence_parallel_attn(query_states,
scatter_dim=scatter_dim,
gather_dim=gather_dim)

if insp > 1:
query_states, key_states, value_states = post_process_for_inner_sp(
query_states, key_states, value_states)

return query_states, key_states, value_states


def post_process_for_sequence_parallel_attn(attn_output,
scatter_dim=1,
gather_dim=2):
# (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
sp = get_sequence_parallel_world_size()
insp = get_inner_sequence_parallel_world_size()
b, s, h_mul_insp_div_sp, d = attn_output.shape
h = h_mul_insp_div_sp * sp // insp
s_div_sp = s // sp

if insp > 1:
# (b, s, insp*h/sp, d) -> (b, s, insp*h/sp, d/insp)
attn_output = split_forward_gather_backward(
attn_output, -1, get_inner_sequence_parallel_group())

# (b, s, insp*h/sp, d/insp) -> (b, s_div_sp, insp*h, d/insp)
sequence_parallel_group = get_sequence_parallel_group()
output = all_to_all(
attn_output,
sequence_parallel_group,
scatter_dim=scatter_dim,
gather_dim=gather_dim)

if insp > 1:
# (b, s_div_sp, insp*h, d/insp) ->
# (b, s_div_sp, sp/insp, insp, h*insp/sp, d/insp) ->
# (b, s_div_sp, sp/insp, h*insp/sp, insp, d/insp) ->
# (b, s_div_sp, h, d)
output = output.view(b, s_div_sp, sp // insp, insp, h * insp // sp,
d // insp).transpose(3, 4).reshape(
b, s_div_sp, h, d)

return output


Expand Down
62 changes: 62 additions & 0 deletions xtuner/parallel/sequence/setup_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
_SEQUENCE_PARALLEL_WORLD_SIZE = None
_SEQUENCE_PARALLEL_RANK = None

_INNER_SEQUENCE_PARALLEL_GROUP = None
_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = None
_INNER_SEQUENCE_PARALLEL_RANK = None

_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_WORLD_SIZE = None
_DATA_PARALLEL_RANK = None
Expand Down Expand Up @@ -49,6 +53,64 @@ def init_sequence_parallel(sequence_parallel_size: int = 1):
_DATA_PARALLEL_GROUP = group


def init_inner_sequence_parallel(inner_sequence_parallel_size: int = 1):
"""Build the sequence parallel inner groups.
They are helpful when sp size is not evenly divided by the number of attn
heads.
"""
assert _SEQUENCE_PARALLEL_GROUP is not None, \
('Please call `init_inner_sequence_parallel` after calling '
'`init_sequence_parallel`.')

rank = dist.get_rank()
world_size: int = dist.get_world_size()

n_inner_group = world_size // inner_sequence_parallel_size

global _INNER_SEQUENCE_PARALLEL_GROUP
assert _INNER_SEQUENCE_PARALLEL_GROUP is None

for i in range(n_inner_group):
ranks = range(i * inner_sequence_parallel_size,
(i + 1) * inner_sequence_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
_INNER_SEQUENCE_PARALLEL_GROUP = group


def is_inner_sequence_parallel_initialized():
return _INNER_SEQUENCE_PARALLEL_GROUP is not None


def get_inner_sequence_parallel_group():
return _INNER_SEQUENCE_PARALLEL_GROUP


def get_inner_sequence_parallel_world_size():
global _INNER_SEQUENCE_PARALLEL_WORLD_SIZE
if _INNER_SEQUENCE_PARALLEL_WORLD_SIZE is not None:
return _INNER_SEQUENCE_PARALLEL_WORLD_SIZE
if not dist.is_initialized() or (_INNER_SEQUENCE_PARALLEL_GROUP is None):
_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = 1
else:
_INNER_SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
group=get_inner_sequence_parallel_group())
return _INNER_SEQUENCE_PARALLEL_WORLD_SIZE


def get_inner_sequence_parallel_rank():
global _INNER_SEQUENCE_PARALLEL_RANK
if _INNER_SEQUENCE_PARALLEL_RANK is not None:
return _INNER_SEQUENCE_PARALLEL_RANK
if not dist.is_initialized() or (_INNER_SEQUENCE_PARALLEL_GROUP is None):
_INNER_SEQUENCE_PARALLEL_RANK = 0
else:
_INNER_SEQUENCE_PARALLEL_RANK = dist.get_rank(
group=get_inner_sequence_parallel_group())
return _INNER_SEQUENCE_PARALLEL_RANK


def get_sequence_parallel_group():
"""Get the sequence parallel group the caller rank belongs to."""
return _SEQUENCE_PARALLEL_GROUP
Expand Down

0 comments on commit 7646e7b

Please sign in to comment.