Skip to content

Commit

Permalink
fix bt
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Oct 23, 2024
1 parent 9435122 commit d25cd97
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 38 deletions.
36 changes: 27 additions & 9 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
import torch
from packaging.version import parse

from ..utils import check_if_pytorch_greater, is_accelerate_available, recurse_getattr, recurse_setattr
from ..utils import (
check_if_pytorch_greater,
check_if_torch_greater,
is_accelerate_available,
recurse_getattr,
recurse_setattr,
)
from .models import BetterTransformerManager


Expand Down Expand Up @@ -213,15 +219,18 @@ def transform(
hf_config = model.config
if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]:
raise ValueError(
f"Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type {hf_config.model_type}. As such, there is no need to use `model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-and-memory-efficient-attention-through-pytorchs-scaleddotproductattention."
f"Transformers now supports natively BetterTransformer optimizations (torch.nn.functional.scaled_dot_product_attention) for the model type {hf_config.model_type}. "
"As such, there is no need to use `model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. "
"Please upgrade to transformers>=4.36 and torch>=2.1.1 to use it. "
"Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention."
)

# Check if we have to load the model using `accelerate`
if hasattr(model, "hf_device_map"):
load_accelerate = True
hf_device_map = model.hf_device_map
else:
load_accelerate = False
if hasattr(hf_config, "_attn_implementation") and hf_config._attn_implementation == "sdpa":
raise ValueError(
"This model already uses BetterTransformer optimizations from Transformers (torch.nn.functional.scaled_dot_product_attention). "
"As such, there is no need to use `model.to_bettertransformers()` or `BetterTransformer.transform(model)` from the Optimum library. "
"Details: https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention."
)

if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
raise Exception(
Expand All @@ -241,11 +250,20 @@ def transform(
f" Currently supported models are: {BetterTransformerManager.MODEL_MAPPING.keys()}."
)

if parse(torch.__version__) <= parse("1.14"):
if not check_if_torch_greater("2.0"):
raise ValueError(
f"BetterTransformer requires torch>=2.0 but {torch.__version__} is installed. Please upgrade PyTorch."
)

hf_config = model.config

# Check if we have to load the model using `accelerate`
if hasattr(model, "hf_device_map"):
load_accelerate = True
hf_device_map = model.hf_device_map
else:
load_accelerate = False

if load_accelerate:
# Remove the hooks from the original model to avoid weights being on `meta` device.
remove_hook_from_module(model, recurse=True)
Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def check_if_torch_greater(target_version: str) -> bool:
if not is_torch_available():
return False

return version.parse(torch_version) >= version.parse(target_version)
return torch_version >= version.parse(target_version)


@contextmanager
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"einops",
"timm",
"scikit-learn",
"sentencepiece",
"rjieba",
]

Expand Down
20 changes: 12 additions & 8 deletions tests/bettertransformer/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

class TestsWhisper(unittest.TestCase):
def test_error_message(self):
model = AutoModel.from_pretrained("openai/whisper-tiny")
model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")

with self.assertRaises(ValueError) as cm:
model = BetterTransformer.transform(model)
Expand Down Expand Up @@ -82,15 +82,19 @@ def _test_fp16_inference(
set_seed(0)

if not use_to_operator:
hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
hf_random_model = automodel_class.from_pretrained(
model_id, torch_dtype=torch.float16, attn_implementation="eager"
).to(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=False)

hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
hf_random_model = automodel_class.from_pretrained(
model_id, torch_dtype=torch.float16, attn_implementation="eager"
).to(0)
else:
hf_random_model = automodel_class.from_pretrained(model_id).to(0)
hf_random_model = automodel_class.from_pretrained(model_id, attn_implementation="eager").to(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=False)

hf_random_model = automodel_class.from_pretrained(model_id).to(0)
hf_random_model = automodel_class.from_pretrained(model_id, attn_implementation="eager").to(0)
hf_random_model = hf_random_model.to(torch.float16)
converted_model = converted_model.to(torch.float16)

Expand Down Expand Up @@ -147,7 +151,7 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int):
model_id = MODELS_DICT[model_type]
processor = AutoProcessor.from_pretrained(model_id)

