Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flag to not pass PAD tokens in ffwd #775

Merged
merged 21 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,14 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.layers.ffn_padding_utils import pad_input, unpad_input
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm

__all__ = [
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
'MultiheadAttention',
'MultiQueryAttention',
'attn_bias_shape',
'build_attn_bias',
'build_alibi_bias',
'ATTN_CLASS_REGISTRY',
'MPTMLP',
'MPTBlock',
'NORM_CLASS_REGISTRY',
'LPLayerNorm',
'FC_CLASS_REGISTRY',
'SharedEmbedding',
'FFN_CLASS_REGISTRY',
'build_ffn',
'scaled_multihead_dot_product_attention', 'flash_attn_fn',
'triton_flash_attn_fn', 'MultiheadAttention', 'MultiQueryAttention',
'attn_bias_shape', 'build_attn_bias', 'build_alibi_bias',
'ATTN_CLASS_REGISTRY', 'MPTMLP', 'MPTBlock', 'NORM_CLASS_REGISTRY',
'LPLayerNorm', 'FC_CLASS_REGISTRY', 'SharedEmbedding', 'FFN_CLASS_REGISTRY',
'build_ffn', 'unpad_input', 'pad_input'
vchiley marked this conversation as resolved.
Show resolved Hide resolved
]
10 changes: 10 additions & 0 deletions llmfoundry/models/layers/blocks.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.ffn_padding_utils import pad_input, unpad_input
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

attn_config_defaults: Dict = {
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffwd: bool = True,
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -105,6 +107,8 @@ def __init__(
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)

self.use_pad_tok_in_ffwd = use_pad_tok_in_ffwd

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -132,6 +136,12 @@ def forward(
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
batch_size = m.size(0)
seq_len = m.size(1)
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
if not self.use_pad_tok_in_ffwd:
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffwd:
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
126 changes: 126 additions & 0 deletions llmfoundry/models/layers/ffn_padding_utils.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Padding utils for ffn layers.

Code was adapted from https://github.com/sz128/flash-attention/blob/3c0d
49532cb3bc0a36b405f34590ff82cf582853/flash_attn/bert_padding.py.
"""

import torch
import torch.nn.functional as F
from einops import rearrange, repeat


class IndexFirstAxis(torch.autograd.Function):

@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
repeat(indices, 'z -> z d',
d=second_dim)).reshape(-1, *other_shape)

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)')
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0,
repeat(indices, 'z -> z d', d=grad_output.shape[1]),
grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None


index_first_axis = IndexFirstAxis.apply


class IndexPutFirstAxis(torch.autograd.Function):

@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim,
*values.shape[1:],
device=values.device,
dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None


index_put_first_axis = IndexPutFirstAxis.apply


def unpad_input(hidden_states, attention_mask):
"""Removes the padding from a hidden state.

Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.

Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def pad_input(hidden_states, indices, batch_size, seqlen):
"""Pads the hidden state.

Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
batch_size: (int) batch size
seqlen: (int) sequence length to pad the input

Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch_size * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch_size)
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffwd: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffwd (bool): Whether to forward the pad token in the feedforwards networks.
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
"""
self.d_model = d_model
self.n_heads = n_heads
Expand All @@ -151,6 +153,7 @@ def __init__(
self.use_cache = use_cache
self.init_config = init_config
self.fc_type = fc_type
self.use_pad_tok_in_ffwd = use_pad_tok_in_ffwd
if verbose is not None:
warnings.warn(
DeprecationWarning(
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _attn_bias(
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)

return attn_bias, None
return attn_bias, attention_mask

def _apply_prefix_mask(self, attn_bias: torch.Tensor,
prefix_mask: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ include = [

# Pyright
[tool.pyright]
exclude = ['env-**', 'venv*', '**/flash_attn_triton.py']
ignore = ['llmfoundry/models/layers/flash_attn_triton.py']
exclude = ['env-**', 'venv*', '**/flash_attn_triton.py', 'ffn_padding_utils.py']
ignore = ['llmfoundry/models/layers/flash_attn_triton.py', 'llmfoundry/models/layers/ffn_padding_utils.py']
stubPath = "" # suppress useless 'stubPath is not a valid directory' errors

reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety
Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,10 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict):
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@pytest.mark.parametrize('use_pad_tok_in_ffwd', [True, False])
def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
tie_word_embeddings: bool,
use_pad_tok_in_ffwd: bool):
# Test that different placement of padding does not affect the output.
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash':
Expand Down Expand Up @@ -731,6 +733,7 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
use_pad_tok_in_ffwd=use_pad_tok_in_ffwd,
)
mpt = MPTForCausalLM(hf_config)
mpt.eval()
Expand Down
Loading