Skip to content

Commit

Permalink
[Refactor] fix internlm2 dispatch (#779)
Browse files Browse the repository at this point in the history
* fix internlm2 dispatch

* add detailed RuntimeError
  • Loading branch information
HIT-cwh authored Jun 18, 2024
1 parent 6bbc274 commit 5f2bca4
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 145 deletions.
2 changes: 0 additions & 2 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@
)

ROTE_DISPATCH_MAPPING = dict(
InternLM2RotaryEmbedding=LazyObject(
'xtuner.model.modules.dispatch.internlm2', 'InternLM2RotaryEmbedding'),
InternLMRotaryEmbedding=LazyObject(
'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'),
MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral',
Expand Down
277 changes: 134 additions & 143 deletions xtuner/model/modules/dispatch/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from mmengine import MessageHub
from transformers.cache_utils import Cache, StaticCache

from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask,
varlen_flash_attn)
from .triton_kernels import apply_rotary_emb


class InternLM2RotaryEmbedding(torch.nn.Module):

def __init__(self,
dim,
max_position_embeddings=2048,
base=1000000,
device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (
base**(torch.arange(0, dim, 2).float().to(device) / dim))

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()

def forward(self, x, seq_len):
# x: [bs, num_attention_heads, seq_len, head_size]
if (seq_len > self.max_seq_len_cached
or self.cos_cached.device != x.device
or self.cos_cached.dtype != x.dtype):
self.max_seq_len_cached = seq_len
assert self.inv_freq.dtype == torch.float32
t = torch.arange(
self.max_seq_len_cached,
device=x.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos().to(x.dtype)
self.sin_cached = emb.sin().to(x.dtype)
return (
self.cos_cached[:seq_len, ...],
self.sin_cached[:seq_len, ...],
)
from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
post_process_for_sequence_parallel_attn,
pre_process_for_sequence_parallel_attn)
from .attention import SUPPORT_FLASH2, flash_attn_wo_mask, varlen_flash_attn


def rotate_half(x):
Expand All @@ -66,9 +20,9 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -111,18 +65,17 @@ def internlm2_attn_forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
cache_position: Optional[torch.LongTensor] = None,
):
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37'
'Please make sure use `attention_mask` instead.`')

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')
if isinstance(past_key_value, StaticCache):
raise ValueError(
'`static` cache implementation is not compatible with '
'`attn_implementation==flash_attention_2` make sure to use `sdpa` '
'in the mean time, and open an issue at '
'https://github.com/huggingface/transformers')

output_attentions = False

Expand All @@ -146,64 +99,68 @@ def internlm2_attn_forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

# This modification is necessary for sequential parallel
assert position_ids is not None and (position_ids.max() + 1) >= kv_seq_len
cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
cos, sin)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
# sin and cos are specific to RoPE models;
# cache_position needed for the static cache
cache_kwargs = {
'sin': sin,
'cos': cos,
'cache_position': cache_position
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)

# repeat kv for sequence parallel
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if SUPPORT_FLASH2:
# the shape of attention_mask used by flash_attn and
# F.scaled_dot_product_attention are different
assert attention_mask is None or attention_mask.ndim == 2, \
('When using flash_attn, attention_mask.ndim should equal to 2.'
f'But got attention_mask.shape = {attention_mask.shape}.'
'We can pass the `attn_implementation="flash_attention_2"` flag '
'to `.from_pretrained` method when instantiating a Internlm2 '
'model.')
# flash attn 2 need (bs, seq_len, nhead, h_dim)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

causal = self.is_causal and q_len != 1

if attention_mask is not None:
attn_output = flash_attn_w_mask(
query_states,
key_states,
value_states,
attention_mask,
causal=causal,
training=self.training)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# In PEFT, usually we cast the layer norms in float32 for training
# stability reasons therefore the input hidden states gets silently
# casted in float32. Hence, we need cast them back in the correct dtype
# just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not
# cast the LayerNorms in fp32. (InternLM2RMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
causal=causal,
training=self.training)
else:
# use flash attention implemented by pytorch
# do not support sequence parallel
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)
target_dtype = self.wqkv.weight.dtype

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

enable_sequence_parallel = (
dist.is_initialized() and get_sequence_parallel_world_size() > 1
and self.training)
if enable_sequence_parallel:
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)

dropout_rate = 0.0
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.shape[1],
dropout=dropout_rate)

if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
Expand All @@ -217,14 +174,21 @@ def internlm2_attn_forward(
def internlm2_varlen_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501

if isinstance(past_key_value, StaticCache):
raise ValueError(
'`static` cache implementation is not compatible with '
'`attn_implementation==flash_attention_2` make sure to use `sdpa` '
'in the mean time, and open an issue at '
'https://github.com/huggingface/transformers')

message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
Expand All @@ -238,6 +202,7 @@ def internlm2_varlen_attn_forward(
f' set to 1, but got {bsz}')

qkv_states = self.wqkv(hidden_states)

qkv_states = rearrange(
qkv_states,
'b q (h gs d) -> b q h gs d',
Expand All @@ -250,55 +215,81 @@ def internlm2_varlen_attn_forward(
key_states = qkv_states[..., -2, :]
value_states = qkv_states[..., -1, :]

kv_seq_len = key_states.shape[-3]
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

try:
cos, sin = self.rotary_emb(value_states, position_ids)
except RuntimeError:
raise RuntimeError(
'You are using the old version of InternLM2 model. The '
'`modeling_internlm2.py` is outdated. Please update the InternLM2 '
'model.')
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin)

if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
# sin and cos are specific to RoPE models;
# cache_position needed for the static cache
cache_kwargs = {
'sin': sin,
'cos': cos,
'cache_position': cache_position
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)

if use_varlen_atten:
cos, sin = self.rotary_emb(value_states, max_seqlen)
query_states = apply_rotary_emb(query_states,
cos[position_ids].squeeze(0),
sin[position_ids].squeeze(0))
key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

# In PEFT, usually we cast the layer norms in float32 for training
# stability reasons therefore the input hidden states gets silently
# casted in float32. Hence, we need cast them back in the correct dtype
# just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not
# cast the LayerNorms in fp32. (InternLM2RMSNorm handles it correctly)

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.wqkv.weight.dtype

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# repeat kv for sequence parallel
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)

assert SUPPORT_FLASH2

dropout_rate = 0.0
if use_varlen_atten:
attn_output = varlen_flash_attn(
query_states,
key_states,
value_states,
cumulative_len,
max_seqlen,
causal=True,
dropout_p=dropout_rate,
training=self.training)
else:
attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
causal=True,
training=False)
dropout_p=dropout_rate,
training=self.training)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

Expand Down

0 comments on commit 5f2bca4

Please sign in to comment.