Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flag to not pass PAD tokens in ffwd #775

Merged
merged 21 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llmfoundry/models/layers/blocks.py
vchiley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
unpad_input, pad_input = None, None

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
Expand Down Expand Up @@ -53,6 +58,7 @@ def __init__(
fc_type: str = 'torch',
device: Optional[str] = None,
no_bias: bool = False,
use_pad_tok_in_ffn: bool = True,
**kwargs: Any,
):
if attn_config is None:
Expand Down Expand Up @@ -105,6 +111,8 @@ def __init__(
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)

self.use_pad_tok_in_ffwd = use_pad_tok_in_ffwd

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -132,6 +140,14 @@ def forward(
m = x
if self.norm_2 is not None:
m = self.norm_2(x)
batch_size, seq_len= m.size[:2]
if not self.use_pad_tok_in_ffwd:
if unpad_input is None:
raise RuntimeError(
'Please install flash-attn==1.0.9 or flash-attn==2.3.2')
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
m, indices, _, _ = unpad_input(m, attention_mask)
n = self.ffn(m)
if not self.use_pad_tok_in_ffwd:
n = pad_input(n, indices, batch_size, seq_len)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value
3 changes: 3 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
init_config: Dict = init_config_defaults,
fc_type: str = 'torch',
tie_word_embeddings: bool = True,
use_pad_tok_in_ffwd: bool = True,
verbose: Optional[int] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs.
tie_word_embeddings (bool): Whether to tie the input embedding and output layers.
use_pad_tok_in_ffwd (bool): Whether to forward the pad token in the feedforwards networks.
bcui19 marked this conversation as resolved.
Show resolved Hide resolved
"""
self.d_model = d_model
self.n_heads = n_heads
Expand All @@ -151,6 +153,7 @@ def __init__(
self.use_cache = use_cache
self.init_config = init_config
self.fc_type = fc_type
self.use_pad_tok_in_ffwd = use_pad_tok_in_ffwd
if verbose is not None:
warnings.warn(
DeprecationWarning(
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _attn_bias(
attn_bias = attn_bias.masked_fill(
~attention_mask.view(-1, 1, 1, s_k), min_val)

return attn_bias, None
return attn_bias, attention_mask

def _apply_prefix_mask(self, attn_bias: torch.Tensor,
prefix_mask: torch.Tensor) -> torch.Tensor:
Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,10 @@ def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict):
},
}])
@pytest.mark.parametrize('tie_word_embeddings', [True, False])
@pytest.mark.parametrize('use_pad_tok_in_ffwd', [True, False])
def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
tie_word_embeddings: bool):
tie_word_embeddings: bool,
use_pad_tok_in_ffwd: bool):
# Test that different placement of padding does not affect the output.
alibi = pos_emb_config['alibi']
if alibi and attention_impl == 'flash':
Expand Down Expand Up @@ -731,6 +733,7 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
'init_std': 0.02,
},
tie_word_embeddings=tie_word_embeddings,
use_pad_tok_in_ffwd=use_pad_tok_in_ffwd,
)
mpt = MPTForCausalLM(hf_config)
mpt.eval()
Expand Down
Loading