Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gpt_bigcode in bettertransformer #1252

Merged
merged 1 commit into from
Aug 4, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Aug 3, 2023

This is one was much more painful than I thought - the indexing in bigcode is slightly unusual.

Fixes #1249

@fxmarty fxmarty merged commit 393113f into huggingface:main Aug 4, 2023
61 of 64 checks passed
@anmolagarwal999
Copy link

@fxmarty I just quickly tested out support for this commit and seem to run into an error.
Python version: 3.8, System: "Ubuntu 20.04.6 LTS.

Error trace:

TypeError                                 Traceback (most recent call last)
Cell In[19], line 2
      1 with torch.no_grad():
----> 2                 gen_tokens = model.generate(
      3                     **encoding,
      4                     generation_config=generation_config, 
      5                     #use_cache=True
      6                 )

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/generation/utils.py:1538, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1532         raise ValueError(
   1533             "num_return_sequences has to be 1 when doing greedy search, "
   1534             f"but is {generation_config.num_return_sequences}."
   1535         )
   1537     # 11. run greedy search
-> 1538     return self.greedy_search(
   1539         input_ids,
   1540         logits_processor=logits_processor,
   1541         stopping_criteria=stopping_criteria,
   1542         pad_token_id=generation_config.pad_token_id,
   1543         eos_token_id=generation_config.eos_token_id,
   1544         output_scores=generation_config.output_scores,
   1545         return_dict_in_generate=generation_config.return_dict_in_generate,
   1546         synced_gpus=synced_gpus,
   1547         streamer=streamer,
   1548         **model_kwargs,
   1549     )
   1551 elif is_contrastive_search_gen_mode:
   1552     if generation_config.num_return_sequences > 1:

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/generation/utils.py:2362, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2359 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2361 # forward pass to get next token
-> 2362 outputs = self(
   2363     **model_inputs,
   2364     return_dict=True,
   2365     output_attentions=output_attentions,
   2366     output_hidden_states=output_hidden_states,
   2367 )
   2369 if synced_gpus and this_peer_finished:
   2370     continue  # don't waste resources running the code we don't need

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:807, in GPTBigCodeForCausalLM.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    799 r"""
    800 labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
    801     Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
    802     `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
    803     are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
    804 """
    805 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 807 transformer_outputs = self.transformer(
    808     input_ids,
    809     past_key_values=past_key_values,
    810     attention_mask=attention_mask,
    811     token_type_ids=token_type_ids,
    812     position_ids=position_ids,
    813     head_mask=head_mask,
    814     inputs_embeds=inputs_embeds,
    815     encoder_hidden_states=encoder_hidden_states,
    816     encoder_attention_mask=encoder_attention_mask,
    817     use_cache=use_cache,
    818     output_attentions=output_attentions,
    819     output_hidden_states=output_hidden_states,
    820     return_dict=return_dict,
    821 )
    822 hidden_states = transformer_outputs[0]
    824 lm_logits = self.lm_head(hidden_states)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:672, in GPTBigCodeModel.forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)
    662     outputs = torch.utils.checkpoint.checkpoint(
    663         create_custom_forward(block),
    664         hidden_states,
   (...)
    669         encoder_attention_mask,
    670     )
    671 else:
