Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FX CI #1866

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions .github/workflows/test_fx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: FX / Python - Test

on:
push:
branches: [ main ]
branches: [main]
pull_request:
branches: [ main ]
branches: [main]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
Expand All @@ -20,16 +20,19 @@ jobs:

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[tests]
pip install git+https://github.com/huggingface/transformers.git
- name: Test with unittest
working-directory: tests
run: |
python -m pytest fx/optimization/test_transformations.py --exitfirst
- name: Checkout code
uses: actions/checkout@v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip install .[tests]

- name: Test with pytest
working-directory: tests
run: |
python -m pytest -s -v -x fx/optimization
24 changes: 20 additions & 4 deletions optimum/fx/optimization/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,31 @@
import operator
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List
from typing import List

import torch
from torch.fx import GraphModule, Node
from transformers.file_utils import add_end_docstrings
from transformers.utils.fx import _gen_constructor_wrapper


if TYPE_CHECKING:
from torch.fx import GraphModule, Node
try:
from transformers.utils.fx import _gen_constructor_wrapper
except ImportError:
from transformers.utils.fx import gen_constructor_wrapper

def _gen_constructor_wrapper(*args, **kwargs):
wrapper, target = gen_constructor_wrapper(*args, **kwargs)

def wrapper_with_forced_tracing(*_args, **_kwargs):
import torch.fx._symbolic_trace

orginal_flag = torch.fx._symbolic_trace._is_fx_tracing_flag
torch.fx._symbolic_trace._is_fx_tracing_flag = True
out = wrapper(*_args, **_kwargs)
torch.fx._symbolic_trace._is_fx_tracing_flag = orginal_flag
return out

return wrapper_with_forced_tracing, target


_ATTRIBUTES_DOCSTRING = r"""
Expand Down
3 changes: 2 additions & 1 deletion tests/fx/optimization/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def transform(self, graph_module):


def get_bert_model():
model = BertModel.from_pretrained(_MODEL_NAME)
# sdpa attn became default
model = BertModel.from_pretrained(_MODEL_NAME, attn_implementation="eager")
model.eval()
traced = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
return model, traced
Expand Down
Loading