Skip to content

Commit

Permalink
Add bark into bettertransformer (#1199)
Browse files Browse the repository at this point in the history
* first working POC of bark integration

* first POC of bark integration

* add bark to bettertranformer docs

* modify bark self attention - training is not supported so no attention_mask passed

* forward unimplemented in Bark so no test_invert_model_logits

* Update commentary regarding why bark.encodec_model is skipped

* update bark tests

* modify bark test_fp_16 to be less demanding

* Update dropout in bark self attention

Co-authored-by: fxmarty <[email protected]>

* make style

* fix tests, my fault!

* remove forward checker from bark class

* remove raise_autocast test from bark tests

---------

Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
ylacombe and fxmarty authored Jul 27, 2023
1 parent 5e04785 commit e8c1266
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ We provide an integration with `BetterTransforer` API to use this function in
The list of supported model below:

- [AlBERT](https://arxiv.org/abs/1909.11942)
- [Bark](https://github.com/suno-ai/bark)
- [BART](https://arxiv.org/abs/1910.13461)
- [BERT](https://arxiv.org/abs/1810.04805)
- [BERT-generation](https://arxiv.org/abs/1907.12461)
Expand Down
5 changes: 5 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .attention import _llama_prepare_decoder_attention_mask
from .decoder_models import (
BarkAttentionLayerBetterTransformer,
BartAttentionLayerBetterTransformer,
BlenderbotAttentionLayerBetterTransformer,
CodegenAttentionLayerBetterTransformer,
Expand Down Expand Up @@ -48,6 +49,7 @@
class BetterTransformerManager:
MODEL_MAPPING = {
"albert": {"AlbertLayer": AlbertLayerBetterTransformer},
"bark": {"BarkSelfAttention": BarkAttentionLayerBetterTransformer},
"bart": {
"BartEncoderLayer": BartEncoderLayerBetterTransformer,
"BartAttention": BartAttentionLayerBetterTransformer,
Expand Down Expand Up @@ -114,6 +116,8 @@ class BetterTransformerManager:
"clip": ["text_model"],
# blip-2's Q-former and vision model should not be identified as the last layers of the model
"blip-2": ["qformer.encoder.layer", "vision_model.encoder.layers"],
# bark.codec_model.encoder is not supported in BetterTransformer
"bark": ["codec_model.encoder.layers"],
}

CAN_NOT_BE_SUPPORTED = {
Expand All @@ -122,6 +126,7 @@ class BetterTransformerManager:
}

NOT_REQUIRES_NESTED_TENSOR = {
"bark",
"blenderbot",
"codegen",
"gpt2",
Expand Down
24 changes: 24 additions & 0 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,30 @@ def gpt2_wrapped_scaled_dot_product(
return sdpa_result, None


# Adapted from transformers.models.bark.modeling_bark.BarkSelfAttention._attn
def bark_wrapped_scaled_dot_product(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
raise_on_head_mask(head_mask)

# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
# to work around this we set is_causal=False.
is_causal = self.is_causal and query.shape[2] != 1

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal
)

return sdpa_result, None


# Adapted from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._attn
def gpt_neo_wrapped_scaled_dot_product(
self,
Expand Down
41 changes: 41 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
import torch.nn as nn
from transformers.models.bark.modeling_bark import BarkSelfAttention
from transformers.models.bart.modeling_bart import BartAttention
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention
from transformers.models.codegen.modeling_codegen import CodeGenAttention
Expand All @@ -30,6 +31,7 @@
from transformers.models.t5.modeling_t5 import T5Attention

from .attention import (
bark_wrapped_scaled_dot_product,
bart_forward,
codegen_wrapped_scaled_dot_product,
gpt2_wrapped_scaled_dot_product,
Expand Down Expand Up @@ -158,6 +160,45 @@ def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)


class BarkAttentionLayerBetterTransformer(BetterTransformerBaseLayer, BarkSelfAttention, nn.Module):
_attn = bark_wrapped_scaled_dot_product

def __init__(self, layer: "nn.Module", config: "PretrainedConfig", is_causal: bool = False):
super().__init__(config)

is_causal = layer.is_causal

config.dropout = layer.dropout

config.hidden_size = layer.embed_dim
config.num_heads = layer.num_heads
config.bias = layer.out_proj.bias is not None

if is_causal:
config.block_size = layer.bias.shape[-1]

with torch.device("meta"):
super(BetterTransformerBaseLayer, self).__init__(config, is_causal)

self.module_mapping = None
submodules = ["dropout", "attn_dropout", "resid_dropout", "att_proj", "out_proj"]

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

if is_causal:
setattr(self, "bias", getattr(layer, "bias"))
self.original_layers_mapping["bias"] = "bias"

self.supports_training = False
self.dropout_prob_attn = float(config.dropout)

def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)


class CodegenAttentionLayerBetterTransformer(BetterTransformerBaseLayer, CodeGenAttention, nn.Module):
_attn = codegen_wrapped_scaled_dot_product

Expand Down
4 changes: 2 additions & 2 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,14 @@ def reverse(bt_model: "PreTrainedModel") -> "PreTrainedModel":
)
config = bt_model.config

if config.model_type not in ["wav2vec2", "hubert"]:
if config.model_type not in ["wav2vec2", "hubert", "bark"]:
with torch.device("meta"):
reversed_model = bt_model.__class__(config)
else:
# TODO: fix once this is fixed in pytorch
# reference: https://github.com/pytorch/pytorch/issues/96409
logger.warning(
"The reverse transform for the architectures wav2vec2 and hubert is memory-heavy due to a bug in PyTorch."
"The reverse transform for the architectures wav2vec2, hubert, bark is memory-heavy due to a bug in PyTorch."
)
reversed_model = bt_model.__class__(config)

Expand Down
144 changes: 142 additions & 2 deletions tests/bettertransformer/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,163 @@
import unittest

import numpy as np
import pytest
import torch
from parameterized import parameterized
from testing_utils import MODELS_DICT, BetterTransformersTestMixin
from transformers import AutoFeatureExtractor, AutoModel, AutoProcessor
from transformers import AutoFeatureExtractor, AutoModel, AutoProcessor, set_seed

from optimum.bettertransformer import BetterTransformer
from optimum.utils.testing_utils import grid_parameters
from optimum.utils.testing_utils import grid_parameters, require_torch_gpu


ALL_AUDIO_MODELS_TO_TEST = [
"openai/whisper-tiny",
"patrickvonplaten/wav2vec2_tiny_random",
"ybelkada/hubert-tiny-random",
"ybelkada/tiny-wav2vec2-stable-ln",
"ylacombe/bark-small",
]


class BetterTransformersBarkTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Testing suite for Bark - tests all the tests defined in `BetterTransformersTestMixin`
Since `Bark` is a text-to-speech model, it is preferrable
to define its own testing class.
"""
SUPPORTED_ARCH = ["bark"]

FULL_GRID = {
"model_type": SUPPORTED_ARCH,
"keep_original_model": [False],
}

def prepare_inputs_for_class(self, model_id, model_type, batch_size=1, **kwargs):
if batch_size == 1:
texts = ["a dummy input yeah!"]
else:
texts = ["a dummy input yeah!"] + ["and two"] * (batch_size - 1)

processor = AutoProcessor.from_pretrained(model_id)

input_dict = processor(texts, **kwargs)

return input_dict

@require_torch_gpu
def _test_fp16_inference(
self, model_id: str, model_type: str, automodel_class, use_to_operator=False, **preprocessor_kwargs
):
r"""
This tests if the converted model runs fine under fp16.
"""
# The first row of the attention mask needs to be all ones -> check: https://github.com/pytorch/pytorch/blob/19171a21ee8a9cc1a811ac46d3abd975f0b6fc3b/test/test_nn.py#L5283
inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs).to(0)

set_seed(0)

if not use_to_operator:
hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=False)

hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
else:
hf_random_model = automodel_class.from_pretrained(model_id).to(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=False)

hf_random_model = automodel_class.from_pretrained(model_id).to(0)
hf_random_model = hf_random_model.to(torch.float16)
converted_model = converted_model.to(torch.float16)

self.assertFalse(
hasattr(hf_random_model, "use_bettertransformer"),
f"The model {hf_random_model.__class__.__name__} has been converted to a `fast` model by mistake.",
)

length = 50
rtol = 5e-2

with torch.inference_mode():
r"""
Make sure the models are in eval mode! Make also sure that the original model
has not been converted to a fast model. The check is done above.
"""
output_hf = hf_random_model.generate(
**inputs, fine_temperature=None, do_sample=False, semantic_max_new_tokens=length
)

output_bt = converted_model.generate(
**inputs, fine_temperature=None, do_sample=False, semantic_max_new_tokens=length
)

self.assertTrue(
(output_hf - output_bt).abs().mean() < rtol,
f"Mean absolute diff: {(output_hf - output_bt).abs().mean()}",
)

@parameterized.expand(
grid_parameters(
{
"model_type": SUPPORTED_ARCH,
"use_to_operator": [True, False],
"batch_size": [1, 2],
}
)
)
@pytest.mark.fp16
@require_torch_gpu
@pytest.mark.gpu_test
def test_fp16_inference(self, test_name: str, model_type: str, use_to_operator: bool, batch_size: int):
model_id = MODELS_DICT[model_type]
self._test_fp16_inference(
model_id,
model_type=model_type,
use_to_operator=use_to_operator,
automodel_class=AutoModel,
batch_size=batch_size,
)

@parameterized.expand(grid_parameters({"model_type": SUPPORTED_ARCH, "batch_size": [1, 2]}))
def test_generation(self, test_name: str, model_type: str, batch_size: int):
model_id = MODELS_DICT[model_type]
processor = AutoProcessor.from_pretrained(model_id)

model = AutoModel.from_pretrained(model_id)

text = ["This is me and me"]
if batch_size > 1:
text.append("Please continue this my dear me")
inp = processor(text, return_tensors="pt")

length = 50

result_vanilla = model.generate(
**inp, num_beams=1, fine_temperature=None, do_sample=False, semantic_max_new_tokens=length
)

model = BetterTransformer.transform(model)

result_bettertransformer = model.generate(
**inp, num_beams=1, fine_temperature=None, do_sample=False, semantic_max_new_tokens=length
)

self.assertTrue(
torch.allclose(result_vanilla, result_bettertransformer),
f" Maxdiff: {(result_vanilla - result_bettertransformer).abs().max()}",
)

@parameterized.expand(grid_parameters(FULL_GRID))
def test_invert_modules(self, test_name: str, model_type: str, keep_original_model=False):
model_id = MODELS_DICT[model_type]
self._test_invert_modules(model_id=model_id, keep_original_model=keep_original_model)

@parameterized.expand(grid_parameters(FULL_GRID))
def test_save_load_invertible(self, test_name: str, model_type: str, keep_original_model=False):
model_id = MODELS_DICT[model_type]
self._test_save_load_invertible(model_id=model_id, keep_original_model=keep_original_model)


class BetterTransformersWhisperTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Testing suite for Whisper - tests all the tests defined in `BetterTransformersTestMixin`
Expand Down
5 changes: 3 additions & 2 deletions tests/bettertransformer/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def test_raise_save_pretrained_error(self, test_name: str, model_type: str, keep
Test if the converted model raises an error when calling `save_pretrained`
but not when the model is reverted
"""
if model_type in ["wav2vec2", "hubert"] and keep_original_model is True:

if model_type in ["wav2vec2", "hubert", "bark"] and keep_original_model is True:
self.skipTest("These architectures do not support deepcopy")

model_ids = (
Expand Down Expand Up @@ -118,7 +119,7 @@ def test_raise_activation_fun(self, model_type: str):
if BetterTransformerManager.requires_strict_validation(model_type) is False:
self.skipTest("The architecture does not require a specific activation function")

if model_type in ["wav2vec2", "hubert"]:
if model_type in ["wav2vec2", "hubert", "bark"]:
self.skipTest("These architectures do not support deepcopy (raise unrelated error)")

layer_classes = BetterTransformerManager.MODEL_MAPPING[model_type].keys()
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

MODELS_DICT = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"bark": "ylacombe/bark-small", # TODO: put a smaller model, this one is 1.7GB...
"bart": "hf-internal-testing/tiny-random-bart",
"bert": "hf-internal-testing/tiny-random-BertModel",
"bert-generation": "ybelkada/random-tiny-BertGenerationModel",
Expand Down

0 comments on commit e8c1266

Please sign in to comment.