Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move applying rotary embeddings inside LlamaRotaryEmbedding class #26307

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,32 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed

def forward(self, q, k, position_ids, seq_len=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the forward pass this way is breaking 😢
We should try to avoid that!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the forward pass was changed to def forward(self, x, seq_len=None, pos_ids=None) to maintain backward compatibility, then the below lines are called twice when they don't have to be.

if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype)
cos, sin = (
self.cos_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
)

cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]

# q/k: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
cos, sin = (
self.cos_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
)
return self.apply_rotary_pos_emb(q, k, cos, sin, position_ids)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
Expand Down Expand Up @@ -184,12 +201,22 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `OpenLlamaRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `OpenLlamaRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
Expand Down Expand Up @@ -291,8 +318,7 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids, seq_len=kv_seq_len)
# [bsz, nh, t, hd]

if past_key_value is not None:
Expand Down
40 changes: 34 additions & 6 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def forward(
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query, key = self.rotary_emb(query_rot, key_rot, position_ids, seq_len=seq_len)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -309,11 +308,30 @@ def _set_cos_sin_cache(self, seq_len, device):
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed

def forward(self, q, k, position_ids, seq_len=None):
# q/k: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
self._set_cos_sin_cache(seq_len=seq_len, device=q.device)
cos, sin = (
self.cos_cached[:seq_len, ...].to(q.device),
self.sin_cached[:seq_len, ...].to(q.device),
)
return self.apply_rotary_pos_emb(q, k, cos, sin, position_ids)


class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
Expand Down Expand Up @@ -363,12 +381,22 @@ def _set_cos_sin_cache(self, seq_len, device):

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `GPTNeoXRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `GPTNeoXRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def forward(
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
query, key = self.rotary_emb(query_rot, key_rot, offset=offset, seq_len=seq_len)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -238,7 +237,6 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
return attn_output, attn_weights


# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
Expand All @@ -262,21 +260,48 @@ def _set_cos_sin_cache(self, seq_len, device):
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, q, k, cos, sin, offset: int = 0):
cos = cos[..., offset : q.shape[-2] + offset, :]
sin = sin[..., offset : q.shape[-2] + offset, :]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed

def forward(self, q, k, offset=0, seq_len=None):
# q/k: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
self._set_cos_sin_cache(seq_len=seq_len, device=q.device)
cos, sin = (
self.cos_cached[:seq_len, ...].to(q.device),
self.sin_cached[:seq_len, ...].to(q.device),
)
return self.apply_rotary_pos_emb(q, k, cos, sin, offset)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `RotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `RotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos[..., offset : q.shape[-2] + offset, :]
sin = sin[..., offset : q.shape[-2] + offset, :]
q_embed = (q * cos) + (rotate_half(q) * sin)
Expand Down
42 changes: 34 additions & 8 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,32 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed

def forward(self, q, k, position_ids, seq_len=None):
# q/k: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
cos, sin = (
self.cos_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
)
return self.apply_rotary_pos_emb(q, k, cos, sin, position_ids)
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved


class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -176,13 +193,23 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `LlamaRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `LlamaRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
Expand Down Expand Up @@ -333,8 +360,7 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids, seq_len=kv_seq_len)
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved

if past_key_value is not None:
# reuse k, v, self_attention
Expand Down
46 changes: 36 additions & 10 deletions src/transformers/models/persimmon/modeling_persimmon.py
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,32 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed

def forward(self, q, k, position_ids, seq_len=None):
# q/k: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
cos, sin = (
self.cos_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=q.dtype),
)
return self.apply_rotary_pos_emb(q, k, cos, sin, position_ids)


# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
Expand Down Expand Up @@ -155,17 +172,27 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)


# Copied from transformers.models.llama.modeling_llama.rotate_half
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
# Copied from transformers.models.llama.modeling_llama.rotate_half with Llama->Persimmon
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
logger.warning_once(
"Using the global `rotate_half` function is deprecated. Please use `PersimmonRotaryEmbedding.rotate_half` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
kunal-vaishnavi marked this conversation as resolved.
Show resolved Hide resolved
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->Persimmon
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
logger.warning_once(
"Using the global `apply_rotary_pos_emb` function is deprecated. Please use `PersimmonRotaryEmbedding.apply_rotary_pos_emb` instead. "
"This is deprecated to improve the export to ONNX by applying the rotary embeddings during the forward call in "
"the rotary embedding class. "
)
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
Expand Down Expand Up @@ -295,7 +322,6 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)

# Partial rotary embedding
query_rot, query_pass = (
Expand All @@ -307,7 +333,7 @@ def forward(
key_states[..., self.rotary_emb.dim :],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len=kv_seq_len)

# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
Expand Down
Loading