Skip to content

Commit

Permalink
fix gpt bigcode ONNX export for transformers<=4.36.0
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 25, 2024
1 parent 5ece6e8 commit 21f709c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test_export_onnx_cli.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-20.04]
transformers-version: ["4.26.0", "4.42.*"]

runs-on: ${{ matrix.os }}
steps:
Expand All @@ -27,6 +28,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies for pytorch export
run: |
pip install transformers==${{ matrix.transformers-version }}
pip install .[tests,exporters]
- name: Test with unittest
working-directory: tests
Expand Down
13 changes: 9 additions & 4 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,13 @@ def __init__(
model.decoder.model.decoder.config.use_cache = True


def _unmask_unattended_patched(
expanded_mask: torch.Tensor,
min_dtype: float,
def _unmask_unattended_patched_legacy(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
):
return expanded_mask

def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float):
return expanded_mask

def _make_causal_mask_patched(
input_ids_shape: torch.Size,
Expand Down Expand Up @@ -316,7 +317,11 @@ def _make_causal_mask_patched(


_make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched)
_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched)

if _transformers_version >= version.parse("4.39.0"):
_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched)
else:
_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy)


# Adapted from _prepare_4d_causal_attention_mask
Expand Down

0 comments on commit 21f709c

Please sign in to comment.