Skip to content

Commit

Permalink
udpate setup
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 25, 2024
1 parent 9fa9e9f commit bf913c2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ jobs:
run: |
pip install --upgrade pip
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install git+https://github.com/huggingface/transformers
pip install .[tests,onnxruntime]
- name: Test with pytest (in series)
Expand Down
7 changes: 3 additions & 4 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ def codegen_wrapped_scaled_dot_product(
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:

if not check_if_transformers_greater("4.44.99"):
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, mask_value)

Expand All @@ -219,7 +220,6 @@ def codegen_wrapped_scaled_dot_product(
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
else:

attention_mask = attention_mask[:, :, :, : key.shape[-2]]

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -229,7 +229,6 @@ def codegen_wrapped_scaled_dot_product(
return sdpa_result, None



# Adapted from transformers.models.opt.modeling_opt.OPTAttention.forward
def opt_forward(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
super()._save_pretrained(save_directory)
self.generation_config.save_pretrained(save_directory)


class ORTGPTBigCodeForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
REQUIRED_PKGS = [
"coloredlogs",
"sympy",
"transformers[sentencepiece]>=4.29,<4.46.0",
"transformers @ git+https://github.com/huggingface/transformers.git",
# "transformers[sentencepiece]>=4.29,<4.46.0",
"torch>=1.11",
"packaging",
"numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569
Expand Down

0 comments on commit bf913c2

Please sign in to comment.