Skip to content

Commit

Permalink
encoders and encoder+decoder all work
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jul 26, 2023
1 parent 212aa3a commit abd8920
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 116 deletions.
180 changes: 137 additions & 43 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, albert_layer, config):
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
self.hidden_dropout_prob = config.hidden_dropout_prob
self.act_fn_callable = ACT2FN[self.act_fn]

self.validate_bettertransformer()

Expand Down Expand Up @@ -150,13 +151,6 @@ def forward(self, hidden_states, attention_mask, *_):
qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
query, key, value = qkv[0], qkv[1], qkv[2]

# TODO: pass scale argument in PyTorch 2.1 release
query = (
query
* torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype())
/ 8
)

# NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch
# to the "math" path and will NOT use flash attention / memory-efficient attention.
# We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work.
Expand All @@ -181,25 +175,23 @@ def forward(self, hidden_states, attention_mask, *_):
training=self.training,
)
+ hidden_states,
normalized_shape=self.norm1_weight.shape, # TODO: stateful
normalized_shape=self.norm1_weight.shape,
weight=self.norm1_weight,
bias=self.norm1_bias,
)

# BertIntermediate
# TODO: stateful
act_fn = ACT2FN[self.act_fn]
x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias))
hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias))

# BertOutput
hidden_states = F.layer_norm(
attention_out
+ F.dropout(
F.linear(x, self.linear2_weight, self.linear2_bias),
F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
p=self.hidden_dropout_prob,
training=self.training,
),
normalized_shape=self.norm2_weight.shape, # TODO: stateful
normalized_shape=self.norm2_weight.shape,
weight=self.norm2_weight,
bias=self.norm2_bias,
)
Expand Down Expand Up @@ -292,6 +284,7 @@ def __init__(self, bert_layer, config):
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.hidden_dropout_prob = config.hidden_dropout_prob
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
self.act_fn_callable = ACT2FN[self.act_fn]

self.validate_bettertransformer()

Expand Down Expand Up @@ -337,13 +330,6 @@ def forward(self, hidden_states, attention_mask, *_):
qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
query, key, value = qkv[0], qkv[1], qkv[2]

# TODO: pass scale argument in PyTorch 2.1 release
query = (
query
* torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype())
/ 8
)

# NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch
# to the "math" path and will NOT use flash attention / memory-efficient attention.
# We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work.
Expand All @@ -368,25 +354,23 @@ def forward(self, hidden_states, attention_mask, *_):
training=self.training,
)
+ hidden_states,
normalized_shape=self.norm1_weight.shape, # TODO: stateful
normalized_shape=self.norm1_weight.shape,
weight=self.norm1_weight,
bias=self.norm1_bias,
)

# BertIntermediate
# TODO: stateful
act_fn = ACT2FN[self.act_fn]
x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias))
hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias))

# BertOutput
hidden_states = F.layer_norm(
attention_out
+ F.dropout(
F.linear(x, self.linear2_weight, self.linear2_bias),
F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
p=self.hidden_dropout_prob,
training=self.training,
),
normalized_shape=self.norm2_weight.shape, # TODO: stateful
normalized_shape=self.norm2_weight.shape,
weight=self.norm2_weight,
bias=self.norm2_bias,
)
Expand Down Expand Up @@ -471,6 +455,10 @@ def __init__(self, bart_layer, config):
"norm2_weight": "final_layer_norm.weight",
"norm2_bias": "final_layer_norm.bias",
}
self.dropout = config.attention_dropout
self.activation_dropout = config.activation_dropout
self.attention_head_size = config.d_model // config.encoder_attention_heads
self.act_fn_callable = ACT2FN[self.act_fn]

self.validate_bettertransformer()

Expand Down Expand Up @@ -521,7 +509,58 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
else:
raise NotImplementedError("TODO")
qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias)

qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
query, key, value = qkv[0], qkv[1], qkv[2]

# NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch
# to the "math" path and will NOT use flash attention / memory-efficient attention.
# We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work.
attention_out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
is_causal=False,
dropout_p=self.dropout if self.training else 0.0,
)

attention_out = attention_out.permute(0, 2, 1, 3).contiguous()
new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,)
attention_out = attention_out.view(new_attention_out_shape)

# BertSelfOutput
attention_out = F.layer_norm(
F.dropout(
F.linear(attention_out, self.out_proj_weight, self.out_proj_bias),
p=self.dropout,
training=self.training,
)
+ hidden_states,
normalized_shape=self.norm1_weight.shape,
weight=self.norm1_weight,
bias=self.norm1_bias,
)

# One additional dropout compared to bert
hidden_states = F.dropout(
self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias)),
p=self.activation_dropout,
training=self.training,
)

hidden_states = F.layer_norm(
attention_out
+ F.dropout(
F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
p=self.dropout,
training=self.training,
),
normalized_shape=self.norm2_weight.shape,
weight=self.norm2_weight,
bias=self.norm2_bias,
)
return (hidden_states,)


Expand Down Expand Up @@ -606,6 +645,10 @@ def __init__(self, mbart_layer, config):
"norm2_bias": "final_layer_norm.bias",
"norm2_eps": "final_layer_norm.eps",
}
self.dropout = config.attention_dropout
self.activation_dropout = config.activation_dropout
self.attention_head_size = config.d_model // config.encoder_attention_heads
self.act_fn_callable = ACT2FN[self.act_fn]

self.validate_bettertransformer()

