Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move applying rotary embeddings inside LlamaRotaryEmbedding class #26307

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,28 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, x, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
x_embed = (x * cos) + (self.rotate_half(x) * sin)
return x_embed

def forward(self, x, seq_len=None, position_ids=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
cos, sin = (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
return self.apply_rotary_pos_emb(x, cos, sin, position_ids)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
Expand Down Expand Up @@ -181,15 +194,25 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `OpenLlamaRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->OpenLlama
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `OpenLlamaRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -287,8 +310,8 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = self.rotary_emb(query_states, seq_len=kv_seq_len, position_ids=position_ids)
key_states = self.rotary_emb(key_states, seq_len=kv_seq_len, position_ids=position_ids)
# [bsz, nh, t, hd]

if past_key_value is not None:
Expand Down
36 changes: 30 additions & 6 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ def forward(
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = self.rotary_emb(query_rot, seq_len=seq_len, position_ids=position_ids)
key = self.rotary_emb(key_rot, seq_len=seq_len, position_ids=position_ids)

query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -312,15 +313,28 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, x, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
x_embed = (x * cos) + (self.rotate_half(x) * sin)
return x_embed

def forward(self, x, seq_len=None, position_ids=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
cos, sin = (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
return self.apply_rotary_pos_emb(x, cos, sin, position_ids)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
Expand Down Expand Up @@ -372,15 +386,25 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `GPTNeoXRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->GPTNeoX
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `GPTNeoXRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def forward(
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
query = self.rotary_emb(query_rot, seq_len=seq_len, offset=offset)
key = self.rotary_emb(key_rot, seq_len=seq_len, offset=offset)

query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -238,7 +239,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
return attn_output, attn_weights


# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand All @@ -264,25 +264,48 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, x, cos, sin, offset: int = 0):
cos = cos[..., offset : x.shape[-2] + offset, :]
sin = sin[..., offset : x.shape[-2] + offset, :]
x_embed = (x * cos) + (self.rotate_half(x) * sin)
return x_embed

def forward(self, x, seq_len=None, offset=0):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
cos, sin = (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
return self.apply_rotary_pos_emb(x, cos, sin, offset)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `RotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `RotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos[..., offset : q.shape[-2] + offset, :]
sin = sin[..., offset : q.shape[-2] + offset, :]
q_embed = (q * cos) + (rotate_half(q) * sin)
Expand Down
36 changes: 30 additions & 6 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,28 +526,52 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, x, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
x_embed = (x * cos) + (self.rotate_half(x) * sin)
return x_embed

def forward(self, x, seq_len=None, position_ids=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
cos, sin = (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
return self.apply_rotary_pos_emb(x, cos, sin, position_ids)


# Copied from transformers.models.llama.modeling_llama.rotate_half with Llama->Idefics
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `IdeficsRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->Idefics
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `IdeficsRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -677,8 +701,8 @@ def forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if not is_cross_attention:
cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = self.rotary_emb(query_states, seq_len=max(kv_seq_len, q_len), position_ids=position_ids)
key_states = self.rotary_emb(key_states, seq_len=max(kv_seq_len, q_len), position_ids=position_ids)
# [bsz, nh, t, hd]

if past_key_value is not None:
Expand Down
34 changes: 28 additions & 6 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,28 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, x, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
x_embed = (x * cos) + (self.rotate_half(x) * sin)
return x_embed

def forward(self, x, seq_len=None, position_ids=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
cos, sin = (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
return self.apply_rotary_pos_emb(x, cos, sin, position_ids)


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -199,15 +212,24 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `LlamaRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `LlamaRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -355,8 +377,8 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = self.rotary_emb(query_states, seq_len=kv_seq_len, position_ids=position_ids)
key_states = self.rotary_emb(key_states, seq_len=kv_seq_len, position_ids=position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down
Loading