-
Notifications
You must be signed in to change notification settings - Fork 455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support gpt_bigcode in bettertransformer #1252
Conversation
@fxmarty I just quickly tested out support for this commit and seem to run into an error. Error trace: TypeError Traceback (most recent call last)
Cell In[19], line 2
1 with torch.no_grad():
----> 2 gen_tokens = model.generate(
3 **encoding,
4 generation_config=generation_config,
5 #use_cache=True
6 )
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/generation/utils.py:1538, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1532 raise ValueError(
1533 "num_return_sequences has to be 1 when doing greedy search, "
1534 f"but is {generation_config.num_return_sequences}."
1535 )
1537 # 11. run greedy search
-> 1538 return self.greedy_search(
1539 input_ids,
1540 logits_processor=logits_processor,
1541 stopping_criteria=stopping_criteria,
1542 pad_token_id=generation_config.pad_token_id,
1543 eos_token_id=generation_config.eos_token_id,
1544 output_scores=generation_config.output_scores,
1545 return_dict_in_generate=generation_config.return_dict_in_generate,
1546 synced_gpus=synced_gpus,
1547 streamer=streamer,
1548 **model_kwargs,
1549 )
1551 elif is_contrastive_search_gen_mode:
1552 if generation_config.num_return_sequences > 1:
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/generation/utils.py:2362, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2359 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2361 # forward pass to get next token
-> 2362 outputs = self(
2363 **model_inputs,
2364 return_dict=True,
2365 output_attentions=output_attentions,
2366 output_hidden_states=output_hidden_states,
2367 )
2369 if synced_gpus and this_peer_finished:
2370 continue # don't waste resources running the code we don't need
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:807, in GPTBigCodeForCausalLM.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
799 r"""
800 labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
801 Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
802 `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
803 are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
804 """
805 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 807 transformer_outputs = self.transformer(
808 input_ids,
809 past_key_values=past_key_values,
810 attention_mask=attention_mask,
811 token_type_ids=token_type_ids,
812 position_ids=position_ids,
813 head_mask=head_mask,
814 inputs_embeds=inputs_embeds,
815 encoder_hidden_states=encoder_hidden_states,
816 encoder_attention_mask=encoder_attention_mask,
817 use_cache=use_cache,
818 output_attentions=output_attentions,
819 output_hidden_states=output_hidden_states,
820 return_dict=return_dict,
821 )
822 hidden_states = transformer_outputs[0]
824 lm_logits = self.lm_head(hidden_states)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:672, in GPTBigCodeModel.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)
662 outputs = torch.utils.checkpoint.checkpoint(
663 create_custom_forward(block),
664 hidden_states,
(...)
669 encoder_attention_mask,
670 )
671 else:
--> 672 outputs = block(
673 hidden_states,
674 layer_past=layer_past,
675 attention_mask=attention_mask,
676 head_mask=head_mask[i],
677 encoder_hidden_states=encoder_hidden_states,
678 encoder_attention_mask=encoder_attention_mask,
679 use_cache=use_cache,
680 output_attentions=output_attentions,
681 )
683 hidden_states = outputs[0]
684 if use_cache:
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:316, in GPTBigCodeBlock.forward(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
314 residual = hidden_states
315 hidden_states = self.ln_1(hidden_states)
--> 316 attn_outputs = self.attn(
317 hidden_states,
318 layer_past=layer_past,
319 attention_mask=attention_mask,
320 head_mask=head_mask,
321 use_cache=use_cache,
322 output_attentions=output_attentions,
323 )
324 attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
325 outputs = attn_outputs[1:]
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/optimum/bettertransformer/models/decoder_models.py:388, in GPTBigCodeAttentionLayerBetterTransformer.forward(self, *args, **kwargs)
387 def forward(self, *args, **kwargs):
--> 388 return gpt_bigcode_forward(self, *args, **kwargs)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/optimum/bettertransformer/models/attention.py:792, in gpt_bigcode_forward(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
788 key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
790 # Difference with the transformers implementation: there is no need to transpose the key here,
791 # as SDPA expects seq_length to be at index -2
--> 792 attn_output = self._attn(query, key, value, attention_mask, head_mask)
794 if not self.multi_query:
795 attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/optimum/bettertransformer/models/attention.py:685, in gpt_bigcode_wrapped_scaled_dot_product(self, query, key, value, attention_mask, head_mask)
683 softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
684 if self.scale_attention_softmax_in_fp32 and query.dtype != softmax_dtype:
--> 685 query = query / (self.layer_idx + 1)
686 else:
687 query = query / self.head_dim**0.5
TypeError: unsupported operand type(s) for +: 'NoneType' and 'i |
@anmolagarwal999 Thank you I can reproduce the issue. Edit: fixed |
@fxmarty The outputs with and without BetterTransformer do not seem to match for me. Since I am using greedy decoding ie no probabilistic sampling, I did expect the outputs to match. NOTE: I am not providing batched inputs. Input 1:Write a function to add 2 numbers. Model Output (without BetterTransformer):def add(a, b):
return a + b Example usage: print(add(2, 3)) # Output: 5 Model Output (with BetterTransformer)Write a function to add 2 numbers. Input 2:'Write a function to find if a number is prime or not' Model Output (without BetterTransformer):Write a function to find if a number is prime or not. def is_prime(n):
if n < 2:
return False
for i in range(2, int(n**0.5)+1):
if n % i == 0:
return False
return True The function takes an integer Model Output (with BetterTransformer)`Write a function to find if a number is prime or not. The function should return true if the number is prime or not. Here's an example of a function that checks if the number is a number or not.` |
Thank you, this is surprising to me as I tested quite a bit https://github.com/huggingface/optimum/blob/main/tests/bettertransformer/test_decoder.py. Can you provide a reproduction? I guess the logits tests should be run as well on GPU + fp16 just for safety. |
I am new to this. How can I provide the reproduction ? @fxmarty |
You can copy-paste here the script you used to get those results, with "```python" before and "```" after and after for code formatting. Edit: I can actually reproduce the issue. Interestingly only on fp16. |
That is correct. This is how I was initializing the model: model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
) |
Thank you. This issue occurs only in fp16 with This will be fixed with #1255, I'll add tests as well. Thanks a lot for testing it - proves that the current test suite is not strong enough to detect these tricky bugs! |
@anmolagarwal999 Let me know if the output matches now (on the main branch). For speedup, note that you need to use https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel to make sure you actually dispatch to flash attention or memory-efficient attention. It is possible the hardware you are running on does not support it. |
@fxmarty The hardware does support the flash-attention and memory-efficient attention. I checked it using: print(torch.backends.cuda.flash_sdp_enabled()) # True
print(torch.backends.cuda.mem_efficient_sdp_enabled()) #True
print(torch.backends.cuda.math_sdp_enabled()) #True
|
This is one was much more painful than I thought - the indexing in bigcode is slightly unusual.
Fixes #1249