Expand Down Expand Up @@ -656,7 +699,60 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
else:
raise NotImplementedError("TODO")
residual = hidden_states
hidden_states = F.layer_norm(
hidden_states,
normalized_shape=self.norm1_weight.shape,
weight=self.norm1_weight,
bias=self.norm1_bias,
)

qkv = F.linear(hidden_states, weight=self.in_proj_weight, bias=self.in_proj_bias)
qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
query, key, value = qkv[0], qkv[1], qkv[2]

# NOTE: In PyTorch 2.0, passing an attention_mask will automatically dispatch
# to the "math" path and will NOT use flash attention / memory-efficient attention.
# We should support xformers / Hazy-flash / rocm-flash directly and stop relying on PyTorch to do the work.
attention_out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
is_causal=False,
dropout_p=self.dropout if self.training else 0.0,
)

attention_out = attention_out.permute(0, 2, 1, 3).contiguous()
new_attention_out_shape = attention_out.size()[:-2] + (self.num_heads * self.attention_head_size,)
attention_out = attention_out.view(new_attention_out_shape)

hidden_states = residual + F.dropout(
F.linear(attention_out, self.out_proj_weight, self.out_proj_bias),
p=self.dropout,
training=self.training,
)
residual = hidden_states
hidden_states = F.layer_norm(
hidden_states,
normalized_shape=self.norm2_weight.shape,
weight=self.norm2_weight,
bias=self.norm2_bias,
)

# One additional dropout compared to bert
hidden_states = F.dropout(
self.act_fn_callable(F.linear(hidden_states, self.linear1_weight, self.linear1_bias)),
p=self.activation_dropout,
training=self.training,
)

hidden_states = residual + F.dropout(
F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
p=self.dropout,
training=self.training,
)

return (hidden_states,)


Expand Down Expand Up @@ -737,6 +833,7 @@ def __init__(self, bert_layer, config):
self.attention_dropout = config.attention_dropout
self.dropout = config.dropout
self.attention_head_size = config.dim // config.n_heads
self.act_fn_callable = ACT2FN[self.act_fn]

self.validate_bettertransformer()

Expand Down Expand Up @@ -786,13 +883,6 @@ def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=No
qkv = qkv.view(qkv.size()[:-1] + (3, self.num_heads, self.attention_head_size)).permute(2, 0, 3, 1, 4)
query, key, value = qkv[0], qkv[1], qkv[2]

# TODO: pass scale argument in PyTorch 2.1 release
query = (
query
* torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32)).to(torch.get_default_dtype())
/ 8
)

# TODO: Kind of stupid to do that at each layer, should be fixed in transformers
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2).to(dtype=query.dtype)
attn_mask = (1.0 - attn_mask) * torch.finfo(query.dtype).min
Expand Down Expand Up @@ -821,23 +911,23 @@ def forward(self, hidden_states, attn_mask, head_mask=None, output_attentions=No
training=self.training,
)
+ hidden_states,
normalized_shape=self.norm1_weight.shape, # TODO: stateful
normalized_shape=self.norm1_weight.shape,
weight=self.norm1_weight,
bias=self.norm1_bias,
)

# BertIntermediate
# TODO: stateful
act_fn = ACT2FN[self.act_fn]
x = act_fn(F.linear(attention_out, self.linear1_weight, self.linear1_bias))
hidden_states = self.act_fn_callable(F.linear(attention_out, self.linear1_weight, self.linear1_bias))

# BertOutput
hidden_states = F.layer_norm(
attention_out
+ F.dropout(
F.linear(x, self.linear2_weight, self.linear2_bias), p=self.dropout, training=self.training
F.linear(hidden_states, self.linear2_weight, self.linear2_bias),
p=self.dropout,
training=self.training,
),
normalized_shape=self.norm2_weight.shape, # TODO: stateful
normalized_shape=self.norm2_weight.shape,
weight=self.norm2_weight,
bias=self.norm2_bias,
)
Expand Down Expand Up @@ -1445,7 +1535,9 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
else:
raise ValueError("Training and Autocast are not supported for BetterTransformer + FSMT.")
raise NotImplementedError(
"Training and Autocast are not implemented for BetterTransformer + FSMT. Please open an issue."
)

return (hidden_states, attention_mask)

Expand Down Expand Up @@ -1577,7 +1669,9 @@ def forward(self, hidden_states, attention_mask, *_, **__):
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
else:
raise ValueError("Training and Autocast are not supported for BetterTransformer + ProphetNet.")
raise ValueError(
"Training and Autocast are not implemented for BetterTransformer + ProphetNet. Please open an issue."
)

return (hidden_states,)

Expand Down
2 changes: 1 addition & 1 deletion optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def transform(

# See: https://github.com/pytorch/pytorch/issues/96099
# TODO: show the warning only for decoders (which do not need an attention mask for training)
if BetterTransformerManager.is_decoder(model_fast.config.model_type):
if False: # BetterTransformerManager.is_decoder(model_fast.config.model_type):
logging.warning(
f"For decoder training, the BetterTransformer implementation for {model_fast.config.model_type} "
" architecture currently does not support padding as fused kernels do not support custom"
Expand Down
6 changes: 0 additions & 6 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch
from packaging.version import parse

from . import is_accelerate_available, is_diffusers_available

Expand Down Expand Up @@ -63,11 +62,6 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)


def require_torch_20(test_case):
"""Decorator marking a test that requires torch>=2.0."""
return unittest.skipUnless(parse(torch.__version__) > parse("1.14"), "test requires torch>=2.0")(test_case)


def require_hf_token(test_case):
"""
Decorator marking a test that requires huggingface hub token.
Expand Down
Loading

0 comments on commit abd8920

Please sign in to comment.