model = AutoModel.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, attn_implementation="eager")

text = ["This is me and me"]
if batch_size > 1:
Expand Down Expand Up @@ -217,14 +221,14 @@ def test_logits(self, model_type: str):
inputs = self.prepare_inputs_for_class(model_id, model_type)

torch.manual_seed(0)
hf_random_model = AutoModel.from_pretrained(model_id).eval()
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
random_config = hf_random_model.config

torch.manual_seed(0)
converted_model = BetterTransformer.transform(hf_random_model)

torch.manual_seed(0)
hf_random_model = AutoModel.from_pretrained(model_id).eval()
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
random_config = hf_random_model.config

self.assertFalse(
Expand Down
12 changes: 7 additions & 5 deletions tests/bettertransformer/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class BetterTransformerIntegrationTests(unittest.TestCase):
def test_raise_error_on_double_transform_call(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-BertModel", attn_implementation="eager")

with self.assertRaises(Exception) as cm:
bt_model = BetterTransformer.transform(model)
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_raise_on_save(self, model_type: str):
)
for model_id in model_ids:
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
hf_model = AutoModel.from_pretrained(model_id).eval()
hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
bt_model = BetterTransformer.transform(hf_model, keep_original_model=False)
bt_model.save_pretrained(tmpdirname)

Expand All @@ -73,7 +73,7 @@ def test_conversion(self, model_type: str):
MODELS_DICT[model_type] if isinstance(MODELS_DICT[model_type], tuple) else (MODELS_DICT[model_type],)
)
for model_id in model_ids:
hf_random_model = AutoModel.from_pretrained(model_id)
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager")
converted_model = BetterTransformer.transform(hf_random_model)

self.assertTrue(
Expand All @@ -99,7 +99,7 @@ def test_raise_save_pretrained_error(self, test_name: str, model_type: str, keep
)
for model_id in model_ids:
# get hf and bt model
hf_model = AutoModel.from_pretrained(model_id)
hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager")
# get bt model and invert it
bt_model = BetterTransformer.transform(hf_model, keep_original_model=keep_original_model)

Expand Down Expand Up @@ -145,9 +145,11 @@ def test_raise_activation_fun(self, model_type: str):
)() # random config class for the model to test
hf_random_config.hidden_act = "silu"

hf_random_model = AutoModel.from_config(hf_random_config).eval()
hf_random_model = AutoModel.from_config(hf_random_config, attn_implementation="eager").eval()

with self.assertRaises(ValueError) as cm:
_ = BetterTransformer.transform(hf_random_model, keep_original_model=True)

self.assertTrue("Activation function" in str(cm.exception))

def test_dict_class_consistency(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in

model_id = MODELS_DICT[model_type]

model = AutoModelForCausalLM.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager")

normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config)

Expand Down Expand Up @@ -167,7 +167,7 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd
model_id = MODELS_DICT[model_type]
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="eager")

