From 1d504c851c26d54e8a07b2a2245fd6cbd4d283e0 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:00:19 -0800 Subject: [PATCH] Adding support for Rotary Position Embeddings (#675) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * removed the roformer impementation of rope * .. * fixed all the lint errors * .. * .. * ../llmfoundry/models/mpt/modeling_mpt.py * .. * .. * .. * added unit test to test rotary embeddings * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/mpt/modeling_mpt.py Accepting the suggestion Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * incorporated some suggestions from the pr * .. * .. * .. * .. * .. * .. * .. * added mark for gpu in the rotary embedding test * .. * .. * .. * removed thecode for hf implementation of rope * .. * .. * added tests * .. * .. * ... * .. * .. * .. * .. * .. * fixed the tests after the merge * minor change * Fixed some tests failing due to a transformers library bug * added check for flash_attention before importing their rotary embedding * added check for flash_attention in tests before using dail rope * fixed tests * .. * .. * temporary fix * .. * .. * fixed a test * .. * minor change * minor changes * added documentation * added documentation * temp commit * made _set_config_defaults recursive * minor changes * reformatted tutorial table * reformatted tutorial table * reformatted tutorial table * added documentation on how to install flash attention 2 * minor changes * minor changes * minor changes * minor changes * minor changes * minor changes * .. * resolved some comments from the PR * fixed tests * modified is_flash_v2_installed * minor changes * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * Update TUTORIAL.md Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * resolved PR comments --------- Co-authored-by: Shashank Rajput Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- TUTORIAL.md | 49 +- llmfoundry/models/layers/attention.py | 71 ++- llmfoundry/models/layers/blocks.py | 43 +- llmfoundry/models/mpt/configuration_mpt.py | 72 ++- llmfoundry/models/mpt/modeling_mpt.py | 129 ++++- tests/test_flash_triton_torch.py | 73 ++- tests/test_model.py | 557 +++++++++++++++++---- tests/test_rope_dail_vs_hf.py | 145 ++++++ 8 files changed, 952 insertions(+), 187 deletions(-) create mode 100644 tests/test_rope_dail_vs_hf.py diff --git a/TUTORIAL.md b/TUTORIAL.md index d019eb9f83..86bd9829e9 100644 --- a/TUTORIAL.md +++ b/TUTORIAL.md @@ -8,27 +8,42 @@ Forging LLMs can be quite complicated — you have to get your data prepared, se This tutorial will provide a brief intro to the repo’s structure and underlying tools (all courtesy of MosaicML, of course), will go over a few example workflows and point you to the related resources within the repo, and will finally cover a number of FAQs that we have encountered since release. +- [LLM Foundry Tutorial](#llm-foundry-tutorial) - [Intro](#intro) - [How this repo is structured](#how-this-repo-is-structured) - [Key components](#key-components) + - [Composer](#composer) + - [StreamingDataset](#streamingdataset) + - [MCLI](#mcli) - [How the YAMLs work](#how-the-yamls-work) - [Example Workflows](#example-workflows) - [Workflow 1: I want to play with a HF model like MPT-7B locally](#workflow-1-i-want-to-play-with-a-hf-model-like-mpt-7b-locally) - [Workflow 2: I want to deploy an inference endpoint with a HF model like MPT-7B](#workflow-2-i-want-to-deploy-an-inference-endpoint-with-a-hf-model-like-mpt-7b) - [Workflow 3: I want to finetune a HF model like MPT-7B](#workflow-3-i-want-to-finetune-a-hf-model-like-mpt-7b) + - [Supervised FineTuning and Instruction FineTuning](#supervised-finetuning-and-instruction-finetuning) + - [Domain Adaptation and Sequence Length Adaptation](#domain-adaptation-and-sequence-length-adaptation) + - [Data](#data) + - [Modeling](#modeling) - [Workflow 4: I want to train a new HF model from scratch](#workflow-4-i-want-to-train-a-new-hf-model-from-scratch) - [FAQs](#faqs) - - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus) - - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do) - - [What hardware can I train on?](#what-hardware-can-i-train-on) - - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on) - - [What is FSDP?](#what-is-fsdp) - - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton-for-mpt-and-which-one-should-i-use) - - [Can I finetune using PEFT / LORA?](#can-i-finetune-using-peft--lora) - - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) - - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer) - - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms) - - [Common installation issues](#common-installation-issues) + - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus) + - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do) + - [What hardware can I train on?](#what-hardware-can-i-train-on) + - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on) + - [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on) + - [What is FSDP?](#what-is-fsdp) + - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use) + - [Limitations](#limitations) + - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir) + - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus) + - [Support for FlashAttention-2](#support-for-flashattention-2) + - [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support) + - [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora) + - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu) + - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer) + - [TransformerEngine and amp\_fp8 support](#transformerengine-and-amp_fp8-support) + - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms) + - [Common installation issues](#common-installation-issues) Let’s get started! @@ -328,6 +343,18 @@ The majority of our training setups use `triton`. --> Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes. What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance. +#### Support for FlashAttention-2 +- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention). + +### What kinds of positional embeddings does LLM Foundry support? +Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf). + +| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes | +|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | | +| ALiBi |
model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. | +| RoPE (Dao-AILab Implementation) |
model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. | +| RoPE (Hugging Face Implementation) |
model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | | ### Can I finetune using PEFT / LoRA? - The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so: diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 39fa7162ac..0503d6d75a 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import Any, List, Optional, Tuple +from typing import Any, Optional import torch import torch.nn as nn @@ -17,12 +17,13 @@ from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY -def is_flash_v2_installed(): +def is_flash_v2_installed(v2_version: str = '2.0.0'): + assert version.parse(v2_version) >= version.parse('2.0.0') try: import flash_attn as flash_attn except: return False - return version.parse(flash_attn.__version__) >= version.parse('2.0.0') + return version.parse(flash_attn.__version__) >= version.parse(v2_version) def is_flash_v1_installed(): @@ -33,6 +34,16 @@ def is_flash_v1_installed(): return version.parse(flash_attn.__version__) < version.parse('2.0.0') +# Before importing any transformers models, we need to disable transformers flash attention if +# we are in an environment with flash attention version <2. Transformers hard errors on a not properly +# gated import otherwise. +if is_flash_v1_installed(): + import transformers + transformers.utils.is_flash_attn_available = lambda: False + +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool) -> bool: # disable causal when it is not needed @@ -70,7 +81,7 @@ def scaled_multihead_dot_product_attention( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -79,7 +90,7 @@ def scaled_multihead_dot_product_attention( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: if multiquery: @@ -183,7 +194,7 @@ def scaled_multihead_dot_product_attention( def check_valid_inputs(*tensors: torch.Tensor, - valid_dtypes: Optional[List[torch.dtype]] = None): + valid_dtypes: Optional[list[torch.dtype]] = None): if valid_dtypes is None: valid_dtypes = [torch.float16, torch.bfloat16] for tensor in tensors: @@ -199,7 +210,7 @@ def flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -208,7 +219,7 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip @@ -337,7 +348,7 @@ def triton_flash_attn_fn( value: torch.Tensor, n_heads: int, kv_n_heads: Optional[int] = None, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, softmax_scale: Optional[float] = None, attn_bias: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None, @@ -346,7 +357,7 @@ def triton_flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func @@ -552,12 +563,13 @@ def __init__( def forward( self, x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -581,6 +593,39 @@ def forward( query = self.q_ln(query).to(dtype) key = self.k_ln(key).to(dtype) + if rotary_emb_w_meta_info is not None: + rotary_emb = rotary_emb_w_meta_info['rotary_emb'] + seq_len = rotary_emb_w_meta_info['seq_len'] + offset_info = rotary_emb_w_meta_info['offset_info'] + bsz, seqlen = query.shape[:2] + query = query.view(bsz, seqlen, -1, self.head_dim) + key = key.view(bsz, seqlen, -1, self.head_dim) + + if rotary_emb_w_meta_info['impl'] == 'dail': + value = value.view(bsz, seqlen, -1, self.head_dim) + + kv = torch.stack([key, value], dim=2) + query, kv = rotary_emb(query, + kv, + seqlen_offset=offset_info, + max_seqlen=seq_len) + [key, value] = torch.unbind(kv, dim=2) + + value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + elif rotary_emb_w_meta_info['impl'] == 'hf': + (cos, sin) = rotary_emb(value, seq_len) + # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb + query = query.transpose(1, 2) + key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, key, cos, sin, + offset_info) + # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + query = query.view(bsz, seqlen, self.d_model) + key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + context, attn_weights, past_key_value = self.attn_fn( query, key, @@ -677,7 +722,7 @@ def __init__( def attn_bias_shape( attn_impl: str, n_heads: int, seq_len: int, alibi: bool, prefix_lm: bool, causal: bool, - use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]: + use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]: if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index a08ef6d77f..6605807c6b 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -12,6 +12,31 @@ from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8, + 'rope': False, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +} + class MPTBlock(nn.Module): @@ -30,18 +55,7 @@ def __init__( **kwargs: Any, ): if attn_config is None: - attn_config = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, - } + attn_config = attn_config_defaults if ffn_config is None: ffn_config = { @@ -58,7 +72,8 @@ def __init__( # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs args_to_exclude_in_attn_class = { 'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', - 'alibi_bias_max' + 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl', + 'rope_dail_config', 'rope_hf_config' } attn_config_subset_for_attn_class = { k: v @@ -94,6 +109,7 @@ def forward( x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, + rotary_emb_w_meta_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, @@ -104,6 +120,7 @@ def forward( a, past_key_value=past_key_value, attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 251e4f5caf..c4ca68d733 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -8,18 +8,16 @@ from transformers import PretrainedConfig -attn_config_defaults: Dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, -} +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.layers.blocks import attn_config_defaults + +# NOTE: All utils are imported directly even if unused so that +# HuggingFace can detect all the needed files to copy into its modules folder. +# Otherwise, certain modules are missing. +# isort: off +from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note) +from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note) +from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note) ffn_config_defaults: Dict = { 'ffn_type': 'mptmlp', @@ -94,6 +92,16 @@ def __init__( Defaults to ``False`` meaning any provided `sequence_id` will be ignored. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. + rope (bool): Whether to use rotary positional embeddings. + rope_theta (int): The base frequency for rope. + rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py). + rope_dail_config (Dict): The configuration for the dail implementation of rope. + type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf). + pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding. + xpos_scale_base (float): The scale base for XPos (if using XPos). + rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length). + type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. + factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp @@ -150,10 +158,12 @@ def __init__( del kwargs['name'] if 'loss_fn' in kwargs: del kwargs['loss_fn'] - if self.attn_config.get('alibi', False): + if self.attn_config.get('alibi', False) or self.attn_config.get( + 'rope', False): self.learned_pos_emb = False warnings.warn( - f'alibi is turned on, setting `learned_pos_emb` to `False.`') + f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' + ) super().__init__(**kwargs) self._validate_config() @@ -164,6 +174,10 @@ def _set_config_defaults(self, config: Dict[str, Any], for k, v in config_defaults.items(): if k not in config: config[k] = v + elif isinstance(v, dict): + # recursively set default values for any sub-dicts + config[k] = self._set_config_defaults( + config[k] if (config[k] is not None) else {}, v) return config def _validate_config(self) -> None: @@ -206,6 +220,31 @@ def _validate_config(self) -> None: raise NotImplementedError( 'attn_uses_sequence_id only implemented with torch and triton attention.' ) + if self.attn_config['rope'] and (self.attn_config['rope_impl'] + not in ['dail', 'hf']): + raise ValueError( + 'If rope is being used then rope_impl should be either "dail", or "hf".' + ) + if self.attn_config['rope'] and ( + self.attn_config['rope_impl'] + == 'hf') and self.attn_config['rope_hf_config']['type'] not in [ + 'no_scaling', 'linear', 'dynamic' + ]: + raise ValueError( + 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".' + ) + if self.attn_config['rope'] and (self.attn_config['rope_impl'] + == 'dail'): + if self.attn_config['rope_dail_config']['type'] not in [ + 'original', 'xpos' + ]: + raise ValueError( + 'If using the dail implementation of rope, the type should be one of "original" or "xpos".' + ) + if not is_flash_v2_installed(v2_version='2.0.1'): + raise ImportError( + 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' @@ -217,9 +256,10 @@ def _validate_config(self) -> None: ) if self.init_config.get('name', None) is None: raise ValueError(f"{self.init_config=} 'name' needs to be set.") - if not self.learned_pos_emb and not self.attn_config['alibi']: + if not (self.learned_pos_emb or self.attn_config['alibi'] or + self.attn_config['rope']): warnings.warn( - f'Positional information not being provided to the model using either learned_pos_emb or alibi.' + f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.' ) if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': try: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 4f4581b177..0cb3ebd56c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -23,11 +23,27 @@ from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity from composer.models import HuggingFaceModel from composer.utils import dist + +from llmfoundry.models.layers.attention import is_flash_v2_installed + +if is_flash_v2_installed(): + try: # This try...except is needed because transformers requires it despite the 'if' statement above + from flash_attn.layers.rotary import \ + RotaryEmbedding as DAILRotaryEmbedding + except Exception as e: + raise e + from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) +from transformers.models.llama.modeling_llama import \ + LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding +from transformers.models.llama.modeling_llama import \ + LlamaRotaryEmbedding as HFRotaryEmbedding from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock @@ -70,6 +86,50 @@ log = logging.getLogger(__name__) +def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, + rope_dail_config: dict, rope_hf_config: dict, + max_seq_len: int): + if rope_impl == 'dail': + return DAILRotaryEmbedding( + dim=rope_head_dim, + base=rope_theta, + interleaved=False, + scale_base=rope_dail_config['xpos_scale_base'] if + (rope_dail_config['type'] == 'xpos') else None, + pos_idx_in_fp32=rope_dail_config['pos_idx_in_fp32'], + device= + 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif rope_impl == 'hf': + if rope_hf_config['type'] == 'no_scaling': + return HFRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif rope_hf_config['type'] == 'linear': + return HFLinearScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + elif rope_hf_config['type'] == 'dynamic': + return HFDynamicNTKScalingRotaryEmbedding( + rope_head_dim, + max_position_embeddings=max_seq_len, + base=rope_theta, + scaling_factor=rope_hf_config['factor'], + device= + 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu + ) + raise ValueError('rope_impl needs to be either dail or hf') + + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -123,6 +183,18 @@ def __init__(self, config: MPTConfig): ]) self.norm_f = norm_class(config.d_model, device=config.init_device) + self.rope = config.attn_config['rope'] + self.rope_impl = None + if self.rope: + self.rope_impl = config.attn_config['rope_impl'] + self.rotary_embedding = gen_rotary_embedding( + rope_head_dim=config.d_model // config.n_heads, + rope_impl=self.rope_impl, + rope_theta=config.attn_config['rope_theta'], + rope_dail_config=config.attn_config['rope_dail_config'], + rope_hf_config=config.attn_config['rope_hf_config'], + max_seq_len=self.config.max_seq_len) + if config.init_device != 'meta': log.info( f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.' @@ -361,8 +433,9 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - tok_emb = self.wte(input_ids) - if self.learned_pos_emb: + rotary_emb_w_meta_info = None + x = self.wte(input_ids) + if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: if len(past_key_values) != self.config.n_layers: @@ -378,31 +451,44 @@ def forward( if self.attn_impl == 'torch': past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: + if self.learned_pos_emb and (S + past_position > + self.config.max_seq_len): raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - pos = torch.arange( - past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device, - ).unsqueeze(0) - if attention_mask is not None: - # adjust the position indices to account for padding tokens - pos = torch.clamp( - pos - torch.cumsum((~attention_mask).to(torch.int32), - dim=1)[:, past_position:], - min=0, - ) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb - else: - # ALiBi and NoPE use this path (RoPE will also use this path if / when enabled) - x = tok_emb + if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'): + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) + if attention_mask is not None: + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), + dim=1)[:, past_position:], + min=0, + ) + if self.learned_pos_emb: + x = x + self.wpe(pos) + elif self.rope and self.rope_impl == 'hf': + rotary_emb_w_meta_info = { + 'impl': self.rope_impl, + 'rotary_emb': self.rotary_embedding, + 'offset_info': pos, + 'seq_len': S + past_position, + } + elif self.rope and self.rope_impl == 'dail': + rotary_emb_w_meta_info = { + 'impl': self.rope_impl, + 'rotary_emb': self.rotary_embedding, + 'offset_info': past_position, + 'seq_len': S + past_position, + } if self.embedding_fraction == 1: x = self.emb_drop(x) @@ -439,6 +525,7 @@ def forward( x, past_key_value=past_key_value, attn_bias=attn_bias, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index e6fe8eb438..3f2c229d6d 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -5,6 +5,9 @@ import torch from omegaconf import OmegaConf as om +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + def allclose_helper(t0: torch.Tensor, t1: torch.Tensor, @@ -18,7 +21,32 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize('attn_impl_1', ['flash', 'triton', 'torch']) @pytest.mark.parametrize('clip_qkv', [True, False]) @pytest.mark.parametrize('qk_ln', [True, False]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) @pytest.mark.parametrize( 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) @@ -26,18 +54,24 @@ def test_attn_impl(attn_impl_0: str, attn_impl_1: str, clip_qkv: bool, qk_ln: bool, - alibi: bool, + pos_emb_config: dict, attn_type: str, device: str = 'cuda'): """Compare all attn impl with each other. - Includes testing with and without attn_clip_qkv, attn_qk_ln, and alibi. + Includes testing with and without attn_clip_qkv, attn_qk_ln, alibi, and + rope. """ from llmfoundry.models.layers import attention - + alibi = pos_emb_config['alibi'] + rope = pos_emb_config['rope'] if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') + if rope and (pos_emb_config['rope_impl'] + == 'dail') and (not is_flash_v2_installed()): + pytest.skip('dail implementation of rope requires flash attention 2.') + cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, @@ -48,7 +82,7 @@ def test_attn_impl(attn_impl_0: str, }) n, s, f = 2, 16, cfg.d_model - + assert cfg.d_model % cfg.n_heads == 0 if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 @@ -91,16 +125,45 @@ def gen_bias(attn_impl: str): with torch.autocast(x0.device.type): attn_bias = gen_bias(attn0.attn_impl) + + rotary_emb_w_meta_info = None + if rope: + rotary_embedding = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=pos_emb_config['rope_impl'], + rope_theta=pos_emb_config['rope_theta'], + rope_dail_config=pos_emb_config.get('rope_dail_config', {}), + rope_hf_config=pos_emb_config.get('rope_hf_config', {}), + max_seq_len=s).to(device) + pos = torch.arange(s).unsqueeze(0).to(device=device) + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + rotary_emb_w_meta_info = { + 'impl': + pos_emb_config['rope_impl'], + 'rotary_emb': + rotary_embedding, + 'offset_info': + pos if (pos_emb_config['rope_impl'] == 'hf') else 0, + 'seq_len': + s, + } + y0, _, _ = attn0(x0, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True) attn_bias = gen_bias(attn1.attn_impl) y1, _, _ = attn1(x1, past_key_value=None, attn_bias=attn_bias, attention_mask=attention_mask, + rotary_emb_w_meta_info=rotary_emb_w_meta_info, is_causal=True) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/test_model.py b/tests/test_model.py index 1c7033ed48..41b62f0ccf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,7 +16,7 @@ from composer.core.precision import Precision, get_precision_context from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module -from composer.utils import dist, get_device +from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, @@ -28,6 +28,7 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias +from llmfoundry.models.layers.attention import is_flash_v2_installed from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -517,16 +518,49 @@ def test_mpt_creation(norm_type: str, no_bias: bool): ('flash', 'gpu'), ('triton', 'gpu'), ('torch', 'gpu')]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_forward_with_padding(attention_impl: str, device: str, + pos_emb_config: dict): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) + alibi = pos_emb_config['alibi'] if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + rope = pos_emb_config['rope'] + if rope and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) hf_config = MPTConfig( @@ -540,7 +574,7 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, init_config={ 'name': 'baseline_', @@ -612,23 +646,35 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): attention_mask=batched_attention_mask).logits # check that right padding and left padding produce the same output + right_pad_v_left_pad_rtol = 1e-5 + right_pad_v_left_pad_atol = 1e-6 if attention_impl == 'torch' else 1e-8 + if rope and pos_emb_config['rope_impl'] == 'dail': + # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + right_pad_v_left_pad_rtol = 1e-2 + right_pad_v_left_pad_atol = 1e-2 assert torch.allclose(right_padding_output[0, :3], left_padding_output[0, 3:], - atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not alibi: + rtol=right_pad_v_left_pad_rtol, + atol=right_pad_v_left_pad_atol) + + if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that right padding and middle padding produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( right_padding_output[0, :3], middle_padding_output[0, [0, 1, 5]], atol=1e-6 if attention_impl == 'torch' else 1e-8) + # check that right padding and right padding in a batch produce the same output assert torch.allclose(right_padding_output[0, :3], batched_output[0, :3], atol=1e-6 if attention_impl == 'torch' else 1e-8) - if not alibi: + + if not (alibi or (rope and pos_emb_config['rope_impl'] == 'dail')): # check that middle padding and middle padding in a batch produce the same output # Note: alibi not implemented for middle padding. + # Note: dail implementation of rope does not support middle padding. assert torch.allclose( middle_padding_output[0], batched_output[1, :], @@ -694,17 +740,47 @@ def test_advanced_mask_building(attention_impl: str): ('flash', 'gpu'), ('triton', 'gpu'), ('torch', 'gpu')]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generate(attention_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_generate(attention_impl: str, device: str, pos_emb_config: dict): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) - if alibi and attention_impl == 'flash': + if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) hf_config = MPTConfig( @@ -718,7 +794,7 @@ def test_generate(attention_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, ) mpt = MPTForCausalLM(hf_config) @@ -886,9 +962,54 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): check_hf_model_equivalence(mpt, mpt2) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache_and_padding(alibi: bool): +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_forward_with_cache_and_padding(attn_impl: str, device: str, + pos_emb_config: dict): # Tests that the result is the same with or without padding when using kv caching + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + composer_device = get_device(device) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -899,8 +1020,8 @@ def test_forward_with_cache_and_padding(alibi: bool): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', - 'alibi': alibi, + 'attn_impl': attn_impl, + **pos_emb_config, }, use_cache=True, init_config={ @@ -910,47 +1031,74 @@ def test_forward_with_cache_and_padding(alibi: bool): ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() - - first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) - first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() - - # start with passing the first three tokens through (no padding) - first_output_no_padding = mpt( - first_input_ids_no_padding, - attention_mask=first_attention_mask_no_padding) - - second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) - second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() - - # pass through the fourth token by itself, using the key-value cache (no padding) - second_output_no_padding = mpt( - second_input_ids_no_padding[:, -1].unsqueeze(-1), - attention_mask=second_attention_mask_no_padding, - past_key_values=first_output_no_padding.past_key_values) - - first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) - first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() - - # start with passing the first three tokens through (with left padding) - first_output_padding = mpt(first_input_ids_padding, - attention_mask=first_attention_mask_padding) - - second_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11, 11274]]) - second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() - - # pass through the fourth token by itself, using the key-value cache (with left padding) - second_output_padding = mpt( - second_input_ids_padding[:, -1].unsqueeze(-1), - attention_mask=second_attention_mask_padding, - past_key_values=first_output_padding.past_key_values) - - # check that the outputs are the same with or without padding - torch.testing.assert_close(second_output_no_padding.logits, - second_output_padding.logits[:, - -1, :].unsqueeze(1), - atol=1e-6, - rtol=1e-6) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + first_input_ids_no_padding = torch.tensor([[11274, 16390, 11]]) + first_input_ids_no_padding = composer_device.tensor_to_device( + first_input_ids_no_padding) + first_attention_mask_no_padding = torch.tensor([[1, 1, 1]]).bool() + first_attention_mask_no_padding = composer_device.tensor_to_device( + first_attention_mask_no_padding) + + # start with passing the first three tokens through (no padding) + first_output_no_padding = mpt( + first_input_ids_no_padding, + attention_mask=first_attention_mask_no_padding) + + second_input_ids_no_padding = torch.tensor([[11274, 16390, 11, 11274]]) + second_input_ids_no_padding = composer_device.tensor_to_device( + second_input_ids_no_padding) + second_attention_mask_no_padding = torch.tensor([[1, 1, 1, 1]]).bool() + second_attention_mask_no_padding = composer_device.tensor_to_device( + second_attention_mask_no_padding) + + # pass through the fourth token by itself, using the key-value cache (no padding) + second_output_no_padding = mpt( + second_input_ids_no_padding[:, -1].unsqueeze(-1), + attention_mask=second_attention_mask_no_padding, + past_key_values=first_output_no_padding.past_key_values) + + first_input_ids_padding = torch.tensor([[50256, 11274, 16390, 11]]) + first_input_ids_padding = composer_device.tensor_to_device( + first_input_ids_padding) + first_attention_mask_padding = torch.tensor([[0, 1, 1, 1]]).bool() + first_attention_mask_padding = composer_device.tensor_to_device( + first_attention_mask_padding) + + # start with passing the first three tokens through (with left padding) + first_output_padding = mpt(first_input_ids_padding, + attention_mask=first_attention_mask_padding) + + second_input_ids_padding = torch.tensor( + [[50256, 11274, 16390, 11, 11274]]) + second_input_ids_padding = composer_device.tensor_to_device( + second_input_ids_padding) + second_attention_mask_padding = torch.tensor([[0, 1, 1, 1, 1]]).bool() + second_attention_mask_padding = composer_device.tensor_to_device( + second_attention_mask_padding) + + # pass through the fourth token by itself, using the key-value cache (with left padding) + second_output_padding = mpt( + second_input_ids_padding[:, -1].unsqueeze(-1), + attention_mask=second_attention_mask_padding, + past_key_values=first_output_padding.past_key_values) + + # check that the outputs are the same with or without padding + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_impl'] == 'dail': # dail implementation of rope uses bf16 precision and hence the rotations have small numerical errors. This causes some differences between the outputs of padded and unpadded inputs. + torch.testing.assert_close( + second_output_no_padding.logits, + second_output_padding.logits[:, -1, :].unsqueeze(1), + atol=1e-2, + rtol=1e-6) + else: + torch.testing.assert_close( + second_output_no_padding.logits, + second_output_padding.logits[:, -1, :].unsqueeze(1), + atol=1e-6, + rtol=1e-6) @pytest.mark.parametrize('attn_impl,device', [ @@ -959,17 +1107,47 @@ def test_forward_with_cache_and_padding(alibi: bool): ('triton', 'gpu'), ('torch', 'gpu'), ]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' ) - if alibi and attn_impl == 'flash': + if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) hf_config = MPTConfig( @@ -983,10 +1161,8 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'alibi': alibi, + **pos_emb_config, }, - attn_impl=attn_impl, - alibi=alibi, use_cache=True, init_config={ 'name': 'baseline_', @@ -1066,8 +1242,53 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): ) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generate_with_past_kv(alibi: bool): +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_generate_with_past_kv(attn_impl: str, device: str, + pos_emb_config: dict): + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + composer_device = get_device(device) + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1078,8 +1299,8 @@ def test_generate_with_past_kv(alibi: bool): emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', - 'alibi': alibi, + 'attn_impl': attn_impl, + **pos_emb_config, }, use_cache=True, init_config={ @@ -1088,33 +1309,46 @@ def test_generate_with_past_kv(alibi: bool): }, ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + no_padding_attention_mask = composer_device.tensor_to_device( + no_padding_attention_mask) - with mock.patch.object(MPTForCausalLM, 'forward', - autospec=True) as forward_mocked: - forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), - past_key_values=[(torch.randn(1, 3, hf_config.d_model), - torch.randn(1, 3, hf_config.d_model)) - for _ in range(hf_config.n_layers)]) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - max_new_tokens=2) - - assert forward_mocked.call_count == 2 - _, _, kwargs = forward_mocked.mock_calls[0] - assert kwargs['past_key_values'] is None - _, _, kwargs = forward_mocked.mock_calls[1] - assert kwargs['past_key_values'] is not None - assert len(kwargs['past_key_values']) == hf_config.n_layers - assert kwargs['past_key_values'][0][0].shape == (1, 3, - hf_config.d_model) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + with mock.patch.object(MPTForCausalLM, 'forward', + autospec=True) as forward_mocked: + forward_mocked.return_value = CausalLMOutputWithPast( + logits=torch.randn((1, 3, hf_config.vocab_size)), + past_key_values=[(torch.randn(1, 3, hf_config.d_model), + torch.randn(1, 3, hf_config.d_model)) + for _ in range(hf_config.n_layers)]) + _ = mpt.generate(input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + max_new_tokens=2) + + assert forward_mocked.call_count == 2 + _, _, kwargs = forward_mocked.mock_calls[0] + assert kwargs['past_key_values'] is None + _, _, kwargs = forward_mocked.mock_calls[1] + assert kwargs['past_key_values'] is not None + assert len(kwargs['past_key_values']) == hf_config.n_layers + assert kwargs['past_key_values'][0][0].shape == (1, 3, + hf_config.d_model) +@pytest.mark.parametrize('attn_impl,device', [ + ('torch', 'cpu'), + ('flash', 'gpu'), + ('triton', 'gpu'), + ('torch', 'gpu'), +]) @pytest.mark.parametrize('generation_kwargs', [{ 'max_new_tokens': 2, 'num_beams': 4 @@ -1126,9 +1360,49 @@ def test_generate_with_past_kv(alibi: bool): 'do_sample': True, 'top_p': 0.95 }]) -@pytest.mark.parametrize('alibi', [True, False]) -def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], - alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_generation_kwargs_dont_crash(attn_impl: str, device: str, + generation_kwargs: Dict[str, Any], + pos_emb_config: dict): + if not torch.cuda.is_available() and device == 'gpu': + pytest.skip( + f'This test requires CUDA to be available in order to run with {attn_impl} attention.' + ) + if pos_emb_config['alibi'] and attn_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + composer_device = get_device(device) + if device == 'gpu': # Switch deteminism off + torch.use_deterministic_algorithms(False) hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1139,35 +1413,73 @@ def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], emb_pdrop=0.1, resid_pdrop=0.2, attn_config={ - 'attn_impl': 'torch', - 'alibi': alibi, + 'attn_impl': attn_impl, + **pos_emb_config, }, use_cache=True, ) mpt = MPTForCausalLM(hf_config) + mpt = composer_device.module_to_device(mpt) mpt.eval() - # no padding in the input - no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) - no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + # no padding in the input + no_padding_input_ids = torch.tensor([[11274, 16390, 11]]) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) + no_padding_attention_mask = torch.tensor([[1, 1, 1]]) + no_padding_attention_mask = composer_device.tensor_to_device( + no_padding_attention_mask) - _ = mpt.generate(input_ids=no_padding_input_ids, - attention_mask=no_padding_attention_mask, - **generation_kwargs) + _ = mpt.generate(input_ids=no_padding_input_ids, + attention_mask=no_padding_attention_mask, + **generation_kwargs) + if device == 'gpu': # Switch deteminism back on + reproducibility.configure_deterministic_mode() @pytest.mark.gpu @pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton']) -@pytest.mark.parametrize('alibi', [True, False]) -def test_model_to(attention_impl: str, alibi: bool): +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_model_to(attention_impl: str, pos_emb_config: dict): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( f'This test requires CUDA to be available in order to run with {attention_impl} attention.' ) - if alibi and attention_impl == 'flash': + if pos_emb_config['alibi'] and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') + if pos_emb_config['rope'] and pos_emb_config[ + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + pytest.skip(f'dail implementation of rope requires flash attention 2.') + hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1179,7 +1491,7 @@ def test_model_to(attention_impl: str, alibi: bool): resid_pdrop=0.2, attn_config={ 'attn_impl': attention_impl, - 'alibi': alibi, + **pos_emb_config, }, use_cache=True, init_config={ @@ -1204,7 +1516,8 @@ def test_model_to(attention_impl: str, alibi: bool): mpt = mpt.to('cpu') # verify the model still works - if attention_impl == 'torch': + if attention_impl == 'torch' and not ( + pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) @@ -1221,7 +1534,8 @@ def test_model_to(attention_impl: str, alibi: bool): mpt = mpt.float() # verify the model still works - if attention_impl == 'torch': + if attention_impl == 'torch' and not ( + pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail'): _ = mpt(input_ids.to('cpu'), attention_mask=attention_mask.to('cpu')) mpt = mpt.half() @@ -1258,21 +1572,50 @@ def test_alibi_vs_hf(): ('triton', 'gpu'), ('torch', 'gpu'), ]) -@pytest.mark.parametrize('alibi', [True, False]) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': False, + 'rope': False +}, { + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( - attn_impl: str, device: str, alibi: bool, output_attentions: bool, - output_hidden_states: bool): + attn_impl: str, device: str, pos_emb_config: dict, + output_attentions: bool, output_hidden_states: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' ) - if alibi and attn_impl == 'flash': + if pos_emb_config['alibi'] and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') + if pos_emb_config['rope'] and pos_emb_config['rope_impl'] == 'dail' and ( + device != 'gpu' or not is_flash_v2_installed()): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') composer_device = get_device(device) @@ -1289,10 +1632,8 @@ def test_forward_with_output_attentions_and_output_hidden_states( resid_pdrop=0.2, attn_config={ 'attn_impl': attn_impl, - 'alibi': alibi, + **pos_emb_config, }, - attn_impl=attn_impl, - alibi=alibi, use_cache=True, init_config={ 'name': 'baseline_', diff --git a/tests/test_rope_dail_vs_hf.py b/tests/test_rope_dail_vs_hf.py new file mode 100644 index 0000000000..598e308546 --- /dev/null +++ b/tests/test_rope_dail_vs_hf.py @@ -0,0 +1,145 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from composer.core.precision import get_precision_context +from omegaconf import OmegaConf as om + +from llmfoundry.models.layers.attention import is_flash_v2_installed +from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding + + +@pytest.mark.gpu +@pytest.mark.parametrize('clip_qkv', [True, False]) +@pytest.mark.parametrize('qk_ln', [True, False]) +@pytest.mark.parametrize( + 'attn_type', + ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) +@pytest.mark.parametrize('seq_len', [1, 233, 2048]) +def test_rope_dail_vs_hf(clip_qkv: bool, + qk_ln: bool, + attn_type: str, + seq_len: int, + device: str = 'cuda'): + # compare rope rotations for the dail vs hf implementations + if not is_flash_v2_installed(): + pytest.skip('dail implementation of rope requires flash attention 2.') + + from llmfoundry.models.layers import attention + + cfg = om.create({ + 'attn_impl': 'flash', + 'd_model': 128, + 'n_heads': 4, + 'attn_pdrop': 0, + 'clip_qkv': clip_qkv, + 'qk_ln': qk_ln, + }) + + batch_size = 2 + assert cfg.d_model % cfg.n_heads == 0 + if attn_type == 'grouped_query_attention': + cfg.kv_n_heads = 2 + + attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) + + attn1.load_state_dict(attn0.state_dict()) + x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device) + x1 = x0.clone().detach() + x0.requires_grad = True + x1.requires_grad = True + attention_mask = torch.ones(batch_size, seq_len).to(device).bool() + + with get_precision_context('amp_bf16'): + dail_rope_config = { + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + } + } + hf_rope_config = { + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + } + } + + dail_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=dail_rope_config['rope_impl'], + rope_theta=dail_rope_config['rope_theta'], + rope_dail_config=dail_rope_config['rope_dail_config'], + rope_hf_config={}, + max_seq_len=seq_len).to('cuda') + dail_rope_w_meta_info = { + 'impl': 'dail', + 'rotary_emb': dail_rope, + 'offset_info': 0, + 'seq_len': seq_len, + } + + hf_rope = gen_rotary_embedding( + rope_head_dim=cfg.d_model // cfg.n_heads, + rope_impl=hf_rope_config['rope_impl'], + rope_theta=hf_rope_config['rope_theta'], + rope_dail_config={}, + rope_hf_config=hf_rope_config['rope_hf_config'], + max_seq_len=seq_len).to('cuda') + pos = torch.arange(seq_len).unsqueeze(0).to(device='cuda') + # adjust the position indices to account for padding tokens + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1), + min=0, + ) + hf_rope_w_meta_info = { + 'impl': 'hf', + 'rotary_emb': hf_rope, + 'offset_info': pos, + 'seq_len': seq_len, + } + + y0, _, _ = attn0(x0, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=dail_rope_w_meta_info, + is_causal=True) + + y1, _, _ = attn1(x1, + past_key_value=None, + attn_bias=None, + attention_mask=attention_mask, + rotary_emb_w_meta_info=hf_rope_w_meta_info, + is_causal=True) + + y0 *= attention_mask.unsqueeze(-1) + y1 *= attention_mask.unsqueeze(-1) + + loss0 = y0.sum() + loss1 = y1.sum() + + loss0.backward() + loss1.backward() + + torch.testing.assert_close(y0, y1, rtol=1e-2, atol=1e-2) + + torch_name_param_map = {n: p for n, p in attn1.named_parameters()} + for n, p in attn0.named_parameters(): + tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None + torch.testing.assert_close(p, tp, rtol=1e-2, atol=1e-2) + # Relaxed to a l2-norm based check. + assert torch.norm(tp.grad - p.grad) <= 1e-2 + 1e-2 * torch.norm(p.grad) + + assert x0.grad is not None + assert x1.grad is not None + # Relaxed to a l2-norm based check. + assert torch.norm(x0.grad - x1.grad) <= 1e-2 + 1e-2 * torch.norm(x0.grad)