--> 672     outputs = block(
    673         hidden_states,
    674         layer_past=layer_past,
    675         attention_mask=attention_mask,
    676         head_mask=head_mask[i],
    677         encoder_hidden_states=encoder_hidden_states,
    678         encoder_attention_mask=encoder_attention_mask,
    679         use_cache=use_cache,
    680         output_attentions=output_attentions,
    681     )
    683 hidden_states = outputs[0]
    684 if use_cache:

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py:316, in GPTBigCodeBlock.forward(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
    314 residual = hidden_states
    315 hidden_states = self.ln_1(hidden_states)
--> 316 attn_outputs = self.attn(
    317     hidden_states,
    318     layer_past=layer_past,
    319     attention_mask=attention_mask,
    320     head_mask=head_mask,
    321     use_cache=use_cache,
    322     output_attentions=output_attentions,
    323 )
    324 attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
    325 outputs = attn_outputs[1:]

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/optimum/bettertransformer/models/decoder_models.py:388, in GPTBigCodeAttentionLayerBetterTransformer.forward(self, *args, **kwargs)
    387 def forward(self, *args, **kwargs):
--> 388     return gpt_bigcode_forward(self, *args, **kwargs)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/optimum/bettertransformer/models/attention.py:792, in gpt_bigcode_forward(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
    788 key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
    790 # Difference with the transformers implementation: there is no need to transpose the key here,
    791 # as SDPA expects seq_length to be at index -2
--> 792 attn_output = self._attn(query, key, value, attention_mask, head_mask)
    794 if not self.multi_query:
    795     attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)

File ~/anaconda3/envs/wizard_coder/lib/python3.8/site-packages/optimum/bettertransformer/models/attention.py:685, in gpt_bigcode_wrapped_scaled_dot_product(self, query, key, value, attention_mask, head_mask)
    683     softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
    684     if self.scale_attention_softmax_in_fp32 and query.dtype != softmax_dtype:
--> 685         query = query / (self.layer_idx + 1)
    686 else:
    687     query = query / self.head_dim**0.5

TypeError: unsupported operand type(s) for +: 'NoneType' and 'i

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 4, 2023

@anmolagarwal999 Thank you I can reproduce the issue. scale_attention_softmax_in_fp32=True was not tested in our CI so that's why. I'll fix promptly.

Edit: fixed

@anmolagarwal999
Copy link

anmolagarwal999 commented Aug 4, 2023

@fxmarty The outputs with and without BetterTransformer do not seem to match for me. Since I am using greedy decoding ie no probabilistic sampling, I did expect the outputs to match. NOTE: I am not providing batched inputs.

Input 1:

Write a function to add 2 numbers.

Model Output (without BetterTransformer):

def add(a, b):
    return a + b

Example usage:

print(add(2, 3)) # Output: 5

Model Output (with BetterTransformer)

Write a function to add 2 numbers.


Input 2:

'Write a function to find if a number is prime or not'

Model Output (without BetterTransformer):

Write a function to find if a number is prime or not.

def is_prime(n):
    if n < 2:
        return False
    for i in range(2, int(n**0.5)+1):
        if n % i == 0:
            return False
    return True

The function takes an integer n as input and returns True if it is prime, and False otherwise. It first checks if n is less than 2, in which case it is not prime. Then it loops through all the integers from 2 to the square root of n (inclusive) and checks if any of them divide n evenly. If so, then n is not prime and the function returns False. If no such integer is found, then n is prime and the function returns True.

Model Output (with BetterTransformer)

`Write a function to find if a number is prime or not.

The function should return true if the number is prime or not.

Here's an example of a function that checks if the number is a number or not.`

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 4, 2023

Thank you, this is surprising to me as I tested quite a bit https://github.com/huggingface/optimum/blob/main/tests/bettertransformer/test_decoder.py. Can you provide a reproduction?

I guess the logits tests should be run as well on GPU + fp16 just for safety.

@anmolagarwal999
Copy link

anmolagarwal999 commented Aug 4, 2023

I am new to this. How can I provide the reproduction ? @fxmarty

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 4, 2023

You can copy-paste here the script you used to get those results, with "```python" before

and "```" after and after for code formatting.

Edit: I can actually reproduce the issue. Interestingly only on fp16.

@anmolagarwal999
Copy link

You can copy-paste here the script you used to get those results, with "```python" before

and "```" after and after for code formatting.

Edit: I can actually reproduce the issue. Interestingly only on fp16.

That is correct. This is how I was initializing the model:

model = AutoModelForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map="auto",
        )

@fxmarty fxmarty mentioned this pull request Aug 4, 2023
@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 4, 2023

Thank you. This issue occurs only in fp16 with scale_attention_softmax_in_fp32=True. This was not catched by the CI.

This will be fixed with #1255, I'll add tests as well.

Thanks a lot for testing it - proves that the current test suite is not strong enough to detect these tricky bugs!

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 4, 2023

@anmolagarwal999 Let me know if the output matches now (on the main branch).

For speedup, note that you need to use https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel to make sure you actually dispatch to flash attention or memory-efficient attention. It is possible the hardware you are running on does not support it.

@anmolagarwal999
Copy link

@fxmarty The hardware does support the flash-attention and memory-efficient attention. I checked it using:

print(torch.backends.cuda.flash_sdp_enabled()) # True
print(torch.backends.cuda.mem_efficient_sdp_enabled()) #True
print(torch.backends.cuda.math_sdp_enabled()) #True
  • However, the speed difference I see is a factor of 1.14 only. Is this close to the expected factor ?
  • Also, can I use batched inference with BetterTransformer() ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SDPA for gpt_bigcode/starcoder
2 participants