Skip to content

Commit

Permalink
use original wrapper generator and eager attn
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed May 22, 2024
1 parent c7c214d commit 5f58a28
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
33 changes: 26 additions & 7 deletions optimum/fx/optimization/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@
import operator
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List
from typing import List

import torch
from torch.fx import GraphModule, Node, Proxy
from transformers.file_utils import add_end_docstrings
from transformers.utils.fx import gen_constructor_wrapper


if TYPE_CHECKING:
from torch.fx import GraphModule, Node


_ATTRIBUTES_DOCSTRING = r"""
Expand Down Expand Up @@ -402,7 +398,7 @@ class FuseBiasInLinear(ReversibleTransformation):
preserves_computation = True

def transform(self, graph_module: "GraphModule") -> "GraphModule":
torch_ones = gen_constructor_wrapper(torch.ones)[0]
torch_ones = _gen_constructor_wrapper(torch.ones)[0]

def insert_concat(linear_input):
shape = linear_input.shape[:-1] + (1,)
Expand Down Expand Up @@ -803,3 +799,26 @@ def reverse(self, graph_module):
return ComposeTransformation._reverse_composition(graph_module)

return ComposeTransformation()


# was removed from transformers in favor of gen_constructor_wrapper which only works during tracing
# TODO: should transformation be applied during tracing ? or by a torch.compile backend instead ?
def _gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None

def check_has_proxy(v):
if isinstance(v, Proxy):
nonlocal proxy
proxy = v

torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)

if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)

return wrapper, target
3 changes: 2 additions & 1 deletion tests/fx/optimization/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def transform(self, graph_module):


def get_bert_model():
model = BertModel.from_pretrained(_MODEL_NAME)
# sdpa attn became default
model = BertModel.from_pretrained(_MODEL_NAME, attn_implementation="eager")
model.eval()
traced = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
return model, traced
Expand Down

0 comments on commit 5f58a28

Please sign in to comment.