Skip to content

Commit

Permalink
fix bt test failures due to default sdpa attention
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 23, 2024
1 parent e0f5812 commit c9deaec
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def load_bettertransformer(
**kwargs,
):
if model_kwargs is None:
model_kwargs = {}
model_kwargs = {"attn_implementation": "eager"}

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_inference_speed(self):
"""
model_name = "bert-base-uncased"

hf_model = AutoModel.from_pretrained(model_name).eval()
hf_model = AutoModel.from_pretrained(model_name, attn_implementation="eager").eval()
bt_model = BetterTransformer.transform(hf_model, keep_original_model=True)

BATCH_SIZE = 8
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _test_logits(self, model_id: str, model_type: str, **preprocessor_kwargs):
inputs = self.prepare_inputs_for_class(model_id=model_id, model_type=model_type, **preprocessor_kwargs)

torch.manual_seed(0)
hf_random_model = AutoModel.from_pretrained(model_id).eval()
hf_random_model = AutoModel.from_pretrained(model_id, attn_implementation="eager").eval()
random_config = hf_random_model.config

hf_random_model = hf_random_model.eval()
Expand Down

0 comments on commit c9deaec

Please sign in to comment.