if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
if tokenizer.eos_token != "":
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_invert_model_logits(self, test_name: str, model_type: str, keep_origina
@require_torch_gpu
@require_accelerate
def test_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_memory=None):
hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory).eval()
hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto", max_memory=max_memory, attn_implementation="eager").eval()
bt_model = BetterTransformer.transform(
hf_model, keep_original_model=keep_original_model, max_memory=max_memory
)
Expand Down
4 changes: 3 additions & 1 deletion tests/bettertransformer/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def check_accelerate_compatibility_cpu_gpu(self, keep_original_model=True, max_m
If this works for roberta, it should work for all other models too.
"""

hf_model = AutoModel.from_pretrained("xlm-roberta-base", device_map="auto", max_memory=max_memory).eval()
hf_model = AutoModel.from_pretrained(
"xlm-roberta-base", device_map="auto", max_memory=max_memory, attn_implementation="eager"
).eval()
bt_model = BetterTransformer.transform(
hf_model, keep_original_model=keep_original_model, max_memory=max_memory
)
Expand Down
3 changes: 1 addition & 2 deletions tests/bettertransformer/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest
"mbart",
"pegasus",
"prophetnet",
"t5",
]

FULL_GRID = {
Expand Down Expand Up @@ -153,7 +152,7 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd
model_id = MODELS_DICT[model_type]
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, attn_implementation="eager")

if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Expand Down
4 changes: 3 additions & 1 deletion tests/bettertransformer/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def timing_cuda(model, num_batches, input_ids, masks, decoder_input_ids):


def benchmark(model_name: str, num_batches: int, batch_size: int, max_seqlen: int, is_half: bool):
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16 if is_half else None).eval()
hf_model = AutoModel.from_pretrained(
model_name, torch_dtype=torch.float16 if is_half else None, attn_implementation="eager"
).eval()
hf_model = hf_model.to("cuda:0")
bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

Expand Down
18 changes: 10 additions & 8 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def _test_fp16_inference(

torch.manual_seed(0)
if not use_to_operator:
hf_random_model = automodel_class.from_pretrained(model_id, torch_dtype=torch.float16).to(0)
hf_random_model = automodel_class.from_pretrained(
model_id, torch_dtype=torch.float16, attn_implementation="eager"
).to(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True)
else:
hf_random_model = automodel_class.from_pretrained(model_id).to(0)
hf_random_model = automodel_class.from_pretrained(model_id, attn_implementation="eager").to(0)
converted_model = BetterTransformer.transform(hf_random_model, keep_original_model=True)
hf_random_model = hf_random_model.to(torch.float16)
converted_model = converted_model.to(torch.float16)
Expand Down Expand Up @@ -169,7 +171,7 @@ def _test_fp16_inference(
def _test_logits_backward(self, model_id: str, model_type: str, **preprocessor_kwargs):
inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs)

hf_random_model = AutoModel.from_pretrained(model_id).eval()
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
random_config = hf_random_model.config

# I could not obtain reproducible results with `torch.manual_seed` nor with
Expand Down Expand Up @@ -309,7 +311,7 @@ def _test_train_decoder(self, model_id: str, model_type: str, **kwargs):
"""
inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **kwargs)

hf_random_model = AutoModel.from_pretrained(model_id).eval()
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()

bt_model = BetterTransformer.transform(hf_random_model, keep_original_model=True)
bt_model.train()
Expand All @@ -328,7 +330,7 @@ def _test_invert_modules(self, model_id, keep_original_model=False):
r"""
Test that the inverse converted model and hf model have the same modules
"""
hf_model = AutoModel.from_pretrained(model_id)
hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager")
hf_modules = list(hf_model.modules())

bt_model = BetterTransformer.transform(hf_model, keep_original_model=keep_original_model)
Expand All @@ -349,7 +351,7 @@ def _test_invert_modules(self, model_id, keep_original_model=False):

def _test_save_load_invertible(self, model_id, keep_original_model=True):
with tempfile.TemporaryDirectory() as tmpdirname:
hf_model = AutoModel.from_pretrained(model_id).eval()
hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
hf_model_state_dict = copy.deepcopy(hf_model.state_dict())

bt_model = BetterTransformer.transform(hf_model, keep_original_model=keep_original_model)
Expand All @@ -362,7 +364,7 @@ def _test_save_load_invertible(self, model_id, keep_original_model=True):
# saving a normal transformers bark model fails because of shared tensors
bt_model.save_pretrained(tmpdirname, safe_serialization=hf_model.config.model_type != "bark")

bt_model_from_load = AutoModel.from_pretrained(tmpdirname)
bt_model_from_load = AutoModel.from_pretrained(tmpdirname, attn_implementation="eager")

self.assertEqual(
set(bt_model.state_dict().keys()),
Expand Down Expand Up @@ -397,7 +399,7 @@ def _test_invert_model_logits(
"""
inputs = self.prepare_inputs_for_class(model_id, model_type=model_type, **preprocessor_kwargs)

hf_model = AutoModel.from_pretrained(model_id)
hf_model = AutoModel.from_pretrained(model_id, attn_implementation="eager")
hf_model = hf_model.eval()

with torch.inference_mode():
Expand Down

0 comments on commit d25cd97

Please sign in to comment.