diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e79abea3cf..7467190582 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -37,9 +37,12 @@ try: # FlashAttention (1.x) from flash_attn.flash_attn_interface import flash_attn_unpadded_func - from flash_attn.flash_attn_triton import flash_attn_func except ImportError: flash_attn_unpadded_func = None + +try: + from flash_attn.flash_attn_triton import flash_attn_func +except ImportError: flash_attn_func = None try: @@ -599,7 +602,11 @@ def __init__(self, config, layer_number, if self.enable_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 - self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group()) + self.dist_attn = DistributedAttention( + local_attn, + parallel_state.get_sequence_parallel_group(), + gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0) + # flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension. else: if self.use_flash_attn: self.core_attention_flash = local_attn