From 613457a1cb426ddd601b1b7ee44430be0d8f5ff7 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 27 Nov 2023 15:51:30 -0800 Subject: [PATCH 01/14] Bump composer version to min 0.17.1 (#762) --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index afdfce8d48..9bf2ef2cb0 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ ] install_requires = [ - 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17,<0.18', + 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'transformers>=4.34.1,<4.35', 'mosaicml-streaming>=0.7.1,<0.8', @@ -84,11 +84,11 @@ ] extra_deps['databricks'] = [ - 'mosaicml[databricks]>=0.17,<0.18', + 'mosaicml[databricks]>=0.17.1,<0.18', ] extra_deps['tensorboard'] = [ - 'mosaicml[tensorboard]>=0.17,<0.18', + 'mosaicml[tensorboard]>=0.17.1,<0.18', ] extra_deps['gpu'] = [ From 34d04ea689a4e4b08af7e9e911338fd4bc2983c1 Mon Sep 17 00:00:00 2001 From: bandish-shah <86627118+bandish-shah@users.noreply.github.com> Date: Mon, 27 Nov 2023 20:52:14 -0800 Subject: [PATCH 02/14] Update Docker image release logic so that we can release new images to prod from workflow_dispatch (#763) --- .github/workflows/docker.yaml | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 13a835356c..f6dac79fe5 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -69,19 +69,17 @@ jobs: GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7) echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV} - if [ "${{ github.event_name }}" == "push" ]; then - echo "Triggered by push event." - PROD_REPO="mosaicml/llm-foundry" - IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" - IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" - elif [ "${{ github.event_name }}" == "pull_request" ]; then + if [ "${{ github.event_name }}" == "pull_request" ]; then echo "Triggered by pull_request event." STAGING_REPO="mosaicml/ci-staging" IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}" IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache" else - echo "Triggered by unknown event: ${{ github.event_name }}" - exit 1 + # Triggered by push or workflow_dispatch event + echo "Triggered by ${{ github.event_name }} event, releasing to prod" + PROD_REPO="mosaicml/llm-foundry" + IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest" + IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache" fi echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV} From 4f399bf5895b52490c1e43cd0f7d1492724bfa47 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Tue, 28 Nov 2023 12:54:45 -0800 Subject: [PATCH 03/14] Fix tiktoken wrapper (#761) --- llmfoundry/tokenizers/tiktoken.py | 169 ++++++++++++++---------------- tests/test_tiktoken.py | 62 ++++++----- 2 files changed, 111 insertions(+), 120 deletions(-) diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py index 8e258cce74..6110f565df 100644 --- a/llmfoundry/tokenizers/tiktoken.py +++ b/llmfoundry/tokenizers/tiktoken.py @@ -1,8 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - -import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from functools import lru_cache +from typing import Any, Dict, List, Optional, Tuple import torch from transformers import PreTrainedTokenizer @@ -10,6 +9,38 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible.""" +# Taken from +# https://github.com/huggingface/transformers/blob/8aca43bdb3cb9a5020f6d57589d85679dc873b1c/src/transformers/models/gpt2/tokenization_gpt2.py#L62-L84 +@lru_cache() +def bytes_to_unicode(): + """Returns list of utf-8 byte and a mapping to unicode strings. + + We specifically avoids mapping to whitespace/control characters the bpe code + barfs on. + + The reversible bpe codes work on unicode strings. This means you need a + large # of unicode characters in your vocab if you want to avoid UNKs. When + you're at something like a 10B token dataset you end up needing around 5K + for decent coverage. This is a significant percentage of your normal, say, + 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and + unicode strings. + """ + bs = (list(range(ord('!'), + ord('~') + 1)) + list(range(ord('¡'), + ord('¬') + 1)) + + list(range(ord('®'), + ord('ÿ') + 1))) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + class TiktokenTokenizerWrapper(PreTrainedTokenizer): """A thin wrapper around tiktoken to make it compatible with Hugging Face. @@ -93,6 +124,28 @@ def pickle_Encoding(enc: Encoding): self.add_eos_token = add_eos_token self.use_default_system_prompt = use_default_system_prompt + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + self.decoder: Dict[int, str] = {} + for i in range(self.encoding.n_vocab): + try: + self.encoding.decode_single_token_bytes(i) + except KeyError: + continue + # Taken from + # https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee + decoding = ''.join([ + bytes_to_unicode()[ord(char)] for char in + self.encoding.decode_single_token_bytes(i).decode('latin-1') + ]) + self.decoder[i] = decoding + + self.encoder: Dict[str, int] = {} + for i in range(self.encoding.n_vocab): + if i in self.decoder: + self.encoder[self.decoder[i]] = i + super().__init__(model_name=model_name, encoding_name=encoding_name, add_bos_token=add_bos_token, @@ -135,122 +188,54 @@ def default_chat_template(self): return template def get_vocab(self) -> Dict[str, int]: - """Returns vocab as a dict. - - Note: This function does not work properly due to difference in assumptions between tiktoken and Hugging Face tokenizers. - Most uses do not need to use get_vocab, so this is not a priority to fix. - """ - warnings.warn( - 'get_vocab does not work properly with TiktokenTokenizerWrapper. Please do not rely on it being perfectly correct.' - + - ' It will be called once init just to get the size of the vocab inside the base class.' - ) - - vocab = {} - for i in range(self.vocab_size): - try: - # need to try this first, so that we get a proper KeyError, - # otherwise it crashes in the rust code - _ = self.encoding.decode_single_token_bytes(i) - vocab[self.encoding.decode([i])] = i - except KeyError: - pass - + """Returns vocab as a dict.""" # As far as I can tell, we don't require get_vocab to completely work, # but when using additional_special_tokens, Hugging Face determines the next # token index to add with len(self.get_vocab()) so we need the _size_ of this dictionary to be correct. + vocab_clone = self.encoder.copy() extra_id_index = 0 candidate_extra_id = f'' indices_to_fill_in = {i for i in range(self.vocab_size)} - set( - vocab.values()) + vocab_clone.values()) # Add enough indices to make get_vocab() the right length for index_to_add in indices_to_fill_in: # Make sure we don't overwrite a token that already exists - while candidate_extra_id in vocab: + while candidate_extra_id in vocab_clone: extra_id_index += 1 candidate_extra_id = f'' # Get an index to add and add the item - vocab[candidate_extra_id] = index_to_add + vocab_clone[candidate_extra_id] = index_to_add - return vocab + return vocab_clone - def _tokenize(self, text: str) -> List[int]: - """Returns a tokenized string. - - Note: We have slightly redefined the expected contract between this method and - the _convert_token_to_id method. Normally, this method turns a string, into a list of strings, - and then the _convert_token_to_id method turns that list of strings into a list of integers. - However, not all vocab indices can be decoded into a string, so instead we just return the integers - from this function, and have adjusted the _convert_token_to_id method to handle integers as well as strings. - The only use of _tokenize that I could find was in this way, so this _should_ be safe. - """ + def _tokenize(self, text: str) -> List[str]: + """Returns a tokenized string.""" if not isinstance(text, str): raise ValueError( f'Expected a string input to _tokenize but got {type(text)}.') - tokens = [t for t in self.encoding.encode(text, allowed_special='all')] + tokens = [ + self.decoder[t] + for t in self.encoding.encode(text, allowed_special='all') + ] return tokens - def _convert_token_to_id(self, token: Union[int, str]) -> int: - """Converts a token (str) into an id using the vocab.""" - if isinstance(token, int): - return token + def _convert_token_to_id(self, token: str) -> Optional[int]: + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) - return self.encoding.encode(token, allowed_special='all')[0] - - def _convert_id_to_token(self, index: int) -> str: - """Converts an index (integer) into a token (str) using the vocab.""" - return self.encoding.decode([index]) + def _convert_id_to_token(self, index: int) -> Optional[str]: + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) def convert_tokens_to_string(self, tokens: List[str]) -> str: """Converts a sequence of tokens (string) in a single string.""" - return ''.join(tokens) - - def convert_ids_to_tokens( - self, - ids: Union[int, List[int]], - skip_special_tokens: bool = False) -> Union[str, List[str]]: - """Converts a single index or a sequence of indices into a token or a. - - sequence of tokens, using the vocabulary and added tokens. - - Args: - ids (`int` or `List[int]`): - The token id (or token ids) to convert to tokens. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - - Returns: - `str` or `List[str]`: The decoded token(s). - """ - if isinstance(ids, int): - if ids in self.added_tokens_decoder: - return str(self.added_tokens_decoder[ids]) - - return self._convert_id_to_token(ids) - - # current_stream will collect multiple tokens, and then separately add items - # for each added token. This is done so that decode works properly with token ids - # that cannot be represented naively in utf-8. - tokens = [] - current_stream = [] - for index in ids: - if skip_special_tokens and index in self.all_special_ids: - continue - - if index in self.added_tokens_decoder: - tokens.append(self.encoding.decode(current_stream)) - current_stream = [] - tokens.append(str(self.added_tokens_decoder[index])) - else: - current_stream.append(index) - - if len(current_stream) > 0: - tokens.append(self.encoding.decode(current_stream)) - return tokens + text = ''.join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8') + return text def build_inputs_with_special_tokens( self, diff --git a/tests/test_tiktoken.py b/tests/test_tiktoken.py index 5bd10c82d3..fe3db41d50 100644 --- a/tests/test_tiktoken.py +++ b/tests/test_tiktoken.py @@ -7,7 +7,8 @@ import pytest import transformers -from llmfoundry import TiktokenTokenizerWrapper +from llmfoundry.tokenizers.tiktoken import (TiktokenTokenizerWrapper, + bytes_to_unicode) from tests.horrible_strings import HORRIBLE_STRINGS from tests.test_hf_conversion_script import check_hf_tokenizer_equivalence @@ -29,18 +30,12 @@ TEST_STRINGS += HORRIBLE_STRINGS -MODEL_OR_ENCODING_NAME_TO_NON_UTF8_TOKENS = { - 'gpt-4': 77, - 'gpt-3.5-turbo': 77, - 'text-davinci-003': 14, - 'cl100k_base': 77, -} - MODEL_ENCODING_NAME_PARAMETRIZATION = [ ('gpt-4', None), ('gpt-3.5-turbo', None), ('text-davinci-003', None), (None, 'cl100k_base'), + ('gpt2', None), ] MULTI_TURN_CHAT_ML = [[{ @@ -120,6 +115,31 @@ def test_tiktoken_simple(model_name: Optional[str], assert reloaded_wrapped_output == wrapped_output +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_tokenize_with_ids(model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + for string in TEST_STRINGS: + wrapped_output = wrapped_tokenizer.tokenize(string) + original_output = original_tokenizer.encode(string, + allowed_special='all') + reloaded_wrapped_output = reloaded_wrapped_tokenizer.tokenize(string) + + assert all([isinstance(t, str) for t in wrapped_output]) + assert len(wrapped_output) == len(original_output) + assert wrapped_output == reloaded_wrapped_output + + redone_token_ids = wrapped_tokenizer.convert_tokens_to_ids( + wrapped_output) + assert redone_token_ids == original_output + assert wrapped_tokenizer.convert_ids_to_tokens( + redone_token_ids) == wrapped_output + + @pytest.mark.parametrize('model_name,encoding_name', MODEL_ENCODING_NAME_PARAMETRIZATION) def test_tiktoken_roundtrip(model_name: Optional[str], @@ -201,31 +221,17 @@ def test_tiktoken_vocab(model_name: Optional[str], encoding_name: Optional[str], reloaded_wrapped_vocab = reloaded_wrapped_tokenizer.get_vocab() assert wrapped_vocab == reloaded_wrapped_vocab - didnt_match = [] for key, value in wrapped_vocab.items(): # Skip checking the extra ids we pad the vocab with if key.startswith(''): continue - if original_tokenizer.encode(key, allowed_special='all') == [value]: - continue - else: - didnt_match.append( - (key, original_tokenizer.encode(key, - allowed_special='all'), value)) - - # Decode is lossy because some bytes are not representable in utf-8 - # see https://github.com/openai/tiktoken/blob/39f29cecdb6fc38d9a3434e5dd15e4de58cf3c80/tiktoken/core.py#L245-L247 - # This means that the str: int vocab mapping doesn't work. Would have to look more into how other HF tokenizers handle this. - model_or_encoding_name = model_name or encoding_name - if model_or_encoding_name is not None: - expected_didnt_match = MODEL_OR_ENCODING_NAME_TO_NON_UTF8_TOKENS.get( - model_or_encoding_name) - assert len(didnt_match) == expected_didnt_match - else: - raise NotImplementedError( - 'Add the new tokenizer and how many tokens in the vocab are not utf8 representable.' - ) + expected_decoding = ''.join([ + bytes_to_unicode()[ord(char)] + for char in original_tokenizer.decode_single_token_bytes( + value).decode('latin-1') + ]) + assert expected_decoding == key @pytest.mark.parametrize('model_name,encoding_name', From 5f21855cb35987ec73fe8b3a515f6ae3db903d56 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 28 Nov 2023 16:18:18 -0800 Subject: [PATCH 04/14] enable param group configuration in llm-foundry (#760) * enable param group configuration in llm-foundry * add doc string * add debug logs * add test, fix bug * spell check; mark test gpu * updt to use RegEx search * Apply suggestions from code review Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> * updt with dakinggg pr comments --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- llmfoundry/optim/lion8b.py | 22 ++++--- llmfoundry/utils/builders.py | 118 +++++++++++++++++++++++++++++++++-- tests/test_builders.py | 89 +++++++++++++++++++++++++- 3 files changed, 211 insertions(+), 18 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 2c2e6e2d35..9d1d1dda71 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): device, or b) step() is executed on a non-CUDA parameter. """ - def __init__(self, - params: Iterable[torch.Tensor], - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0, - quantize: bool = True, - compress_state_dict: bool = False, - error_correction: bool = False, - _fused: bool = True): # XXX this flag is mostly for testing... + def __init__( + self, + params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + quantize: bool = True, + compress_state_dict: bool = False, + error_correction: bool = False, + _fused: bool = True, # XXX this flag is mostly for testing... + ): if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index c31917efc6..14196c3ef9 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,10 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import functools import logging import os +import re import warnings -from typing import Any, Dict, List, Optional, Tuple, Union +from collections import OrderedDict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from composer import algorithms @@ -155,18 +158,121 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: raise ValueError(f'Not sure how to build algorithm: {name}') +def _extract_param_groups( + model: torch.nn.Module, + optimizer_config: Dict[str, Any], +) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: + """Extracts parameter groups defined in the optimizer config. + + The optimizer_config defines the optimizer args. It can additionally have key + `disable_grad` which is a string or list of strings. If a string matches a + parameter name, then that parameter will have `requires_grad=False`. This is + useful for freezing parameters. It can additionally have a key + `param_groups` which is a list of dicts. In this dict, key `param_str_match` + defines a string; if a parameter name contains this string, then it will be + in this parameter group. This is useful for grouping parameters together. + The dict can also contain any other key that is a valid optimizer arg. + Note: to handle name overlap conflicts, params are assigned to parameter + groups and added to `param_groups` in the order that `param_str_match` appear + in `param_groups`. + + Usage + To disable gradient for all parameters that contain the string "norm" or "bias": + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "disable_grad": ["norm", "bias"] + } + ``` + + To create and modify the optimizer parameters for all parameters that contain + the string "norm" and "bias" separately: + ``` + optimizer_config: { + "name": "decoupled_lionw", + "lr": 1e-3, + "weight_decay": 1e-2, + "betas": [0.9, 0.999], + "eps": 1e-8, + "param_groups": [ + { + "param_str_match": "norm", + "lr": 1e-4, + "weight_decay": 0.0, + }, + { + "param_str_match": "bias", + "lr": 5e-4, + "weight_decay": 0.0, + }, + ], + } + ``` + + Args: + model (torch.nn.Module): model to extract parameters from + optimizer_config (Dict[str, Any]): optimizer config + + Returns: + Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of + torch.Tensor's or dict's. Specifies what Tensors should be optimized + and their param groupings. + """ + if 'disable_grad' in optimizer_config.keys(): + str_matches = optimizer_config.pop('disable_grad') + if isinstance(str_matches, str): + str_matches = [str_matches] + for str_match in str_matches: + for n, p in model.named_parameters(): + if re.search(str_match, n): + p.requires_grad = False + log.debug(f'Setting `{n}.requires_grad = False`.') + + param_groups_config = optimizer_config.pop('param_groups', None) + if param_groups_config is not None: + params = [] + param_dict = OrderedDict((n, p) for n, p in model.named_parameters()) + + log.debug(f'Default optimizer settings: {optimizer_config}.') + for param_group_config in param_groups_config: + str_match = param_group_config.pop('param_str_match') + filter_fn = functools.partial(re.search, str_match) + param_names = [n for n in param_dict.keys() if filter_fn(n)] + group_params = {'params': [param_dict.pop(n) for n in param_names]} + group_params.update(param_group_config) + + log.debug( + f'Creating optimizer param_group with parameters: {param_names} ' +\ + f'(extracted using {str_match=}). The param_group optimizer ' +\ + f'setting overrides are: {param_group_config}.') + + params.append(group_params) + + params.insert(0, {'params': param_dict.values()}) + return params + + return model.parameters() + + def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Dict[str, Any]) -> Optimizer: + + params = _extract_param_groups(model, optimizer_config) + if name == 'decoupled_adamw': - return DecoupledAdamW(model.parameters(), **optimizer_config) + return DecoupledAdamW(params, **optimizer_config) elif name == 'decoupled_lionw': - return DecoupledLionW(model.parameters(), **optimizer_config) + return DecoupledLionW(params, **optimizer_config) elif name == 'clip_lion': - return DecoupledClipLion(model.parameters(), **optimizer_config) + return DecoupledClipLion(params, **optimizer_config) elif name == 'adalr_lion': - return DecoupledAdaLRLion(model.parameters(), **optimizer_config) + return DecoupledAdaLRLion(params, **optimizer_config) elif name == 'decoupled_lionw_8b': - return DecoupledLionW_8bit(model.parameters(), **optimizer_config) + return DecoupledLionW_8bit(params, **optimizer_config) else: raise ValueError(f'Not sure how to build optimizer: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index 237e27b52b..7ac179720e 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -1,17 +1,22 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import re import unittest.mock as mock -from typing import Union +from copy import deepcopy +from typing import Any, Dict, Union import pytest +import torch +import torch.nn as nn from composer.callbacks import Generate from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import build_callback, build_tokenizer +from llmfoundry.utils.builders import (build_callback, build_optimizer, + build_tokenizer) @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -110,3 +115,83 @@ def test_build_hf_checkpointer_callback(): assert isinstance(kwargs['mlflow_logging_config'], dict) assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict + + +class _DummyModule(nn.Module): + + def __init__(self, device: str = 'cpu', dtype: torch.dtype = torch.float32): + super().__init__() + self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) + self.norm0 = nn.LayerNorm(3, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 5, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + return self.linear1(self.norm0(self.linear0(x))) + + +@pytest.mark.parametrize('name, optimizer_config', [ + ('decoupled_adamw', {}), + ('decoupled_lionw', {}), + ('clip_lion', {}), + ('adalr_lion', {}), + pytest.param('decoupled_lionw_8b', {}, marks=pytest.mark.gpu), +]) +@pytest.mark.parametrize('opt_additional_config', [ + { + 'disable_grad': 'norm' + }, + { + 'disable_grad': ['norm', 'bias'] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'no.*.bias', + 'lr': 1e-9, + 'weight_decay': 0.0, + },] + }, + { + 'param_groups': [{ + 'param_str_match': 'norm', + 'lr': 1e-4, + 'weight_decay': 0.0, + },], + 'disable_grad': ['bias'], + }, +]) +def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], + opt_additional_config: Dict[str, Any]): + model = _DummyModule() + optimizer_config.update(deepcopy(opt_additional_config)) + optimizer = build_optimizer(model, name, optimizer_config) + + if 'disable_grad' in opt_additional_config.keys(): + disable_grad = opt_additional_config['disable_grad'] + if isinstance(disable_grad, str): + disable_grad = [disable_grad] + for n, p in model.named_parameters(): + for k in disable_grad: + if re.search(k, n): + assert not p.requires_grad + + if 'param_groups' in opt_additional_config.keys(): + for param_group_config, param_group in zip( + opt_additional_config['param_groups'], + optimizer.param_groups[1:]): + param_group_config = deepcopy(param_group_config) + param_str_match = param_group_config.pop('param_str_match') + + for k, v in param_group_config.items(): + assert param_group[k] == v + + param_ids = [id(p) for p in param_group['params']] + for n, p in model.named_parameters(): + if re.search(param_str_match, n): + assert id(p) in param_ids From 3a96b69965189876ff3bccceebb26d991e9bea72 Mon Sep 17 00:00:00 2001 From: Anna Date: Wed, 29 Nov 2023 10:29:07 -0800 Subject: [PATCH 05/14] Add script for doing bulk generation against an endpoint (#765) * Add script for doing bulk generation against an endpoint * more logging * warn * fix * format * asdfads * Add warning * updates * folder -> file * remove blank line * Support remote input * prompts -> inputs --- llmfoundry/utils/prompt_files.py | 58 +++++++ scripts/inference/endpoint_generate.py | 223 +++++++++++++++++++++++++ scripts/inference/hf_generate.py | 31 ++-- tests/test_prompt_files.py | 18 ++ 4 files changed, 309 insertions(+), 21 deletions(-) create mode 100644 llmfoundry/utils/prompt_files.py create mode 100644 scripts/inference/endpoint_generate.py create mode 100644 tests/test_prompt_files.py diff --git a/llmfoundry/utils/prompt_files.py b/llmfoundry/utils/prompt_files.py new file mode 100644 index 0000000000..40de19907a --- /dev/null +++ b/llmfoundry/utils/prompt_files.py @@ -0,0 +1,58 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List, Optional + +PROMPTFILE_PREFIX = 'file::' + + +def load_prompts(prompts: List[str], + prompt_delimiter: Optional[str] = None) -> List[str]: + """Loads a set of prompts, both free text and from file. + + Args: + prompts (List[str]): List of free text prompts and prompt files + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + prompt_strings = [] + for prompt in prompts: + if prompt.startswith(PROMPTFILE_PREFIX): + prompts = load_prompts_from_file(prompt, prompt_delimiter) + prompt_strings.extend(prompts) + else: + prompt_strings.append(prompt) + return prompt_strings + + +def load_prompts_from_file(prompt_path: str, + prompt_delimiter: Optional[str] = None) -> List[str]: + """Load a set of prompts from a text fie. + + Args: + prompt_path (str): Path for text file + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + if not prompt_path.startswith(PROMPTFILE_PREFIX): + raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}') + + _, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1) + prompt_file_path = os.path.expanduser(prompt_file_path) + if not os.path.isfile(prompt_file_path): + raise FileNotFoundError( + f'{prompt_file_path=} does not match any existing files.') + + with open(prompt_file_path, 'r') as f: + prompt_string = f.read() + + if prompt_delimiter is None: + return [prompt_string] + return [i for i in prompt_string.split(prompt_delimiter) if i] diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py new file mode 100644 index 0000000000..e78fecf59b --- /dev/null +++ b/scripts/inference/endpoint_generate.py @@ -0,0 +1,223 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Batch generate text completion results from an endpoint. + +Warning: This script is experimental and could change or be removed at any time +""" + +import asyncio +import copy +import logging +import math +import os +import tempfile +import time +from argparse import ArgumentParser, Namespace + +import pandas as pd +import requests +from composer.utils import (get_file, maybe_create_object_store_from_uri, + parse_uri) + +from llmfoundry.utils import prompt_files as utils + +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') +log = logging.getLogger(__name__) + +ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY' +ENDPOINT_URL_ENV: str = 'ENDPOINT_URL' + +PROMPT_DELIMITER = '\n' + + +def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description='Call prompts against a text completions endpoint') + + ##### + # Path Parameters + parser.add_argument( + '-i', + '--inputs', + nargs='+', + help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\ + ' and/or remote object stores' + ) + parser.add_argument( + '--prompt-delimiter', + default='\n', + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') + + parser.add_argument('-o', + '--output-folder', + required=True, + help='Remote folder to save the output') + + ##### + # Generation Parameters + parser.add_argument( + '--rate-limit', + type=int, + default=75, + help='Max number of calls to make to the endpoint in a second') + parser.add_argument( + '--batch-size', + type=int, + default=10, + help='Max number of calls to make to the endpoint in a single request') + + ##### + # Endpoint Parameters + parser.add_argument( + '-e', + '--endpoint', + type=str, + help= + f'OpenAI-compatible text completions endpoint to query on. If not set, will read from {ENDPOINT_URL_ENV}' + ) + + parser.add_argument('--max-tokens', type=int, default=100) + parser.add_argument('--temperature', type=float, default=1.0) + parser.add_argument('--top-k', type=int, default=50) + parser.add_argument('--top-p', type=float, default=1.0) + return parser.parse_args() + + +async def main(args: Namespace) -> None: + # This is mildly experimental, so for now imports are not added as part of llm-foundry + try: + import aiohttp + except ImportError as e: + raise ImportError('Please install aiohttp') from e + + try: + from ratelimit import limits, sleep_and_retry + except ImportError as e: + raise ImportError('Please install ratelimit') from e + + if args.batch_size > args.rate_limit: + raise ValueError( + f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s' + ) + + url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV) + if not url: + raise ValueError( + f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}') + + log.info(f'Using endpoint {url}') + + api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '') + if not api_key: + log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') + + new_inputs = [] + for prompt in args.inputs: + if prompt.startswith(utils.PROMPTFILE_PREFIX): + new_inputs.append(prompt) + continue + + input_object_store = maybe_create_object_store_from_uri(prompt) + if input_object_store is not None: + local_output_path = tempfile.TemporaryDirectory().name + get_file(prompt, str(local_output_path)) + log.info(f'Downloaded {prompt} to {local_output_path}') + prompt = f'{utils.PROMPTFILE_PREFIX}{local_output_path}' + + new_inputs.append(prompt) + + prompt_strings = utils.load_prompts(new_inputs, args.prompt_delimiter) + + cols = ['batch', 'prompt', 'output'] + param_data = { + 'max_tokens': args.max_tokens, + 'temperature': args.temperature, + 'top_k': args.top_k, + 'top_p': args.top_p, + } + + total_batches = math.ceil(len(prompt_strings) / args.batch_size) + log.info( + f'Generating {len(prompt_strings)} prompts in {total_batches} batches') + + @sleep_and_retry + @limits(calls=total_batches, period=1) # type: ignore + async def generate(session: aiohttp.ClientSession, batch: int, + prompts: list): + data = copy.copy(param_data) + data['prompt'] = prompts + headers = {'Authorization': api_key, 'Content-Type': 'application/json'} + req_start = time.time() + async with session.post(url, headers=headers, json=data) as resp: + if resp.ok: + try: + response = await resp.json() + except requests.JSONDecodeError: + raise Exception( + f'Bad response: {resp.status} {resp.reason}') + else: + raise Exception(f'Bad response: {resp.status} {resp.reason}') + + req_end = time.time() + n_compl = response['usage']['completion_tokens'] + n_prompt = response['usage']['prompt_tokens'] + req_latency = (req_end - req_start) + log.info(f'Completed batch {batch}: {n_compl:,} completion' + + f' tokens using {n_prompt:,} prompt tokens in {req_latency}s') + + res = pd.DataFrame(columns=cols) + + for r in response['choices']: + index = r['index'] + res.loc[len(res)] = [batch, prompts[index], r['text']] + return res + + res = pd.DataFrame(columns=cols) + batch = 0 + + gen_start = time.time() + async with aiohttp.ClientSession() as session: + tasks = [] + + for i in range(total_batches): + prompts = prompt_strings[i * args.batch_size:min( + (i + 1) * args.batch_size, len(prompt_strings))] + + tasks.append(generate(session, batch, prompts)) + batch += 1 + + results = await asyncio.gather(*tasks) + res = pd.concat(results) + + res.reset_index(drop=True, inplace=True) + + gen_end = time.time() + gen_latency = (gen_end - gen_start) + log.info(f'Generated {len(res)} prompts in {gen_latency}s, example data:') + log.info(res.head()) + + with tempfile.TemporaryDirectory() as tmp_dir: + file = 'output.csv' + local_path = os.path.join(tmp_dir, file) + res.to_csv(local_path, index=False) + + output_object_store = maybe_create_object_store_from_uri( + args.output_folder) + if output_object_store is not None: + _, _, output_folder_prefix = parse_uri(args.output_folder) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object(remote_path, local_path) + output_object_store.download_object + log.info(f'Uploaded results to {args.output_folder}/{file}') + else: + output_dir, _ = os.path.split(args.output_folder) + os.makedirs(output_dir, exist_ok=True) + os.rename(local_path, args.output_folder) + log.info(f'Saved results to {args.output_folder}') + + +if __name__ == '__main__': + asyncio.run(main(parse_args())) diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 45ddc6b63e..6ac645e5b7 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import itertools -import os import random import time import warnings @@ -13,6 +12,8 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from llmfoundry.utils import prompt_files as utils + def get_dtype(dtype: str): if dtype == 'fp32': @@ -62,9 +63,14 @@ def parse_args() -> Namespace: 'My name is', 'This is an explanation of deep learning to a five year old. Deep learning is', ], - help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\ - 'prompt contained in a txt file.' + help='List of generation prompts or list of delimited files. Use syntax ' +\ + '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' ) + parser.add_argument( + '--prompt-delimiter', + default=None, + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') parser.add_argument('--max_seq_len', type=int, default=None) parser.add_argument('--max_new_tokens', type=int, default=100) parser.add_argument('--max_batch_size', type=int, default=None) @@ -125,19 +131,6 @@ def parse_args() -> Namespace: return parser.parse_args() -def load_prompt_string_from_file(prompt_path_str: str): - if not prompt_path_str.startswith('file::'): - raise ValueError('prompt_path_str must start with "file::".') - _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1) - prompt_file_path = os.path.expanduser(prompt_file_path) - if not os.path.isfile(prompt_file_path): - raise FileNotFoundError( - f'{prompt_file_path=} does not match any existing files.') - with open(prompt_file_path, 'r') as f: - prompt_string = ''.join(f.readlines()) - return prompt_string - - def maybe_synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() @@ -163,11 +156,7 @@ def main(args: Namespace) -> None: print(f'Using {model_dtype=}') # Load prompts - prompt_strings = [] - for prompt in args.prompts: - if prompt.startswith('file::'): - prompt = load_prompt_string_from_file(prompt) - prompt_strings.append(prompt) + prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter) # Grab config first print(f'Loading HF Config...') diff --git a/tests/test_prompt_files.py b/tests/test_prompt_files.py new file mode 100644 index 0000000000..12a5d02999 --- /dev/null +++ b/tests/test_prompt_files.py @@ -0,0 +1,18 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from llmfoundry.utils import prompt_files as utils + + +def test_load_prompt_strings(tmp_path: Path): + assert utils.load_prompts(['hello', 'goodbye']) == ['hello', 'goodbye'] + + with open(tmp_path / 'prompts.txt', 'w') as f: + f.write('hello goodbye') + + temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt') + assert utils.load_prompts( + [temp, temp, 'why'], + ' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why'] From 1191267195367b5ec6093ed7854b8f6daf1be2d3 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 29 Nov 2023 12:14:02 -0800 Subject: [PATCH 06/14] Only strip object names when creating new output path (#766) --- llmfoundry/utils/data_prep_utils.py | 11 +++++----- tests/test_convert_text_to_mds.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/data_prep_utils.py b/llmfoundry/utils/data_prep_utils.py index 75e27b504f..a88e65ee94 100644 --- a/llmfoundry/utils/data_prep_utils.py +++ b/llmfoundry/utils/data_prep_utils.py @@ -96,15 +96,16 @@ def __init__( def __iter__(self): for object_name in self.object_names: - object_name = object_name.strip('/') - output_filename = os.path.join(self.output_folder, object_name) + # Default output_filename, used for local paths. + output_filename = object_name + + # Download objects if remote path. if self.object_store is not None: + output_filename = os.path.join(self.output_folder, + object_name.strip('/')) self.object_store.download_object(object_name=object_name, filename=output_filename, overwrite=True) - else: - # Inputs are local so we do not need to download them. - output_filename = object_name with open(output_filename) as _txt_file: txt = _txt_file.read() diff --git a/tests/test_convert_text_to_mds.py b/tests/test_convert_text_to_mds.py index 2d4878ebbb..ab8c25bc2d 100644 --- a/tests/test_convert_text_to_mds.py +++ b/tests/test_convert_text_to_mds.py @@ -188,6 +188,37 @@ def test_single_and_multi_process(merge_shard_groups: Mock, assert n_tokens == expected_n_tokens +def test_local_path(tmp_path: pathlib.Path): + # Input/output folders + input_folder = tmp_path / 'input' + output_folder = tmp_path / 'output' + + # Create input text data + os.makedirs(input_folder, exist_ok=True) + with open(input_folder / 'test.txt', 'w') as f: + f.write('test') + + # Convert text data to mds + convert_text_to_mds( + tokenizer_name='mosaicml/mpt-7b', + output_folder=str(output_folder), + input_folder=str(input_folder), + concat_tokens=1, + eos_text='', + bos_text='', + no_wrap=False, + compression='zstd', + processes=1, + args_str='Namespace()', + reprocess=False, + ) + + # Make sure all the files exist as expected. + assert os.path.exists(output_folder / '.text_to_mds_conversion_done') + assert os.path.exists(output_folder / 'index.json') + assert os.path.exists(output_folder / 'shard.00000.mds.zstd') + + def test_is_already_processed(tmp_path: pathlib.Path): tmp_path_str = str(tmp_path) args_str = 'Namespace(x = 5)' From 3100859905c1ed29e049e7c203cf70da8231f2e6 Mon Sep 17 00:00:00 2001 From: Anna Date: Thu, 30 Nov 2023 14:02:13 -0800 Subject: [PATCH 07/14] Add eval loader to eval script (#742) * Add eval loader to eval script * small input tests * updates * fix typing and formatting * fixes, add tests * remove circular dependency * tests pass * nits + small fixes * add metrics at the end, refactor to put icl/gauntlet as helpers * NOT * metrics instead of models, add unit tests --- llmfoundry/data/dataloader.py | 32 ++++----- llmfoundry/utils/builders.py | 81 +++++++++++++++++++++++ scripts/eval/eval.py | 53 +++++++++++---- scripts/train/train.py | 52 +++++---------- tests/data_utils.py | 98 +++++++++++++++++++++++++++- tests/test_builders.py | 118 +++++++++++++++++++++++++++++++++- tests/test_dataloader.py | 11 ++++ tests/test_eval.py | 89 +++++++++++++++++++++++++ tests/test_eval_inputs.py | 1 + tests/test_train_inputs.py | 2 +- tests/test_training.py | 97 ++-------------------------- 11 files changed, 469 insertions(+), 165 deletions(-) diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 12741717be..63d47a65d5 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -11,6 +11,12 @@ from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader from llmfoundry.data.text_data import build_text_dataloader +LOADER_NAME_TO_FUNCTION = { + 'text': build_text_dataloader, + 'text_denoising': build_text_denoising_dataloader, + 'finetuning': build_finetuning_dataloader, +} + def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataSpec: @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size (int): The size of the batches (number of examples) that the dataloader will produce. """ - if cfg.name == 'text': - return build_text_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'text_denoising': - return build_text_denoising_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - elif cfg.name == 'finetuning': - return build_finetuning_dataloader( - cfg, - tokenizer, - device_batch_size, - ) - else: - raise ValueError(f'Not sure how to build dataloader with config: {cfg}') + if cfg.name not in LOADER_NAME_TO_FUNCTION: + allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys()) + raise ValueError(f'Expected dataloader name to be one of {allowed}' + + f' but found name "{cfg.name}" in config: {cfg}') + + return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 14196c3ef9..a672fbee55 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -28,12 +28,14 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from torch.optim.optimizer import Optimizer +from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling, HuggingFaceCheckpointer, LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler @@ -42,6 +44,85 @@ log = logging.getLogger(__name__) +def build_evaluators( + eval_loader_config: Optional[Union[DictConfig, ListConfig]], + icl_tasks_config: Optional[Union[str, ListConfig]], + eval_gauntlet_config: Optional[Union[str, DictConfig]], + *, + tokenizer: PreTrainedTokenizerBase, + device_eval_batch_size: int, + icl_seq_len: int, + icl_subset_num_batches: Optional[int], +) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: + + evaluators = [] + if eval_loader_config is not None: + evaluators = build_eval_loaders( + eval_loader_config, + tokenizer, + device_eval_batch_size, + ) + + logger_keys = [] + eval_gauntlet_callback = None + if icl_tasks_config is not None: + icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet( + icl_tasks_config, + eval_gauntlet_config, + tokenizer, + device_eval_batch_size, + icl_seq_len, + icl_subset_num_batches, + ) + evaluators.extend(icl_evaluators) + + return evaluators, logger_keys, eval_gauntlet_callback + + +def build_eval_loaders( + eval_loader_config: Union[DictConfig, ListConfig], + tokenizer: PreTrainedTokenizerBase, + device_eval_batch_size: int, +) -> List[Evaluator]: + evaluators: List[Evaluator] = [] + if isinstance(eval_loader_config, ListConfig): + eval_configs: ListConfig = eval_loader_config + is_multi_eval = True + else: + eval_configs = ListConfig([eval_loader_config]) + is_multi_eval = False + + for eval_config in eval_configs: + eval_dataloader = build_dataloader(eval_config, tokenizer, + device_eval_batch_size) + eval_loader: Evaluator = Evaluator( + label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', + dataloader=eval_dataloader, + # Load the eval data to fail fast. metrics will get added + # later in add_metrics_to_eval_loaders, after the model is loaded + metric_names=[], + ) + evaluators.append(eval_loader) + return evaluators + + +def add_metrics_to_eval_loaders( + evaluators: List[Evaluator], + metrics: Dict[str, Metric], +) -> List[Evaluator]: + metric_names = list(metrics.keys()) + eval_loaders, other_evaluators = [], [] + for evaluator in evaluators: + if evaluator.metric_names == []: + evaluator.metric_names = metric_names + eval_loaders.append(evaluator) + else: + other_evaluators.append(evaluator) + + # Put the base eval_loaders first + return eval_loaders + other_evaluators + + def build_icl_data_and_gauntlet( icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 02a5d1f862..369a894720 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -6,7 +6,7 @@ import sys import time import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import pandas as pd import torch @@ -21,13 +21,14 @@ from llmfoundry.models import MPTForCausalLM from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY -from llmfoundry.utils.builders import (build_icl_data_and_gauntlet, - build_logger, build_tokenizer) +from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, + build_evaluators, build_logger, + build_tokenizer) from llmfoundry.utils.config_utils import pop_config, process_init_device def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - num_retries: int) -> Optional[ComposerModel]: + num_retries: int) -> ComposerModel: try: from peft import PeftModel except ImportError as e: @@ -43,7 +44,8 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, } retries = 0 - while retries < num_retries: + composer_model_wrapper = None + while retries < num_retries and composer_model_wrapper is None: try: trust_remote_code = model_cfg.get('trust_remote_code', True) use_auth_token = model_cfg.get('use_auth_token', False) @@ -58,7 +60,6 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, composer_model_wrapper = COMPOSER_MODEL_REGISTRY[model_cfg.name]( peft_model, tokenizer) - return composer_model_wrapper except Exception as e: retries += 1 if retries >= num_retries: @@ -68,19 +69,21 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) + assert composer_model_wrapper is not None + return composer_model_wrapper + def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, - fsdp_config: Optional[Dict], - num_retries: int) -> Optional[ComposerModel]: + fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel: init_context = process_init_device(model_cfg, fsdp_config) retries = 0 + composer_model = None with init_context: - while retries < num_retries: + while retries < num_retries and composer_model is None: try: composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name]( model_cfg, tokenizer) - return composer_model except Exception as e: retries += 1 if retries >= num_retries: @@ -90,6 +93,9 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) + assert composer_model is not None + return composer_model + def evaluate_model( model_cfg: DictConfig, @@ -100,6 +106,7 @@ def evaluate_model( max_seq_len: int, device_eval_batch_size: int, eval_gauntlet_config: Optional[Union[str, DictConfig]], + eval_loader_config: Optional[Union[DictConfig, ListConfig]], fsdp_config: Optional[Dict], num_retries: int, loggers_cfg: Dict[str, Any], @@ -118,9 +125,15 @@ def evaluate_model( tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet( - icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size, - max_seq_len, icl_subset_num_batches) + evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=max_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) callbacks = [] if eval_gauntlet_callback is not None: @@ -143,6 +156,11 @@ def evaluate_model( composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, num_retries) + # Now add the eval metrics + if eval_loader_config is not None: + train_metrics = composer_model.get_metrics(is_train=True) + evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) + if eval_gauntlet_df is None and eval_gauntlet_callback is not None: eval_gauntlet_df = pd.DataFrame( columns=['model_name'] + @@ -186,7 +204,7 @@ def evaluate_model( return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) -def main(cfg: DictConfig): +def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: om.resolve(cfg) model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True) eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config( @@ -228,6 +246,8 @@ def main(cfg: DictConfig): default_value='debug') # Optional Evaluation Parameters with default values + eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config( + cfg, 'eval_loader', must_exist=False, default_value=None) seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17) dist_timeout: Union[float, int] = pop_config(cfg, 'dist_timeout', @@ -274,6 +294,7 @@ def main(cfg: DictConfig): eval_gauntlet_df = None models_df = None composite_scores = None + trainers = [] for model_cfg in model_configs: (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) = evaluate_model( @@ -285,6 +306,7 @@ def main(cfg: DictConfig): max_seq_len=max_seq_len, device_eval_batch_size=device_eval_batch_size, eval_gauntlet_config=eval_gauntlet_config, + eval_loader_config=eval_loader_config, fsdp_config=fsdp_config, num_retries=num_retries, loggers_cfg=loggers_cfg, @@ -292,6 +314,7 @@ def main(cfg: DictConfig): precision=precision, eval_gauntlet_df=eval_gauntlet_df, icl_subset_num_batches=icl_subset_num_batches) + trainers.append(trainer) if eval_gauntlet_callback is not None: composite_scores = eval_gauntlet_callback.eval_after_all( @@ -330,6 +353,8 @@ def main(cfg: DictConfig): assert models_df is not None print(models_df.to_markdown(index=False)) + return trainers, eval_gauntlet_df + def calculate_markdown_results(logger_keys: List[str], trainer: Trainer, benchmark_to_taxonomy: Dict[str, str], diff --git a/scripts/train/train.py b/scripts/train/train.py index 88f776375f..809f2fb09c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -11,7 +11,6 @@ import torch from composer import Trainer -from composer.core import Evaluator from composer.core.callback import Callback from composer.loggers import MosaicMLLogger from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, @@ -26,10 +25,11 @@ from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, MPTForCausalLM) from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.utils.builders import (build_algorithm, build_callback, - build_icl_data_and_gauntlet, - build_logger, build_optimizer, - build_scheduler, build_tokenizer) +from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, + build_algorithm, build_callback, + build_evaluators, build_logger, + build_optimizer, build_scheduler, + build_tokenizer) from llmfoundry.utils.config_utils import (log_config, pop_config, process_init_device, update_batch_size_info) @@ -526,31 +526,16 @@ def main(cfg: DictConfig) -> Trainer: ## Evaluation print('Building eval loader...') - evaluators = [] - eval_loaders = [] - if eval_loader_config is not None: - is_multi_eval = isinstance(eval_loader_config, ListConfig) - eval_configs = eval_loader_config if is_multi_eval else [ - eval_loader_config - ] - for eval_config in eval_configs: - eval_dataloader = build_dataloader(eval_config, tokenizer, - device_eval_batch_size) - eval_loader = Evaluator( - label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', - dataloader=eval_dataloader, - metric_names=[], # we will add these after model is created - ) - eval_loaders.append(eval_loader) - - eval_gauntlet_callback = None - - if icl_tasks_config is not None: - icl_evaluators, _, eval_gauntlet_callback = build_icl_data_and_gauntlet( - icl_tasks_config, eval_gauntlet_config, tokenizer, - device_eval_batch_size, icl_seq_len if icl_seq_len else max_seq_len, - icl_subset_num_batches) - evaluators.extend(icl_evaluators) + eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len + evaluators, _, eval_gauntlet_callback = build_evaluators( + eval_loader_config, + icl_tasks_config, + eval_gauntlet_config, + tokenizer=tokenizer, + device_eval_batch_size=device_eval_batch_size, + icl_seq_len=eval_icl_seq_len, + icl_subset_num_batches=icl_subset_num_batches, + ) if eval_gauntlet_callback is not None: callbacks.append(eval_gauntlet_callback) @@ -581,11 +566,8 @@ def main(cfg: DictConfig) -> Trainer: # Now add the eval metrics if eval_loader_config is not None: - assert model.train_metrics is not None - eval_metric_names = list(model.train_metrics.keys()) - for eval_loader in eval_loaders: - eval_loader.metric_names = eval_metric_names - evaluators.insert(0, eval_loader) # Put the base eval_loaders first + train_metrics = model.get_metrics(is_train=True) + evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) # Build the Trainer print('Building trainer...') diff --git a/tests/data_utils.py b/tests/data_utils.py index 075933de7d..efb4f6d7cf 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -1,10 +1,26 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json import os +import sys + +# Add repo root to path so we can import scripts and test it +repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(repo_dir) + +import json +import pathlib +import shutil +from argparse import Namespace from typing import Optional +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 +from scripts.data_prep.convert_dataset_json import \ + main as main_json # noqa: E402 + def make_tiny_ft_dataset( path: str, @@ -65,3 +81,83 @@ def make_tiny_ft_dataset( for sample in samples: _f.write(json.dumps(sample)) _f.write('\n') + + +def create_c4_dataset_xxsmall(path: pathlib.Path) -> str: + """Creates a small mocked version of the C4 dataset.""" + c4_dir = os.path.join(path, f'my-copy-c4') + downloaded_split = 'val_xxsmall' # very fast to convert + + # Hyperparameters from https://github.com/mosaicml/llm-foundry/blob/340a56658560ebceb2a3aa69d6e37813e415acd0/README.md#L188 + main_hf( + Namespace( + **{ + 'dataset': 'c4', + 'data_subset': 'en', + 'splits': [downloaded_split], + 'out_root': c4_dir, + 'compression': None, + 'concat_tokens': 2048, + 'tokenizer': 'EleutherAI/gpt-neox-20b', + 'tokenizer_kwargs': {}, + 'bos_text': '', + 'eos_text': '<|endoftext|>', + 'no_wrap': False, + 'num_workers': 8 + })) + + # copy the small downloaded_split to other c4 splits for mocking purposes + mocked_splits = ['train', 'val'] + for mocked_split in mocked_splits: + shutil.copytree(os.path.join(c4_dir, 'val_xxsmall'), + os.path.join(c4_dir, mocked_split)) + assert os.path.exists(c4_dir) + return c4_dir + + +def create_arxiv_dataset(path: pathlib.Path) -> str: + """Creates an arxiv dataset.""" + arxiv_dir = os.path.join(path, f'my-copy-arxiv') + downloaded_split = 'train' + + main_json( + Namespace( + **{ + 'path': 'data_prep/example_data/arxiv.jsonl', + 'out_root': arxiv_dir, + 'compression': None, + 'split': downloaded_split, + 'concat_tokens': None, + 'bos_text': None, + 'eos_text': None, + 'no_wrap': False, + 'num_workers': None + })) + + return arxiv_dir + + +def gpt_tiny_cfg(dataset_name: str, device: str): + """Create gpt tiny cfg.""" + conf_path: str = os.path.join(repo_dir, + 'scripts/train/yamls/pretrain/testing.yaml') + with open(conf_path) as f: + test_cfg = om.load(f) + assert isinstance(test_cfg, DictConfig) + + test_cfg.data_local = dataset_name + test_cfg.global_train_batch_size = 8 + test_cfg.device_eval_batch_size = 4 + test_cfg.device_train_microbatch_size = 4 + test_cfg.max_duration = '4ba' + test_cfg.eval_interval = '4ba' + test_cfg.run_name = 'gpt-mini-integration-test' + + if device == 'cpu': + test_cfg.model.init_device = 'cpu' + test_cfg.fsdp_config = None + test_cfg.model.attn_config.attn_impl = 'torch' + test_cfg.model.loss_fn = 'torch_crossentropy' + test_cfg.precision = 'fp32' + + return test_cfg diff --git a/tests/test_builders.py b/tests/test_builders.py index 7ac179720e..5c38ed8602 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -5,17 +5,22 @@ import unittest.mock as mock from copy import deepcopy from typing import Any, Dict, Union +from unittest.mock import MagicMock import pytest import torch import torch.nn as nn from composer.callbacks import Generate +from composer.core import Evaluator +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper -from llmfoundry.utils.builders import (build_callback, build_optimizer, +from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, + build_callback, build_eval_loaders, + build_evaluators, build_optimizer, build_tokenizer) @@ -195,3 +200,114 @@ def test_build_optimizer(name: str, optimizer_config: Dict[str, Any], for n, p in model.named_parameters(): if re.search(param_str_match, n): assert id(p) in param_ids + + +def test_build_evaluators_empty(): + evaluators, logger_keys, eval_gauntlet_callback = build_evaluators( + None, + None, + None, + tokenizer=None, # type: ignore + device_eval_batch_size=1, + icl_seq_len=2, + icl_subset_num_batches=3) + assert evaluators == [] + assert logger_keys == [] + assert eval_gauntlet_callback is None + + +def test_build_eval_loaders(monkeypatch: pytest.MonkeyPatch): + tokenizer = TiktokenTokenizerWrapper(model_name='gpt-4') + + eval_loader_cfg = DictConfig({ + 'name': 'text', + 'dataset': { + # mocked, not needed + }, + 'drop_last': False, + 'num_workers': 8, + }) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + eval_loaders = build_eval_loaders(eval_loader_cfg, tokenizer, 2) + + assert len(eval_loaders) == 1 + + assert eval_loaders[0].label == 'eval' + assert eval_loaders[0].dataloader is not None + assert eval_loaders[0].metric_names == [] + + multi_eval_loader_cfg = ListConfig([ + { + 'name': 'text', + 'label': 'test1', + 'dataset': { + # mocked, not needed + }, + 'drop_last': False, + 'num_workers': 8, + }, + { + 'name': 'text', + 'label': 'test2', + 'dataset': { + # mocked, not needed + }, + 'drop_last': False, + 'num_workers': 8, + } + ]) + monkeypatch.setattr('llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, **kwargs: MagicMock()) + eval_loaders2 = build_eval_loaders(multi_eval_loader_cfg, tokenizer, 2) + + assert len(eval_loaders2) == 2 + + assert eval_loaders2[0].label == 'eval/test1' + assert eval_loaders2[0].dataloader is not None + assert eval_loaders2[0].metric_names == [] + + assert eval_loaders2[1].label == 'eval/test2' + assert eval_loaders2[1].dataloader is not None + assert eval_loaders2[1].metric_names == [] + + +def test_add_metrics_to_eval_loaders(): + evaluators = [ + Evaluator( + label='first', + metric_names=['a', 'b'], + dataloader=None, # type: ignore + device_eval_microbatch_size=1, + ), + Evaluator( + label='second', + metric_names=[], + dataloader=None, # type: ignore + device_eval_microbatch_size=1, + ), + Evaluator( + label='third', + metric_names=['c'], + dataloader=None, # type: ignore + device_eval_microbatch_size=1, + ) + ] + + new_evaluators = add_metrics_to_eval_loaders( + evaluators, + { + 'new1': 'foo', + 'new2': 'bar' + }, # type: ignore + ) + assert len(new_evaluators) == 3 + + assert new_evaluators[0].label == 'second' + assert new_evaluators[0].metric_names == ['new1', 'new2'] + + assert new_evaluators[1].label == 'first' + assert new_evaluators[1].metric_names == ['a', 'b'] + + assert new_evaluators[2].label == 'third' + assert new_evaluators[2].metric_names == ['c'] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index c35d29f74d..2e9039644b 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -21,6 +21,7 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) +from llmfoundry.data import build_dataloader from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -740,3 +741,13 @@ def test_token_counting_func_dataloader_setting( actual_token_count = dl.get_num_tokens_in_batch(batch_tokenized) assert actual_token_count == expected_token_count + + +def test_build_unknown_dataloader(): + cfg = DictConfig({ + 'name': 'unknown', + }) + tokenizer = MagicMock() + with pytest.raises(ValueError, + match='Expected dataloader name to be one of'): + _ = build_dataloader(cfg, tokenizer, 2) diff --git a/tests/test_eval.py b/tests/test_eval.py index 1217487b70..2fc96bb7ad 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -1,16 +1,21 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import copy import os +import pathlib import sys from typing import Any import omegaconf as om import pytest from composer import Trainer +from composer.loggers import InMemoryLogger from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.utils import build_tokenizer +from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, + gpt_tiny_cfg) # Add repo root to path so we can import scripts and test it repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) @@ -66,3 +71,87 @@ def test_icl_eval(capfd: Any, mock_saved_model_path: Any): assert expected_results in out expected_results = '| model_name | default_average | language_understanding_lite |\n|:-------------|------------------:|------------------------------:|\n| tiny_mpt | 0 | 0 |' assert expected_results in out + + +@pytest.mark.gpu +def test_loader_eval(capfd: Any, mock_saved_model_path: Any, + tmp_path: pathlib.Path): + + c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) + + # Use a training config that already has eval loader configured + test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') + + # define icl eval task + test_cfg.icl_tasks = om.ListConfig([ + om.DictConfig({ + 'label': + 'lambada_openai', + 'dataset_uri': + 'eval/local_data/language_understanding/lambada_openai_small.jsonl', + 'num_fewshot': [0], + 'icl_task_type': + 'language_modeling' + }) + ]) + + # convert the model from a training to eval model + model = test_cfg.pop('model') + eval_model = { + 'model_name': model.get('name'), + 'model': model, + 'load_path': mock_saved_model_path + } + + tokenizer = test_cfg.pop('tokenizer') + eval_model['tokenizer'] = tokenizer + test_cfg.models = [eval_model] + + # Set up multiple eval dataloaders + first_eval_loader = test_cfg.eval_loader + first_eval_loader.label = 'c4' + # Create second eval dataloader using the arxiv dataset. + second_eval_loader = copy.deepcopy(first_eval_loader) + arxiv_dataset_name = create_arxiv_dataset(tmp_path) + second_eval_loader.data_local = arxiv_dataset_name + second_eval_loader.label = 'arxiv' + test_cfg.eval_loader = om.OmegaConf.create( + [first_eval_loader, second_eval_loader]) + + test_cfg.max_duration = '1ba' + test_cfg.eval_interval = '1ba' + test_cfg.loggers = om.DictConfig({'inmemory': om.DictConfig({})}) + + trainers, eval_gauntlet_df = main(test_cfg) + + assert eval_gauntlet_df is None + assert len(trainers) == 1 # one per model + trainer = trainers[0] + + assert isinstance(trainer.logger.destinations, tuple) + + assert len(trainer.logger.destinations) > 0 + inmemorylogger = trainer.logger.destinations[ + 0] # pyright: ignore [reportGeneralTypeIssues] + assert isinstance(inmemorylogger, InMemoryLogger) + print(inmemorylogger.data.keys()) + + # Checks for first eval dataloader + assert 'metrics/eval/c4/LanguageCrossEntropy' in inmemorylogger.data.keys() + assert isinstance( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'], list) + assert len( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1]) > 0 + assert isinstance( + inmemorylogger.data['metrics/eval/c4/LanguageCrossEntropy'][-1], tuple) + + # Checks for second eval dataloader + assert 'metrics/eval/arxiv/LanguageCrossEntropy' in inmemorylogger.data.keys( + ) + assert isinstance( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'], list) + assert len( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1]) > 0 + assert isinstance( + inmemorylogger.data['metrics/eval/arxiv/LanguageCrossEntropy'][-1], + tuple) diff --git a/tests/test_eval_inputs.py b/tests/test_eval_inputs.py index 9c7a130a9b..83104b62b7 100644 --- a/tests/test_eval_inputs.py +++ b/tests/test_eval_inputs.py @@ -57,6 +57,7 @@ def test_optional_mispelled_params_raise_warning(self, 'loggers', 'eval_gauntlet', 'fsdp_config', + 'eval_loader', ] old_cfg = copy.deepcopy(cfg) for param in optional_params: diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py index bf90f48ef0..2ed1c9c239 100644 --- a/tests/test_train_inputs.py +++ b/tests/test_train_inputs.py @@ -103,7 +103,7 @@ def test_optional_mispelled_params_raise_warning(self, 'save_folder', 'fsdp_config', 'lora_config', - 'eval_loader_config', + 'eval_loader', 'icl_tasks_config', ] old_cfg = copy.deepcopy(cfg) diff --git a/tests/test_training.py b/tests/test_training.py index 8390834d1d..3cd2963100 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -3,9 +3,6 @@ import copy import os import pathlib -import shutil -import sys -from argparse import Namespace from typing import Any, Optional import pytest @@ -14,95 +11,9 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - -from scripts.data_prep.convert_dataset_hf import main as main_hf # noqa: E402 -from scripts.data_prep.convert_dataset_json import \ - main as main_json # noqa: E402 from scripts.train.train import main # noqa: E402 - - -def create_c4_dataset_xsmall(path: pathlib.Path) -> str: - """Creates a small mocked version of the C4 dataset.""" - c4_dir = os.path.join(path, f'my-copy-c4') - downloaded_split = 'val_xxsmall' - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': [downloaded_split], - 'out_root': c4_dir, - 'compression': None, - 'concat_tokens': 2048, - 'tokenizer': 'EleutherAI/gpt-neox-20b', - 'tokenizer_kwargs': {}, - 'bos_text': '', - 'eos_text': '<|endoftext|>', - 'no_wrap': False, - 'num_workers': 8 - })) - - # copy the small downloaded_split to other c4 splits for mocking purposes - mocked_splits = ['train', 'val'] - for mocked_split in mocked_splits: - shutil.copytree(os.path.join(c4_dir, 'val_xxsmall'), - os.path.join(c4_dir, mocked_split)) - assert os.path.exists(c4_dir) - return c4_dir - - -def create_arxiv_dataset(path: pathlib.Path) -> str: - """Creates an arxiv dataset.""" - arxiv_dir = os.path.join(path, f'my-copy-arxiv') - downloaded_split = 'train' - - main_json( - Namespace( - **{ - 'path': 'data_prep/example_data/arxiv.jsonl', - 'out_root': arxiv_dir, - 'compression': None, - 'split': downloaded_split, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None - })) - - return arxiv_dir - - -def gpt_tiny_cfg(dataset_name: str, device: str): - """Create gpt tiny cfg.""" - conf_path: str = os.path.join(repo_dir, - 'scripts/train/yamls/pretrain/testing.yaml') - with open(conf_path) as f: - test_cfg = om.load(f) - assert isinstance(test_cfg, DictConfig) - - test_cfg.data_local = dataset_name - test_cfg.global_train_batch_size = 1 - test_cfg.device_eval_batch_size = 2 - test_cfg.device_train_microbatch_size = 1 - test_cfg.max_duration = '4ba' - test_cfg.eval_interval = '4ba' - test_cfg.run_name = 'gpt-mini-integration-test' - - test_cfg.model.n_layer = 2 - test_cfg.model.n_embd = 64 - - if device == 'cpu': - test_cfg.model.init_device = 'cpu' - test_cfg.fsdp_config = None - test_cfg.model.attn_config.attn_impl = 'torch' - test_cfg.model.loss_fn = 'torch_crossentropy' - test_cfg.precision = 'fp32' - - return test_cfg +from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, + gpt_tiny_cfg) @pytest.fixture(autouse=False) @@ -122,7 +33,7 @@ def set_correct_cwd(): def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, tmp_path: pathlib.Path): """Test training run with a small dataset.""" - dataset_name = create_c4_dataset_xsmall(tmp_path) + dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(dataset_name, 'cpu') test_cfg.icl_tasks = ListConfig([ DictConfig({ @@ -201,7 +112,7 @@ def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" - c4_dataset_name = create_c4_dataset_xsmall(tmp_path) + c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') # Set up multiple eval dataloaders first_eval_loader = test_cfg.eval_loader From 22ae919c1d6b2b542278399586a1835e0e632bba Mon Sep 17 00:00:00 2001 From: Sam Havens Date: Thu, 30 Nov 2023 17:47:43 -0800 Subject: [PATCH 08/14] Support inputs_embeds (#687) * support inputs_embeds * update tests to test inputs_embeds * make iids optional inputs to fwd * remove check for both iids and inputs_embeds in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead * reorder kwargs * add more tests * fix device merge artifact in test_model.oy * fix generate test * yapf --- llmfoundry/models/mpt/modeling_mpt.py | 51 +++++++++-------- tests/test_model.py | 79 +++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 29 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 274c1b76e5..d6b23c04d0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -368,7 +368,7 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor, def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, @@ -412,11 +412,6 @@ def forward( 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.' ) - # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT) - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds is not implemented for MPT.') - if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( @@ -430,14 +425,25 @@ def forward( 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.' ) - S = input_ids.size(1) + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds.') + elif input_ids is not None: + S = input_ids.size(1) + x = self.wte(input_ids) + input_device = input_ids.device + elif inputs_embeds is not None: + S = inputs_embeds.size(1) + x = inputs_embeds + input_device = inputs_embeds.device + else: + raise ValueError('You must specify input_ids or inputs_embeds') assert ( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: @@ -467,7 +473,7 @@ def forward( past_position, S + past_position, dtype=torch.long, - device=input_ids.device, + device=input_device, ).unsqueeze(0) if attention_mask is not None: # adjust the position indices to account for padding tokens @@ -652,7 +658,7 @@ def get_decoder(self) -> MPTModel: def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, @@ -669,11 +675,6 @@ def forward( use_cache = (use_cache if use_cache is not None else self.config.use_cache) - # if input_embeds is not none, raise a not implemented error - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds has to be None (for hf/peft support).') - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, @@ -684,6 +685,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, + inputs_embeds=inputs_embeds, ) if self.lm_head is not None: @@ -773,10 +775,6 @@ def prepare_inputs_for_generation( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Dict[str, Any]: - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds is not implemented for MPT yet') - attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( @@ -787,6 +785,7 @@ def prepare_inputs_for_generation( else: sequence_id = None + # only last token for inputs_ids if past is defined in kwargs if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -800,14 +799,20 @@ def prepare_inputs_for_generation( else: prefix_mask = None - return { - 'input_ids': input_ids, + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), - } + }) + return model_inputs @staticmethod def _reorder_cache( @@ -898,7 +903,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: add_bidirectional_mask_if_missing(batch) # Note: prefix_mask is only used if model.prefix_lm is True return self.model( - input_ids=batch['input_ids'], + input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), prefix_mask=batch.get('bidirectional_mask', None), sequence_id=batch.get('sequence_id', None), diff --git a/tests/test_model.py b/tests/test_model.py index 4d5b0a4dbc..acb2074ae9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -5,7 +5,7 @@ import os import pathlib import warnings -from typing import Any, Dict, Union, cast +from typing import Any, Dict, List, Optional, Union, cast from unittest import mock import pytest @@ -94,13 +94,26 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): return test_cfg, model, optimizer -def gen_random_batch(batch_size: int, test_cfg: Union[DictConfig, ListConfig]): +def gen_random_batch(batch_size: int, + test_cfg: Union[DictConfig, ListConfig], + inputs: Optional[List[str]] = None): + # inputs can be [], ['input_ids'], ['input_ids', 'inputs_embeds'], and ['inputs_embeds'] + # default to only input ids + if inputs == None: + inputs = ['input_ids'] # generate input batch of random data, suitable for a Causal or Prefix LM batch = {} - batch['input_ids'] = torch.randint( - low=0, - high=test_cfg.model.vocab_size, - size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + for inp in inputs: + if inp == 'input_ids': + batch['input_ids'] = torch.randint( + low=0, + high=test_cfg.model.vocab_size, + size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + if inp == 'inputs_embeds': + batch['inputs_embeds'] = torch.randn( + batch_size, test_cfg.max_seq_len, + test_cfg.model.d_model).to(test_cfg.device) + batch['labels'] = torch.randint(low=0, high=test_cfg.model.vocab_size, size=(batch_size, test_cfg.max_seq_len)).to( @@ -150,6 +163,34 @@ def test_full_forward_and_backward(batch_size: int = 2): assert not torch.equal(original_params, updated_params) +def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): + test_cfg, model, optimizer = get_objs( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + + batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds']) + + model.train() + original_params = next(model.parameters()).clone().data + outputs = model(batch) + loss = model.loss(outputs, batch) + loss.backward() + optimizer.step() + updated_params = next(model.parameters()).clone().data + assert not torch.equal(original_params, updated_params) + + +@pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']]) +def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]): + test_cfg, model, _ = get_objs( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + + batch = gen_random_batch(2, test_cfg, inputs=inputs) + + model.train() + with pytest.raises(ValueError): + _ = model(batch) + + def test_attention_mechanism(batch_size: int = 2): test_cfg, model, _ = get_objs( conf_path='scripts/train/yamls/pretrain/testing.yaml') @@ -825,6 +866,9 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, no_padding_attention_mask = composer_device.tensor_to_device( no_padding_attention_mask) + # inputs_embeds + inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128)) + # a single batch with different amounts of left padding in the input batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11], [50256, 50256, 16, 11274, 16390, 11]]) @@ -860,6 +904,29 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, assert generation_with_no_padding[:, 3:].equal( generation_with_left_padding[:, 6:]) + # check that both/neither ids and embeds do not error + # note that we need to set the BOS token ID for generating from neither + _ = mpt.generate(input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=False) + _ = mpt.generate(input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=True) + _ = mpt.generate(input_ids=None, + inputs_embeds=None, + max_new_tokens=5, + use_cache=False, + bos_token_id=50256) + _ = mpt.generate(input_ids=None, + inputs_embeds=None, + max_new_tokens=5, + use_cache=True, + bos_token_id=50256) + @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) From 9cf99b7457a6ed0e199a56785dad697bd4a09a58 Mon Sep 17 00:00:00 2001 From: Anna Date: Thu, 30 Nov 2023 19:38:01 -0800 Subject: [PATCH 09/14] Better error message when test does not complete (#769) --- .github/mcp/mcp_pytest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/mcp/mcp_pytest.py b/.github/mcp/mcp_pytest.py index 5f0aaa147b..b6d74880c8 100644 --- a/.github/mcp/mcp_pytest.py +++ b/.github/mcp/mcp_pytest.py @@ -130,7 +130,7 @@ print(line, end='') print('[GHA] Run completed. Waiting for run to finish...') - run = wait_for_run_status(run, status='completed') + run = wait_for_run_status(run, status=RunStatus.COMPLETED) - # Fail if command exited with non-zero exit code or timed out - assert run.status == RunStatus.COMPLETED + # Fail if command exited with non-zero exit code or timed out (didn't reach COMPLETED) + assert run.status == RunStatus.COMPLETED, f'Run did not complete: {run.status} ({run.reason})' From 32dc3bd85134b0362b234abb03d7ebae04bb5ac6 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 1 Dec 2023 09:39:54 -0800 Subject: [PATCH 10/14] Add codeowners (#770) * add codeowners * precommit --- .github/CODEOWNERS | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..bbdd4259cd --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,8 @@ +# Require admin approval to modify all files in the root of the repository +# This includes setup.py, the README, and the CODEOWNERS file itself! +/* @mosaicml/composer-team-admins + +# Require admin approval to change the CI build configuration +# All CI Changes should be reviewed for security +/.ci/ @mosaicml/composer-team-admins +/.github/ @mosaicml/composer-team-admins From 6ac01ef73f7a461751c1358b217900ec2f39217a Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Fri, 1 Dec 2023 13:18:34 -0800 Subject: [PATCH 11/14] add single value support to activation_checkpointing_target (#772) * add single value support * check str or list dtype --- llmfoundry/models/mpt/modeling_mpt.py | 6 ++++++ tests/test_fsdp_act_checkpoint.py | 15 ++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index d6b23c04d0..34b8992d3e 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -739,6 +739,12 @@ def fsdp_wrap_fn(self, module: nn.Module) -> bool: def activation_checkpointing_fn(self, module: nn.Module) -> bool: act_ckpt_list = getattr(self.config, 'activation_checkpointing_target', None) or ['MPTBlock'] + if isinstance(act_ckpt_list, str): + act_ckpt_list = [act_ckpt_list] + elif not isinstance(act_ckpt_list, list): + raise ValueError( + f'activation_checkpointing_target must be either a single string or a list, but got {type(act_ckpt_list)}' + ) if 'MPTBlock' in act_ckpt_list or 'mptblock' in act_ckpt_list: if len(act_ckpt_list) > 1: diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/test_fsdp_act_checkpoint.py index 3b9a746708..a7e41a3fc2 100644 --- a/tests/test_fsdp_act_checkpoint.py +++ b/tests/test_fsdp_act_checkpoint.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Union + import pytest from composer import Trainer from composer.utils import get_device, using_torch_2 @@ -14,11 +16,12 @@ @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('activation_checkpointing', [True, False]) -@pytest.mark.parametrize( - 'activation_checkpointing_target', - [[], ['grouped_query_attention'], ['mptblock', 'grouped_query_attention']]) +@pytest.mark.parametrize('activation_checkpointing_target', [ + 'grouped_query_attention', [], ['grouped_query_attention'], + ['mptblock', 'grouped_query_attention'] +]) def test_fsdp_act_checkpoint(activation_checkpointing: bool, - activation_checkpointing_target: list): + activation_checkpointing_target: Union[list, str]): device = get_device('gpu') model_cfg = { 'name': 'mpt_causal_lm', @@ -66,7 +69,9 @@ def test_fsdp_act_checkpoint(activation_checkpointing: bool, module = trainer.state.model.model._fsdp_wrapped_module.transformer.blocks[ 0]._fsdp_wrapped_module._fpw_module assert isinstance(module, CheckpointWrapper) - elif activation_checkpointing_target == ['grouped_query_attention']: + elif activation_checkpointing_target == [ + 'grouped_query_attention' + ] or activation_checkpointing_target == 'grouped_query_attention': assert isinstance( trainer.state.model.model._fsdp_wrapped_module.transformer. blocks[0]._fsdp_wrapped_module.attn, CheckpointWrapper) From 6a058a60a603f06f1c32ebcfee156259ef1230db Mon Sep 17 00:00:00 2001 From: Anna Date: Fri, 1 Dec 2023 17:02:04 -0800 Subject: [PATCH 12/14] Reorganize tests to make them easier to find (#768) * Add eval loader to eval script * small input tests * updates * fix typing and formatting * fixes, add tests * remove circular dependency * tests pass * nits + small fixes * add metrics at the end, refactor to put icl/gauntlet as helpers * NOT * metrics instead of models, add unit tests * Move tests into directories * add copyright to inits * fix relative paths * fixes * revert gauntlet test change * Support inputs_embeds (#687) * support inputs_embeds * update tests to test inputs_embeds * make iids optional inputs to fwd * remove check for both iids and inputs_embeds in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead * reorder kwargs * add more tests * fix device merge artifact in test_model.oy * fix generate test * yapf * Better error message when test does not complete (#769) * run script tests first * comment out * ascripts -> scripts * bad dirs * try this * hacks * add a note about a_scripts --------- Co-authored-by: Sam Havens --- scripts/inference/convert_composer_to_hf.py | 30 +++++++---- tests/a_scripts/__init__.py | 6 +++ tests/a_scripts/data_prep/__init__.py | 2 + .../data_prep/test_convert_dataset_hf.py | 28 ++++++++++ .../data_prep/test_convert_dataset_json.py | 27 ++++++++++ .../data_prep}/test_convert_text_to_mds.py | 8 +-- tests/a_scripts/eval/__init__.py | 2 + tests/{ => a_scripts/eval}/test_eval.py | 34 ++++++------ .../{ => a_scripts/eval}/test_eval_inputs.py | 23 ++++---- tests/a_scripts/inference/__init__.py | 2 + .../inference/test_convert_composer_to_hf.py} | 22 +++----- tests/a_scripts/train/__init__.py | 2 + .../train/test_train.py} | 21 ++------ .../train}/test_train_inputs.py | 9 +--- tests/callbacks/__init__.py | 2 + .../test_eval_gauntlet_callback.py} | 0 tests/data/__init__.py | 2 + tests/{ => data}/test_dataloader.py | 5 -- tests/{ => data}/test_icl_datasets.py | 2 +- tests/{ => data}/test_packing.py | 0 tests/{ => data}/test_tasks.yaml | 0 tests/data_utils.py | 21 ++++---- tests/fixtures/autouse.py | 11 ++++ tests/models/__init__.py | 2 + tests/models/hf/__init__.py | 2 + tests/{ => models/hf}/test_hf_config.py | 0 tests/{ => models/hf}/test_hf_mpt_gen.py | 0 tests/{ => models/hf}/test_hf_v_mpt.py | 0 .../models/inference_api_wrapper/__init__.py | 2 + .../test_inference_api_eval_wrapper.py | 0 tests/models/layers/__init__.py | 2 + .../layers}/test_flash_triton_torch.py | 0 .../layers}/test_huggingface_flash.py | 0 .../{ => models}/test_fsdp_act_checkpoint.py | 0 tests/{ => models}/test_model.py | 0 tests/{ => models}/test_mpt_gen.py | 0 tests/{ => models}/test_onnx.py | 0 tests/{ => models}/test_rope_dail_vs_hf.py | 0 tests/models/utils/__init__.py | 2 + .../utils/test_param_init_fns.py} | 0 tests/optim/__init__.py | 2 + tests/{ => optim}/test_lion8b.py | 0 tests/{ => optim}/test_scheduler.py | 0 tests/test_data_prep_scripts.py | 52 ------------------- tests/tokenizers/__init__.py | 2 + tests/{ => tokenizers}/test_tiktoken.py | 3 +- tests/{ => tokenizers}/test_tokenizer.py | 0 tests/utils/__init__.py | 2 + tests/{ => utils}/test_builders.py | 0 .../{ => utils}/test_model_download_utils.py | 0 tests/{ => utils}/test_prompt_files.py | 0 51 files changed, 176 insertions(+), 154 deletions(-) create mode 100644 tests/a_scripts/__init__.py create mode 100644 tests/a_scripts/data_prep/__init__.py create mode 100644 tests/a_scripts/data_prep/test_convert_dataset_hf.py create mode 100644 tests/a_scripts/data_prep/test_convert_dataset_json.py rename tests/{ => a_scripts/data_prep}/test_convert_text_to_mds.py (98%) create mode 100644 tests/a_scripts/eval/__init__.py rename tests/{ => a_scripts/eval}/test_eval.py (89%) rename tests/{ => a_scripts/eval}/test_eval_inputs.py (86%) create mode 100644 tests/a_scripts/inference/__init__.py rename tests/{test_hf_conversion_script.py => a_scripts/inference/test_convert_composer_to_hf.py} (99%) create mode 100644 tests/a_scripts/train/__init__.py rename tests/{test_training.py => a_scripts/train/test_train.py} (90%) rename tests/{ => a_scripts/train}/test_train_inputs.py (96%) create mode 100644 tests/callbacks/__init__.py rename tests/{test_eval_gauntlet.py => callbacks/test_eval_gauntlet_callback.py} (100%) create mode 100644 tests/data/__init__.py rename tests/{ => data}/test_dataloader.py (99%) rename tests/{ => data}/test_icl_datasets.py (98%) rename tests/{ => data}/test_packing.py (100%) rename tests/{ => data}/test_tasks.yaml (100%) create mode 100644 tests/models/__init__.py create mode 100644 tests/models/hf/__init__.py rename tests/{ => models/hf}/test_hf_config.py (100%) rename tests/{ => models/hf}/test_hf_mpt_gen.py (100%) rename tests/{ => models/hf}/test_hf_v_mpt.py (100%) create mode 100644 tests/models/inference_api_wrapper/__init__.py rename tests/{ => models/inference_api_wrapper}/test_inference_api_eval_wrapper.py (100%) create mode 100644 tests/models/layers/__init__.py rename tests/{ => models/layers}/test_flash_triton_torch.py (100%) rename tests/{ => models/layers}/test_huggingface_flash.py (100%) rename tests/{ => models}/test_fsdp_act_checkpoint.py (100%) rename tests/{ => models}/test_model.py (100%) rename tests/{ => models}/test_mpt_gen.py (100%) rename tests/{ => models}/test_onnx.py (100%) rename tests/{ => models}/test_rope_dail_vs_hf.py (100%) create mode 100644 tests/models/utils/__init__.py rename tests/{test_init_fn.py => models/utils/test_param_init_fns.py} (100%) create mode 100644 tests/optim/__init__.py rename tests/{ => optim}/test_lion8b.py (100%) rename tests/{ => optim}/test_scheduler.py (100%) delete mode 100644 tests/test_data_prep_scripts.py create mode 100644 tests/tokenizers/__init__.py rename tests/{ => tokenizers}/test_tiktoken.py (99%) rename tests/{ => tokenizers}/test_tokenizer.py (100%) create mode 100644 tests/utils/__init__.py rename tests/{ => utils}/test_builders.py (100%) rename tests/{ => utils}/test_model_download_utils.py (100%) rename tests/{ => utils}/test_prompt_files.py (100%) diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 1b43762473..51afb105c8 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -168,19 +168,11 @@ def parse_args() -> Namespace: return parser.parse_args() -def convert_composer_to_hf(args: Namespace) -> None: +def _convert_composer_to_hf(args: Namespace) -> None: print() print('#' * 30) print('Converting Composer checkpoint to HuggingFace checkpoint format...') - # Register MPT auto classes so that this script works with MPT - # This script will not work without modification for other custom models, - # but will work for other HuggingFace causal LMs - from transformers.models.auto.configuration_auto import CONFIG_MAPPING - CONFIG_MAPPING._extra_content['mpt'] = MPTConfig - MPTConfig.register_for_auto_class() - MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') - _, _, local_folder_path = parse_uri(args.hf_output_path) config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint( @@ -296,5 +288,25 @@ def convert_composer_to_hf(args: Namespace) -> None: ) +def convert_composer_to_hf(args: Namespace) -> None: + # Register MPT auto classes so that this script works with MPT + # This script will not work without modification for other custom models, + # but will work for other HuggingFace causal LMs + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + CONFIG_MAPPING._extra_content['mpt'] = MPTConfig + MPTConfig.register_for_auto_class() + MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') + + try: + _convert_composer_to_hf(args) + except Exception as e: + raise e + finally: + # Undo auto registration after running the script + del CONFIG_MAPPING._extra_content['mpt'] + delattr(MPTConfig, '_auto_class') + delattr(MPTForCausalLM, '_auto_class') + + if __name__ == '__main__': convert_composer_to_hf(parse_args()) diff --git a/tests/a_scripts/__init__.py b/tests/a_scripts/__init__.py new file mode 100644 index 0000000000..eb5c1d149e --- /dev/null +++ b/tests/a_scripts/__init__.py @@ -0,0 +1,6 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# TODO: This test directory is called "a_scripts" to enforce that these tests are run +# first. More clean up should be done to ensure tests can be run in any order and +# don't leave around artifacts diff --git a/tests/a_scripts/data_prep/__init__.py b/tests/a_scripts/data_prep/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/a_scripts/data_prep/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/a_scripts/data_prep/test_convert_dataset_hf.py b/tests/a_scripts/data_prep/test_convert_dataset_hf.py new file mode 100644 index 0000000000..f226b0a4be --- /dev/null +++ b/tests/a_scripts/data_prep/test_convert_dataset_hf.py @@ -0,0 +1,28 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from argparse import Namespace +from pathlib import Path + +from scripts.data_prep.convert_dataset_hf import main as main_hf + + +def test_download_script_from_api(tmp_path: Path): + # test calling it directly + path = os.path.join(tmp_path, 'my-copy-c4-1') + main_hf( + Namespace( + **{ + 'dataset': 'c4', + 'data_subset': 'en', + 'splits': ['val_xsmall'], + 'out_root': path, + 'compression': None, + 'concat_tokens': None, + 'bos_text': None, + 'eos_text': None, + 'no_wrap': False, + 'num_workers': None + })) + assert os.path.exists(path) diff --git a/tests/a_scripts/data_prep/test_convert_dataset_json.py b/tests/a_scripts/data_prep/test_convert_dataset_json.py new file mode 100644 index 0000000000..179b8a701b --- /dev/null +++ b/tests/a_scripts/data_prep/test_convert_dataset_json.py @@ -0,0 +1,27 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from argparse import Namespace +from pathlib import Path + +from scripts.data_prep.convert_dataset_json import main as main_json + + +def test_json_script_from_api(tmp_path: Path): + # test calling it directly + path = os.path.join(tmp_path, 'my-copy-arxiv-1') + main_json( + Namespace( + **{ + 'path': 'scripts/data_prep/example_data/arxiv.jsonl', + 'out_root': path, + 'compression': None, + 'split': 'train', + 'concat_tokens': None, + 'bos_text': None, + 'eos_text': None, + 'no_wrap': False, + 'num_workers': None + })) + assert os.path.exists(path) diff --git a/tests/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py similarity index 98% rename from tests/test_convert_text_to_mds.py rename to tests/a_scripts/data_prep/test_convert_text_to_mds.py index ab8c25bc2d..cc293a2cdd 100644 --- a/tests/test_convert_text_to_mds.py +++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py @@ -2,13 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import os -import sys - -import pytest - -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) import pathlib from concurrent.futures import ProcessPoolExecutor from glob import glob @@ -16,6 +9,7 @@ from unittest.mock import Mock, patch import numpy as np +import pytest from streaming import StreamingDataset from transformers import AutoTokenizer diff --git a/tests/a_scripts/eval/__init__.py b/tests/a_scripts/eval/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/a_scripts/eval/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_eval.py b/tests/a_scripts/eval/test_eval.py similarity index 89% rename from tests/test_eval.py rename to tests/a_scripts/eval/test_eval.py index 2fc96bb7ad..e8d86903dc 100644 --- a/tests/test_eval.py +++ b/tests/a_scripts/eval/test_eval.py @@ -4,8 +4,7 @@ import copy import os import pathlib -import sys -from typing import Any +from typing import Any, Union import omegaconf as om import pytest @@ -14,15 +13,10 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.utils import build_tokenizer +from scripts.eval.eval import main # noqa: E402 from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall, gpt_tiny_cfg) -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - -from scripts.eval.eval import main # noqa: E402 - @pytest.fixture(autouse=True) def set_correct_cwd(): @@ -35,11 +29,16 @@ def set_correct_cwd(): os.chdir('..') -@pytest.fixture() -def mock_saved_model_path(): - # load the eval and model config - with open('eval/yamls/test_eval.yaml', 'r', encoding='utf-8') as f: +@pytest.fixture +def eval_cfg(foundry_dir: str) -> Union[om.ListConfig, om.DictConfig]: + yaml_path = os.path.join(foundry_dir, 'scripts/eval/yamls/test_eval.yaml') + with open(yaml_path, 'r', encoding='utf-8') as f: eval_cfg = om.OmegaConf.load(f) + return eval_cfg + + +@pytest.fixture() +def mock_saved_model_path(eval_cfg: Union[om.ListConfig, om.DictConfig]): model_cfg = eval_cfg.models[0] # set device to cpu device = 'cpu' @@ -60,12 +59,11 @@ def mock_saved_model_path(): os.remove(saved_model_path) -def test_icl_eval(capfd: Any, mock_saved_model_path: Any): - with open('eval/yamls/test_eval.yaml', 'r', encoding='utf-8') as f: - test_cfg = om.OmegaConf.load(f) - test_cfg.models[0].load_path = mock_saved_model_path - assert isinstance(test_cfg, om.DictConfig) - main(test_cfg) +def test_icl_eval(eval_cfg: Union[om.ListConfig, om.DictConfig], capfd: Any, + mock_saved_model_path: Any): + eval_cfg.models[0].load_path = mock_saved_model_path + assert isinstance(eval_cfg, om.DictConfig) + main(eval_cfg) out, _ = capfd.readouterr() expected_results = '| Category | Benchmark | Subtask | Accuracy | Number few shot | Model |\n|:----------------------------|:---------------|:----------|-----------:|:------------------|:---------|\n| language_understanding_lite | lambada_openai | | 0 | 0-shot | tiny_mpt |' assert expected_results in out diff --git a/tests/test_eval_inputs.py b/tests/a_scripts/eval/test_eval_inputs.py similarity index 86% rename from tests/test_eval_inputs.py rename to tests/a_scripts/eval/test_eval_inputs.py index 83104b62b7..8694546c4f 100644 --- a/tests/test_eval_inputs.py +++ b/tests/a_scripts/eval/test_eval_inputs.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy import os -import sys import warnings import omegaconf @@ -10,10 +9,6 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - from scripts.eval.eval import main # noqa: E402 @@ -21,10 +16,12 @@ class TestHuggingFaceEvalYAMLInputs: """Validate and tests error handling for the input YAML file.""" @pytest.fixture - def cfg(self) -> DictConfig: + def cfg(self, foundry_dir: str) -> DictConfig: """Create YAML cfg fixture for testing purposes.""" - conf_path: str = os.path.join(repo_dir, - 'scripts/eval/yamls/hf_eval.yaml') + conf_path: str = os.path.join( + foundry_dir, + 'scripts/eval/yamls/hf_eval.yaml', + ) with open(conf_path, 'r', encoding='utf-8') as config: test_cfg = om.load(config) assert isinstance(test_cfg, DictConfig) @@ -78,15 +75,17 @@ def test_optional_mispelled_params_raise_warning(self, class TestMPTEvalYAMLInputs: @pytest.fixture - def cfg(self) -> DictConfig: + def cfg(self, foundry_dir: str) -> DictConfig: """Create YAML cfg fixture for testing purposes.""" - conf_path: str = os.path.join(repo_dir, - 'scripts/eval/yamls/mpt_eval.yaml') + conf_path: str = os.path.join( + foundry_dir, + 'scripts/eval/yamls/mpt_eval.yaml', + ) with open(conf_path, 'r', encoding='utf-8') as config: test_cfg = om.load(config) test_cfg.icl_tasks[0].dataset_uri = os.path.join( - repo_dir, 'scripts', test_cfg.icl_tasks[0].dataset_uri) + foundry_dir, 'scripts', test_cfg.icl_tasks[0].dataset_uri) # make tests use cpu initialized transformer models only test_cfg.models[0].model.init_device = 'cpu' diff --git a/tests/a_scripts/inference/__init__.py b/tests/a_scripts/inference/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/a_scripts/inference/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_hf_conversion_script.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py similarity index 99% rename from tests/test_hf_conversion_script.py rename to tests/a_scripts/inference/test_convert_composer_to_hf.py index f9191cd701..d21c942dee 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -4,34 +4,26 @@ import math import os import pathlib -import sys -from typing import Callable -from unittest.mock import ANY, MagicMock, patch - -from composer import Trainer -from composer.loggers import MLFlowLogger -from composer.utils import dist, get_device, using_torch_2 - -from llmfoundry.callbacks import HuggingFaceCheckpointer -from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM - -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) import shutil from argparse import Namespace -from typing import Optional, cast +from typing import Callable, Optional, cast +from unittest.mock import ANY, MagicMock, patch import pytest import torch import transformers +from composer import Trainer +from composer.loggers import MLFlowLogger +from composer.utils import dist, get_device, using_torch_2 from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.utils.data import DataLoader from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry import COMPOSER_MODEL_REGISTRY +from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.data.finetuning import build_finetuning_dataloader +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM from llmfoundry.utils.builders import build_optimizer, build_tokenizer from scripts.inference.convert_composer_to_hf import convert_composer_to_hf from tests.data_utils import make_tiny_ft_dataset diff --git a/tests/a_scripts/train/__init__.py b/tests/a_scripts/train/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/a_scripts/train/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_training.py b/tests/a_scripts/train/test_train.py similarity index 90% rename from tests/test_training.py rename to tests/a_scripts/train/test_train.py index 3cd2963100..62075383cc 100644 --- a/tests/test_training.py +++ b/tests/a_scripts/train/test_train.py @@ -1,9 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import copy -import os import pathlib -from typing import Any, Optional +from typing import Optional import pytest from composer.loggers import InMemoryLogger @@ -16,22 +15,10 @@ gpt_tiny_cfg) -@pytest.fixture(autouse=False) -def set_correct_cwd(): - if not os.getcwd().endswith('llm-foundry/scripts'): - os.chdir('scripts') - - yield - - if os.getcwd().endswith('llm-foundry/scripts'): - os.chdir('..') - - @pytest.mark.parametrize('averages', [{ 'core_average': ['language_understanding_lite'] }, None]) -def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, - tmp_path: pathlib.Path): +def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): """Test training run with a small dataset.""" dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(dataset_name, 'cpu') @@ -40,7 +27,7 @@ def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, 'label': 'lambada_openai', 'dataset_uri': - 'eval/local_data/language_understanding/lambada_openai_small.jsonl', + 'scripts/eval/local_data/language_understanding/lambada_openai_small.jsonl', 'num_fewshot': [0], 'icl_task_type': 'language_modeling' @@ -110,7 +97,7 @@ def test_train_gauntlet(averages: Optional[dict], set_correct_cwd: Any, -1][-1] == 0 -def test_train_multi_eval(set_correct_cwd: Any, tmp_path: pathlib.Path): +def test_train_multi_eval(tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" c4_dataset_name = create_c4_dataset_xxsmall(tmp_path) test_cfg = gpt_tiny_cfg(c4_dataset_name, 'cpu') diff --git a/tests/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py similarity index 96% rename from tests/test_train_inputs.py rename to tests/a_scripts/train/test_train_inputs.py index 2ed1c9c239..17eca26587 100644 --- a/tests/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -3,7 +3,6 @@ import copy import json import os -import sys import warnings import omegaconf @@ -11,10 +10,6 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - from scripts.train.train import main # noqa: E402 @@ -54,10 +49,10 @@ class TestTrainingYAMLInputs: """Validate and tests error handling for the input YAML file.""" @pytest.fixture - def cfg(self) -> DictConfig: + def cfg(self, foundry_dir: str) -> DictConfig: """Create YAML cfg fixture for testing purposes.""" conf_path: str = os.path.join( - repo_dir, 'scripts/train/yamls/pretrain/testing.yaml') + foundry_dir, 'scripts/train/yamls/pretrain/testing.yaml') with open(conf_path, 'r', encoding='utf-8') as config: test_cfg = om.load(config) assert isinstance(test_cfg, DictConfig) diff --git a/tests/callbacks/__init__.py b/tests/callbacks/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/callbacks/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_eval_gauntlet.py b/tests/callbacks/test_eval_gauntlet_callback.py similarity index 100% rename from tests/test_eval_gauntlet.py rename to tests/callbacks/test_eval_gauntlet_callback.py diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/data/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_dataloader.py b/tests/data/test_dataloader.py similarity index 99% rename from tests/test_dataloader.py rename to tests/data/test_dataloader.py index 2e9039644b..0f5f506e22 100644 --- a/tests/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -5,7 +5,6 @@ import pathlib import random import shutil -import sys import tempfile from argparse import Namespace from typing import Literal, Optional, Union @@ -26,10 +25,6 @@ build_text_dataloader, get_tokens_per_batch_func) from llmfoundry.utils.builders import build_tokenizer - -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) from scripts.data_prep.convert_dataset_hf import main as main_hf from tests.data_utils import make_tiny_ft_dataset diff --git a/tests/test_icl_datasets.py b/tests/data/test_icl_datasets.py similarity index 98% rename from tests/test_icl_datasets.py rename to tests/data/test_icl_datasets.py index 28d12df91d..3a730fdf19 100644 --- a/tests/test_icl_datasets.py +++ b/tests/data/test_icl_datasets.py @@ -10,7 +10,7 @@ from llmfoundry.utils.builders import build_icl_evaluators -def load_icl_config(conf_path: str = 'tests/test_tasks.yaml'): +def load_icl_config(conf_path: str = 'tests/data/test_tasks.yaml'): with open(conf_path) as f: test_cfg = om.load(f) return test_cfg diff --git a/tests/test_packing.py b/tests/data/test_packing.py similarity index 100% rename from tests/test_packing.py rename to tests/data/test_packing.py diff --git a/tests/test_tasks.yaml b/tests/data/test_tasks.yaml similarity index 100% rename from tests/test_tasks.yaml rename to tests/data/test_tasks.yaml diff --git a/tests/data_utils.py b/tests/data_utils.py index efb4f6d7cf..a0ad6bcd13 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -1,14 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import os -import sys - -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) - import json +import os import pathlib import shutil from argparse import Namespace @@ -120,10 +114,14 @@ def create_arxiv_dataset(path: pathlib.Path) -> str: arxiv_dir = os.path.join(path, f'my-copy-arxiv') downloaded_split = 'train' + arxiv_path = 'data_prep/example_data/arxiv.jsonl' + if not os.getcwd().endswith('scripts'): + arxiv_path = os.path.join('scripts', arxiv_path) + main_json( Namespace( **{ - 'path': 'data_prep/example_data/arxiv.jsonl', + 'path': arxiv_path, 'out_root': arxiv_dir, 'compression': None, 'split': downloaded_split, @@ -139,8 +137,11 @@ def create_arxiv_dataset(path: pathlib.Path) -> str: def gpt_tiny_cfg(dataset_name: str, device: str): """Create gpt tiny cfg.""" - conf_path: str = os.path.join(repo_dir, - 'scripts/train/yamls/pretrain/testing.yaml') + from tests.fixtures.autouse import REPO_DIR + conf_path: str = os.path.join( + REPO_DIR, + 'scripts/train/yamls/pretrain/testing.yaml', + ) with open(conf_path) as f: test_cfg = om.load(f) assert isinstance(test_cfg, DictConfig) diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index c51ccfacb0..75caa6c941 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -2,11 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import gc +import os +import sys import pytest import torch from composer.utils import dist, get_device, reproducibility +# Add llm-foundry repo root to path so we can import scripts in the tests +REPO_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) +sys.path.append(REPO_DIR) + @pytest.fixture(autouse=True) def initialize_dist(request: pytest.FixtureRequest): @@ -33,6 +39,11 @@ def random_seed() -> int: return 17 +@pytest.fixture +def foundry_dir() -> str: + return REPO_DIR + + @pytest.fixture(autouse=True) def seed_all(random_seed: int): """Sets the seed for reproducibility.""" diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/models/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/models/hf/__init__.py b/tests/models/hf/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/models/hf/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_hf_config.py b/tests/models/hf/test_hf_config.py similarity index 100% rename from tests/test_hf_config.py rename to tests/models/hf/test_hf_config.py diff --git a/tests/test_hf_mpt_gen.py b/tests/models/hf/test_hf_mpt_gen.py similarity index 100% rename from tests/test_hf_mpt_gen.py rename to tests/models/hf/test_hf_mpt_gen.py diff --git a/tests/test_hf_v_mpt.py b/tests/models/hf/test_hf_v_mpt.py similarity index 100% rename from tests/test_hf_v_mpt.py rename to tests/models/hf/test_hf_v_mpt.py diff --git a/tests/models/inference_api_wrapper/__init__.py b/tests/models/inference_api_wrapper/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/models/inference_api_wrapper/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_inference_api_eval_wrapper.py b/tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py similarity index 100% rename from tests/test_inference_api_eval_wrapper.py rename to tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py diff --git a/tests/models/layers/__init__.py b/tests/models/layers/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/models/layers/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py similarity index 100% rename from tests/test_flash_triton_torch.py rename to tests/models/layers/test_flash_triton_torch.py diff --git a/tests/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py similarity index 100% rename from tests/test_huggingface_flash.py rename to tests/models/layers/test_huggingface_flash.py diff --git a/tests/test_fsdp_act_checkpoint.py b/tests/models/test_fsdp_act_checkpoint.py similarity index 100% rename from tests/test_fsdp_act_checkpoint.py rename to tests/models/test_fsdp_act_checkpoint.py diff --git a/tests/test_model.py b/tests/models/test_model.py similarity index 100% rename from tests/test_model.py rename to tests/models/test_model.py diff --git a/tests/test_mpt_gen.py b/tests/models/test_mpt_gen.py similarity index 100% rename from tests/test_mpt_gen.py rename to tests/models/test_mpt_gen.py diff --git a/tests/test_onnx.py b/tests/models/test_onnx.py similarity index 100% rename from tests/test_onnx.py rename to tests/models/test_onnx.py diff --git a/tests/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py similarity index 100% rename from tests/test_rope_dail_vs_hf.py rename to tests/models/test_rope_dail_vs_hf.py diff --git a/tests/models/utils/__init__.py b/tests/models/utils/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/models/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_init_fn.py b/tests/models/utils/test_param_init_fns.py similarity index 100% rename from tests/test_init_fn.py rename to tests/models/utils/test_param_init_fns.py diff --git a/tests/optim/__init__.py b/tests/optim/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/optim/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_lion8b.py b/tests/optim/test_lion8b.py similarity index 100% rename from tests/test_lion8b.py rename to tests/optim/test_lion8b.py diff --git a/tests/test_scheduler.py b/tests/optim/test_scheduler.py similarity index 100% rename from tests/test_scheduler.py rename to tests/optim/test_scheduler.py diff --git a/tests/test_data_prep_scripts.py b/tests/test_data_prep_scripts.py deleted file mode 100644 index 4fe5ed7e64..0000000000 --- a/tests/test_data_prep_scripts.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2022 MosaicML LLM Foundry authors -# SPDX-License-Identifier: Apache-2.0 - -import os -import sys -from argparse import Namespace -from pathlib import Path - -# Add repo root to path so we can import scripts and test it -repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(repo_dir) -from scripts.data_prep.convert_dataset_hf import main as main_hf -from scripts.data_prep.convert_dataset_json import main as main_json - - -def test_download_script_from_api(tmp_path: Path): - # test calling it directly - path = os.path.join(tmp_path, 'my-copy-c4-1') - main_hf( - Namespace( - **{ - 'dataset': 'c4', - 'data_subset': 'en', - 'splits': ['val_xsmall'], - 'out_root': path, - 'compression': None, - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None - })) - assert os.path.exists(path) - - -def test_json_script_from_api(tmp_path: Path): - # test calling it directly - path = os.path.join(tmp_path, 'my-copy-arxiv-1') - main_json( - Namespace( - **{ - 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': path, - 'compression': None, - 'split': 'train', - 'concat_tokens': None, - 'bos_text': None, - 'eos_text': None, - 'no_wrap': False, - 'num_workers': None - })) - assert os.path.exists(path) diff --git a/tests/tokenizers/__init__.py b/tests/tokenizers/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/tokenizers/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_tiktoken.py b/tests/tokenizers/test_tiktoken.py similarity index 99% rename from tests/test_tiktoken.py rename to tests/tokenizers/test_tiktoken.py index fe3db41d50..60907092c8 100644 --- a/tests/test_tiktoken.py +++ b/tests/tokenizers/test_tiktoken.py @@ -9,8 +9,9 @@ from llmfoundry.tokenizers.tiktoken import (TiktokenTokenizerWrapper, bytes_to_unicode) +from tests.a_scripts.inference.test_convert_composer_to_hf import \ + check_hf_tokenizer_equivalence from tests.horrible_strings import HORRIBLE_STRINGS -from tests.test_hf_conversion_script import check_hf_tokenizer_equivalence if TYPE_CHECKING: from tiktoken.core import Encoding diff --git a/tests/test_tokenizer.py b/tests/tokenizers/test_tokenizer.py similarity index 100% rename from tests/test_tokenizer.py rename to tests/tokenizers/test_tokenizer.py diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_builders.py b/tests/utils/test_builders.py similarity index 100% rename from tests/test_builders.py rename to tests/utils/test_builders.py diff --git a/tests/test_model_download_utils.py b/tests/utils/test_model_download_utils.py similarity index 100% rename from tests/test_model_download_utils.py rename to tests/utils/test_model_download_utils.py diff --git a/tests/test_prompt_files.py b/tests/utils/test_prompt_files.py similarity index 100% rename from tests/test_prompt_files.py rename to tests/utils/test_prompt_files.py From b2e4b0e2f55e0e6e01a3ca94e956298fe3fc581c Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 1 Dec 2023 17:27:51 -0800 Subject: [PATCH 13/14] Add "completion" alias for response key (#771) --- llmfoundry/data/finetuning/tasks.py | 39 ++++++++++++++++++++------- tests/data/test_dataloader.py | 42 ++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index a29dee7683..4b80ffef54 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -47,25 +47,46 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: __all__ = ['dataset_constructor'] +_ALLOWED_RESPONSE_KEYS = {'response', 'completion'} +_ALLOWED_PROMPT_KEYS = {'prompt'} + def _tokenize_formatted_example( example: Dict[str, Any], tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]: - if ('prompt' not in example) or ('response' not in example): + """Tokenize a formatted example and validate expected keys.""" + example_keys = set(example.keys()) + prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) + response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS) + + if len(prompt_keys) != 1: + raise KeyError( + f'Unable to tokenize example because {len(prompt_keys)} of the allowed prompt keys ' +\ + f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_PROMPT_KEYS=}' + ) + + if len(response_keys) != 1: raise KeyError( - 'Unable to tokenize example because it has not been properly formatted. ' +\ - '"prompt" and "response" are required keys but at least one was missing ' +\ - f'from {example=}.' + f'Unable to tokenize example because {len(response_keys)} of the allowed response keys ' +\ + f'were present in {example_keys=}. Please specify exactly one. {_ALLOWED_RESPONSE_KEYS=}' ) - if not isinstance(example['prompt'], str): + + prompt_key = prompt_keys.pop() + response_key = response_keys.pop() + prompt = example[prompt_key] + response = example[response_key] + + if not isinstance(prompt, str): raise TypeError( - f'Unable to tokenize example because "prompt" was not a string. {example=}' + f'Unable to tokenize example because {prompt_key} was not a string. {example=}' ) - if not isinstance(example['response'], str): + + if not isinstance(response, str): raise TypeError( - f'Unable to tokenize example because "response" was not a string. {example=}' + f'Unable to tokenize example because {response_key} was not a string. {example=}' ) - return tokenizer(text=example['prompt'], text_target=example['response']) + + return tokenizer(text=prompt, text_target=response) class StreamingFinetuningDataset(StreamingDataset): diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 0f5f506e22..747021e82a 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -21,6 +21,9 @@ from llmfoundry import (build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data import build_dataloader +from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, + _ALLOWED_RESPONSE_KEYS, + _tokenize_formatted_example) from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper, build_text_dataloader, get_tokens_per_batch_func) @@ -355,10 +358,8 @@ def test_finetuning_dataloader_small_data(dataset_size: int, if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last: error_context = pytest.raises(ValueError, match='Your dataset') if invalid_dataset: - error_context = pytest.raises( - TypeError, - match='Unable to tokenize example because "prompt" was not a string' - ) + error_context = pytest.raises(TypeError, + match='Unable to tokenize example') with error_context: _ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size) @@ -367,6 +368,39 @@ def test_finetuning_dataloader_small_data(dataset_size: int, shutil.rmtree(tiny_dataset_folder_path) +def test_tokenize_example_malformed(): + no_keys = {} + no_prompt_key = {'response': 'response'} + no_response_key = {'prompt': 'prompt'} + extra_keys_with_prompt = {'prompt': 'prompt', 'extra': 'extra'} + extra_keys_with_response = {'response': 'response', 'extra': 'extra'} + multiple_allowed_response_keys = { + 'prompt': 'prompt', + 'response': 'response', + 'completion': 'completion' + } + + malformed_examples = [ + no_keys, no_prompt_key, no_response_key, extra_keys_with_prompt, + extra_keys_with_response, multiple_allowed_response_keys + ] + + for example in malformed_examples: + with pytest.raises(KeyError): + _tokenize_formatted_example(example, MagicMock()) + + +def test_tokenize_example_well_formed(): + tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') + + for prompt_key in _ALLOWED_PROMPT_KEYS: + for response_key in _ALLOWED_RESPONSE_KEYS: + example = {prompt_key: 'prompt', response_key: 'response'} + tokenized_example = _tokenize_formatted_example(example, tokenizer) + assert 'input_ids' in tokenized_example + assert 'labels' in tokenized_example + + @pytest.mark.parametrize('split', ['train', 'custom', 'data']) def test_finetuning_dataloader_custom_split(tmp_path: pathlib.Path, split: str): tokenizer_name = 'gpt2' From 84b5d96df9b1ffb7dbd2edf34b0a03fd4fe4220b Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:06:24 -0800 Subject: [PATCH 14/14] Shashank/seq id flash attn (#738) * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * Update llmfoundry/models/layers/attention.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * Update llmfoundry/models/mpt/modeling_mpt.py Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. * .. --------- Co-authored-by: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> --- llmfoundry/models/layers/attention.py | 104 ++++--- llmfoundry/models/layers/blocks.py | 3 + llmfoundry/models/mpt/configuration_mpt.py | 15 +- llmfoundry/models/mpt/modeling_mpt.py | 142 ++++++++-- tests/models/layers/test_flash_attn.py | 255 ++++++++++++++++++ .../models/layers/test_flash_triton_torch.py | 60 ++++- tests/models/test_model.py | 110 ++++++++ 7 files changed, 613 insertions(+), 76 deletions(-) create mode 100644 tests/models/layers/test_flash_attn.py diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index dd7f40cd19..86e49c315d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -92,7 +92,6 @@ def scaled_multihead_dot_product_attention( multiquery: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: - if multiquery: warnings.warn( DeprecationWarning( @@ -219,6 +218,9 @@ def flash_attn_fn( training: bool = False, needs_weights: bool = False, multiquery: bool = False, + attention_mask_in_length: Optional[torch.Tensor] = None, + should_repeat_kv_for_gqa: Optional[bool] = True, + sliding_window_size: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: try: @@ -249,58 +251,65 @@ def flash_attn_fn( past_key_value = (key, value) - if attn_bias is not None: - # clamp to 0 necessary for torch 2.0 compile() - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') batch_size, seqlen = query.shape[:2] - if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) - query_padding_mask = key_padding_mask[:, -query.size(1):] + if attention_mask_in_length is None: + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + unpadding_function = bert_padding.unpad_input + else: + key_padding_mask = attention_mask_in_length + query_padding_mask = attention_mask_in_length + unpadding_function = bert_padding.unpad_input_for_concatenated_sequences - query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = bert_padding.unpad_input( + query_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( query, query_padding_mask) query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) - key_unpad, _, cu_seqlens_k, max_seqlen_k = bert_padding.unpad_input( + key_unpad, _, cu_seqlens_k, max_seqlen_k = unpadding_function( key, key_padding_mask) key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - value_unpad, _, _, _ = bert_padding.unpad_input(value, key_padding_mask) + value_unpad, _, _, _ = unpadding_function(value, key_padding_mask) value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=kv_n_heads) - # multi-query case - if kv_n_heads == 1: - # Expanding a tensor does not allocate new memory, but only creates a new - # view on the existing tensor where a dimension of size one is expanded - # to a larger size by setting the stride to 0. - # - pytorch docs - # - # hopefully the kernels can utilize this and we're jot just wasting BW here - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, - key_unpad.size(-1)) - value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, - value_unpad.size(-1)) - # grouped query case - elif kv_n_heads < n_heads: - # Each query belong to a group of kv heads of group size n_heads // kv_n_heads - # We repeat each kv head by the group size number to use the underlying MHA kernels - - # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d) - # we use .view to modify {key, value}_unpad appropriately + if (kv_n_heads < n_heads) and (not is_flash_v2_installed()) and ( + not should_repeat_kv_for_gqa): + raise ValueError( + 'For Grouped Query Attention or Multi Query Attention, should_repeat_kv_for_gqa should be set to True if not using Flash Attention v2.' + ) - key_unpad = repeat_kv_for_gqa( - key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), - n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1) - value_unpad = repeat_kv_for_gqa( - value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), - n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1) + if should_repeat_kv_for_gqa: + # multi-query case + if kv_n_heads == 1: + # Expanding a tensor does not allocate new memory, but only creates a new + # view on the existing tensor where a dimension of size one is expanded + # to a larger size by setting the stride to 0. + # - pytorch docs + # + # hopefully the kernels can utilize this and we're jot just wasting BW here + key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, + key_unpad.size(-1)) + value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, + value_unpad.size(-1)) + # grouped query case + elif kv_n_heads < n_heads: + # Each query belong to a group of kv heads of group size n_heads // kv_n_heads + # We repeat each kv head by the group size number to use the underlying MHA kernels + + # since repeat_kv_for_gqa expects input dims of (b, s, kv_n_heads, d) + # we use .view to modify {key, value}_unpad appropriately + + key_unpad = repeat_kv_for_gqa( + key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1), + n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1) + value_unpad = repeat_kv_for_gqa( + value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1), + n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1) dropout_p = dropout_p if training else 0.0 @@ -331,7 +340,8 @@ def flash_attn_fn( dropout_p=dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, - return_attn_probs=needs_weights) + return_attn_probs=needs_weights, + window_size=(sliding_window_size, sliding_window_size)) else: raise RuntimeError( 'flash-attn==1.0.9 or flash-attn==2.3.2 is required.') @@ -490,6 +500,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, bias: bool = True, + sliding_window_size: int = -1, ): super().__init__() @@ -500,6 +511,7 @@ def __init__( self.d_model = d_model self.n_heads = n_heads self.kv_n_heads = kv_n_heads + self.sliding_window_size = sliding_window_size self.head_dim = d_model // n_heads @@ -569,6 +581,7 @@ def forward( rotary_emb_w_meta_info: Optional[dict] = None, is_causal: bool = True, needs_weights: bool = False, + attention_mask_in_length: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) @@ -626,6 +639,14 @@ def forward( query = query.view(bsz, seqlen, self.d_model) key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim) + extra_attn_kwargs = {} + if self.attn_impl == 'flash': + extra_attn_kwargs = { + 'attention_mask_in_length': attention_mask_in_length, + 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), + 'sliding_window_size': self.sliding_window_size, + } + context, attn_weights, past_key_value = self.attn_fn( query, key, @@ -640,6 +661,7 @@ def forward( dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, + **extra_attn_kwargs, ) return self.out_proj(context), attn_weights, past_key_value @@ -665,6 +687,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, bias: bool = True, + sliding_window_size: int = -1, ): super().__init__( d_model=d_model, @@ -679,6 +702,7 @@ def __init__( fc_type=fc_type, device=device, bias=bias, + sliding_window_size=sliding_window_size, ) @@ -702,6 +726,7 @@ def __init__( fc_type: str = 'torch', device: Optional[str] = None, bias: bool = True, + sliding_window_size: int = -1, ): super().__init__( d_model=d_model, @@ -716,6 +741,7 @@ def __init__( fc_type=fc_type, device=device, bias=bias, + sliding_window_size=sliding_window_size, ) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 6605807c6b..6db9ff22ca 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -21,6 +21,7 @@ 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, + 'sliding_window_size': -1, 'alibi': False, 'alibi_bias_max': 8, 'rope': False, @@ -113,6 +114,7 @@ def forward( attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, + attention_mask_in_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) @@ -124,6 +126,7 @@ def forward( attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, + attention_mask_in_length=attention_mask_in_length, ) x = x + self.resid_attn_dropout(b) m = x diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index f8022808bf..47fd5ac9e5 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -91,6 +91,7 @@ def __init__( When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates which sub-sequence each token belongs to. Defaults to ``False`` meaning any provided `sequence_id` will be ignored. + sliding_window_size (int): Window size for sliding window local attention. Defaults to -1, which means no sliding window. Query at position i will only attend to keys between [i + seqlen_k - seqlen_q - window_size, i + seqlen_k - seqlen_q + window_size] inclusive. Only works for flash attention v2.3.0 or higher. alibi (bool): Whether to use the alibi bias instead of position embeddings. alibi_bias_max (int): The maximum value of the alibi bias. rope (bool): Whether to use rotary positional embeddings. @@ -221,10 +222,12 @@ def _validate_config(self) -> None: ]: raise NotImplementedError( 'alibi only implemented with torch and triton attention.') - if self.attn_config['attn_uses_sequence_id'] and self.attn_config[ - 'attn_impl'] not in ['torch', 'triton']: + if self.attn_config['attn_uses_sequence_id'] and not ( + self.attn_config['attn_impl'] in ['torch', 'triton'] or + (self.attn_config['attn_impl'] == 'flash' and + is_flash_v2_installed(v2_version='v2.1.2'))): raise NotImplementedError( - 'attn_uses_sequence_id only implemented with torch and triton attention.' + 'attn_uses_sequence_id only implemented with torch, triton, and flash (v2.1.2 or higher) attention.' ) if self.attn_config['rope'] and (self.attn_config['rope_impl'] not in ['dail', 'hf']): @@ -251,6 +254,12 @@ def _validate_config(self) -> None: raise ImportError( 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support' ) + if self.attn_config['sliding_window_size'] != -1 and not ( + self.attn_config['attn_impl'] == 'flash' and + is_flash_v2_installed(v2_version='v2.3.0')): + raise NotImplementedError( + 'sliding window only implemented with flash attention v2.3.0 or higher.' + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 34b8992d3e..e2d2ee6fbc 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -132,6 +132,114 @@ def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int, raise ValueError('rope_impl needs to be either dail or hf') +def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, + attn_uses_sequence_id: bool, attn_impl: str, + attention_mask: Union[torch.Tensor, None]): + """Generates the attention mask used for sequence masking in FA v2. + + Only supports sequence id based sparse attention for no attention masking or attention masking with right padding. + In case of left padding: + 1. Training with left padding is not supported in MPT (see https://github.com/mosaicml/llm-foundry/blob/1eecd4cb8e734499f77f6a35f657b8b20c0adfcb/llmfoundry/models/mpt/modeling_mpt.py#L407). + 2. For generation with left padding, we only have a single sequence id per sample, so we don't need sequence id based sparse attention. + + Args: + sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len). + S (int): Sequence length + attn_uses_sequence_id (bool): Whether the attention uses sequence id based masking. + attn_impl (str): Attention implementation. This function is only creates attention_mask_in_length for flash attention. + attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len) + + Returns: + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + (The description above is taken verbatim from https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/bert_padding.py#L125 .) + """ + attention_mask_in_length = None + if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl + == 'flash'): + # Check if sequence has left padding. If yes, raise an error. + if (attention_mask is not None) and (attention_mask[:, 0].sum() != + attention_mask.shape[0]): + raise NotImplementedError( + 'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.' + ) + if S != sequence_id.shape[-1]: + raise ValueError( + f'Sequence length ({S}) does not match length of sequences in sequence_id ({sequence_id.shape[-1]}).' + ) + attention_mask_in_length = torch.nn.functional.one_hot(sequence_id) + if attention_mask is not None: + attention_mask_in_length = attention_mask_in_length.masked_fill( + ~attention_mask.unsqueeze(-1), 0) + attention_mask_in_length = attention_mask_in_length.sum(dim=1) + attention_mask_in_length = torch.nn.functional.pad( + attention_mask_in_length, + (0, S - attention_mask_in_length.shape[-1]), + mode='constant', + value=0) + + return attention_mask_in_length + + +def apply_sequence_id(attn_bias: torch.Tensor, sequence_id: torch.LongTensor, + max_seq_len: int) -> torch.Tensor: + seq_len = sequence_id.shape[-1] + if seq_len > max_seq_len: + raise ValueError( + f'sequence_id sequence length cannot exceed max_seq_len={max_seq_len}' + ) + + # select seq_len subset of attn mask + attn_bias = attn_bias[..., :seq_len, :seq_len] + + # Restrict attention to tokens that share the same value + # in sequence_id + cannot_attend = torch.logical_not( + torch.eq( + sequence_id.view(-1, seq_len, 1), + sequence_id.view(-1, 1, seq_len), + )).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + + return attn_bias + + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig base_model_prefix = 'model' @@ -286,7 +394,8 @@ def _attn_bias( # If using torch or triton, we incorporate sequence_id (if appropriate) if self.attn_uses_sequence_id and sequence_id is not None: assert isinstance(attn_bias, torch.Tensor) # pyright - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + attn_bias = apply_sequence_id(attn_bias, sequence_id, + self.config.max_seq_len) # If using torch or triton, we incorporate attention_mask. This will output # None in place of attention_mask since it will not be further needed in the @@ -343,29 +452,6 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, return attn_bias - def _apply_sequence_id(self, attn_bias: torch.Tensor, - sequence_id: torch.LongTensor) -> torch.Tensor: - seq_len = sequence_id.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError( - f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}' - ) - - # select seq_len subset of attn mask - attn_bias = attn_bias[..., :seq_len, :seq_len] - - # Restrict attention to tokens that share the same value - # in sequence_id - cannot_attend = torch.logical_not( - torch.eq( - sequence_id.view(-1, seq_len, 1), - sequence_id.view(-1, 1, seq_len), - )).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) - - return attn_bias - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -515,7 +601,12 @@ def forward( prefix_mask=prefix_mask, sequence_id=sequence_id, ) - + attention_mask_in_length = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=S, + attn_uses_sequence_id=self.attn_uses_sequence_id, + attn_impl=self.attn_impl, + attention_mask=attention_mask) # initialize the past key values cache if it should be used presents = () if use_cache else None if use_cache and past_key_values is None: @@ -538,6 +629,7 @@ def forward( attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), + attention_mask_in_length=attention_mask_in_length, ) if presents is not None: presents += (present,) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py new file mode 100644 index 0000000000..acefd2c42d --- /dev/null +++ b/tests/models/layers/test_flash_attn.py @@ -0,0 +1,255 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import math + +import pytest +import torch + +from llmfoundry.models.layers.attention import (flash_attn_fn, + is_flash_v2_installed, + triton_flash_attn_fn) + + +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(), + reason='GQA natively only supported by Flash Attention after v2.') +@pytest.mark.parametrize('kv_n_heads', [1, 4, 8]) +def test_gqa_kv_repetition(kv_n_heads: int): + # Test that flash attention v2 with GQA (kv_n_heads < n_heads) works the same + # whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own. + d = 128 + n_heads = 8 + seqlen_1 = 6 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + kv_n_heads * d).to(torch.bfloat16).cuda() + value_1.requires_grad = True + + output_1, _, _ = flash_attn_fn(query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=True) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + output_2, _, _ = flash_attn_fn(query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=False) + + output_2.sum().backward() + assert torch.allclose(output_1, output_2) + assert torch.allclose(query_1.grad, query_2.grad) # type: ignore + assert torch.allclose(key_1.grad, key_2.grad) # type: ignore + assert torch.allclose(value_1.grad, value_2.grad) # type: ignore + + +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.1.2'), + reason= + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' +) +def test_seq_id_masking_FA_v2(): + # Test that flash attention v2 with sequence id masking works correctly. + d = 128 + n_heads = 4 + kv_n_heads = 4 + seqlen_1 = 6 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(torch.bfloat16).cuda() + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, kv_n_heads * d).to(torch.bfloat16).cuda() + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + kv_n_heads * d).to(torch.bfloat16).cuda() + value_1.requires_grad = True + + seq_ranges = [ + (0, 3), (3, 5), (5, 6) + ] # Each batch has 3 sequences of length 3, 2, and 1 respectively. + attention_mask_in_length_1 = torch.tensor([[3, 2, 1, 0, 0, 0], + [3, 2, 1, 0, 0, + 0]]).to(torch.int64).cuda() + + output_1, _, _ = flash_attn_fn( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=attention_mask_in_length_1) + + output_1.sum().backward() + + for seq_range in seq_ranges: + query_2 = query_1.detach().clone()[:, seq_range[0]:seq_range[1], :] + query_2.requires_grad = True + key_2 = key_1.detach().clone()[:, seq_range[0]:seq_range[1], :] + key_2.requires_grad = True + value_2 = value_1.detach().clone()[:, seq_range[0]:seq_range[1], :] + value_2.requires_grad = True + + output_2, _, _ = flash_attn_fn(query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=kv_n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None) + + output_2.sum().backward() + assert torch.allclose(output_1[:, seq_range[0]:seq_range[1], :], + output_2) + assert torch.allclose( + query_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + query_2.grad) # type: ignore + assert torch.allclose( + key_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + key_2.grad) # type: ignore + assert torch.allclose( + value_1.grad[:, seq_range[0]:seq_range[1], :], # type: ignore + value_2.grad) # type: ignore + + +@pytest.mark.gpu +@pytest.mark.skipif( + not is_flash_v2_installed(v2_version='v2.3.0'), + reason= + 'Sliding window attention only supported by Flash Attention after v2.3.0.') +@pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) +def test_sliding_window(sliding_window_size: int): + # Test that sliding window attention works as expected. + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + n_heads = 8 + seqlen_1 = 8 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, n_heads * d).to(dtype=dtype, + device=device) + value_1.requires_grad = True + + output_1, _, _ = flash_attn_fn(query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + attention_mask_in_length=None, + should_repeat_kv_for_gqa=True, + sliding_window_size=sliding_window_size) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + attn_bias_2 = torch.zeros(1, 1, seqlen_1, seqlen_1).to(dtype=dtype, + device=device) + + window_mask_2 = torch.tril( + torch.ones(seqlen_1, seqlen_1), diagonal=-(sliding_window_size + 1)).to( + dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min + attn_bias_2 = attn_bias_2 + window_mask_2 + output_2, _, _ = triton_flash_attn_fn( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + multiquery=False, + ) + + output_2.sum().backward() + + assert torch.allclose(output_1, output_2) + assert torch.norm(query_2.grad - query_1.grad # type: ignore + ) <= 1e-2 + 1e-2 * torch.norm(query_2.grad) + assert torch.norm(key_2.grad - key_1.grad # type: ignore + ) <= 1e-2 + 1e-2 * torch.norm(key_2.grad) + assert torch.norm(value_2.grad - value_1.grad # type: ignore + ) <= 1e-2 + 1e-2 * torch.norm(value_2.grad) diff --git a/tests/models/layers/test_flash_triton_torch.py b/tests/models/layers/test_flash_triton_torch.py index e140f678bc..454fda311d 100644 --- a/tests/models/layers/test_flash_triton_torch.py +++ b/tests/models/layers/test_flash_triton_torch.py @@ -7,7 +7,9 @@ from llmfoundry.models.layers import attention from llmfoundry.models.layers.attention import is_flash_v2_installed -from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding +from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id, + gen_attention_mask_in_length, + gen_rotary_embedding) def allclose_helper(t0: torch.Tensor, @@ -54,6 +56,7 @@ def allclose_helper(t0: torch.Tensor, @pytest.mark.parametrize( 'attn_type', ['multihead_attention', 'multiquery_attention', 'grouped_query_attention']) +@pytest.mark.parametrize('attn_uses_sequence_id', [True, False]) @pytest.mark.parametrize('pad_attention_mask', [True, False]) def test_attn_impl(attn_impl_0: str, attn_impl_1: str, @@ -61,6 +64,7 @@ def test_attn_impl(attn_impl_0: str, qk_ln: bool, pos_emb_config: dict, attn_type: str, + attn_uses_sequence_id: bool, pad_attention_mask: bool, device: str = 'cuda'): """Compare all attn impl with each other. @@ -77,6 +81,16 @@ def test_attn_impl(attn_impl_0: str, == 'dail') and (not is_flash_v2_installed()): pytest.skip('dail implementation of rope requires flash attention 2.') + if attn_uses_sequence_id and ( + attn_impl_0 == 'flash' or attn_impl_1 + == 'flash') and (not is_flash_v2_installed(v2_version='v2.1.2')): + pytest.skip( + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + ) + + if not (alibi or rope) and attn_uses_sequence_id: + pytest.skip('attn_uses_sequence_id requires alibi or rope.') + cfg = om.create({ 'attn_impl': 'flash', 'd_model': 64, @@ -91,6 +105,14 @@ def test_attn_impl(attn_impl_0: str, if attn_type == 'grouped_query_attention': cfg.kv_n_heads = 2 + sequence_id = None + if attn_uses_sequence_id: + assert n == 2 + assert s >= 4 + sequence_id = torch.LongTensor([[0] * 2 + [1] * (s - 2), + [0] * 4 + [1] * (s - 4) + ]).to(device=device) + cfg.attn_impl = attn_impl_0 attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device) cfg.attn_impl = attn_impl_1 @@ -113,7 +135,7 @@ def gen_bias(attn_impl: str): s, alibi, prefix_lm=False, - use_sequence_id=False, + use_sequence_id=attn_uses_sequence_id, causal=causal) if bs is not None: attn_bias = torch.zeros(*bs, device=device) @@ -126,17 +148,35 @@ def gen_bias(attn_impl: str): alibi=alibi, alibi_bias_max=8, ) + if attn_impl != 'flash' and attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) # pyright + attn_bias = apply_sequence_id( + attn_bias, + sequence_id, # type: ignore + s) return attn_bias + attention_mask_in_length_0 = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=attn_uses_sequence_id, + attn_impl=attn_impl_0, + attention_mask=attention_mask) + attention_mask_in_length_1 = gen_attention_mask_in_length( + sequence_id=sequence_id, + S=s, + attn_uses_sequence_id=attn_uses_sequence_id, + attn_impl=attn_impl_1, + attention_mask=attention_mask) + x0 = torch.randn(n, s, f).to(device) x1 = x0.clone().detach() x0.requires_grad = True x1.requires_grad = True with torch.autocast(x0.device.type): - attn_bias = gen_bias(attn0.attn_impl) - + attn_bias_0 = gen_bias(attn_impl_0) rotary_emb_w_meta_info = None if rope: rotary_embedding = gen_rotary_embedding( @@ -165,17 +205,19 @@ def gen_bias(attn_impl: str): y0, _, _ = attn0(x0, past_key_value=None, - attn_bias=attn_bias, + attn_bias=attn_bias_0, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, - is_causal=True) - attn_bias = gen_bias(attn1.attn_impl) + is_causal=True, + attention_mask_in_length=attention_mask_in_length_0) + attn_bias_1 = gen_bias(attn_impl_1) y1, _, _ = attn1(x1, past_key_value=None, - attn_bias=attn_bias, + attn_bias=attn_bias_1, attention_mask=attention_mask, rotary_emb_w_meta_info=rotary_emb_w_meta_info, - is_causal=True) + is_causal=True, + attention_mask_in_length=attention_mask_in_length_1) y0 *= attention_mask.unsqueeze(-1) y1 *= attention_mask.unsqueeze(-1) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index acb2074ae9..98a556f534 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -555,6 +555,116 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): assert block.resid_ffn_dropout.p == 0.2 +@pytest.mark.gpu +@pytest.mark.parametrize('attention_impl', ['flash', 'triton', 'torch']) +@pytest.mark.parametrize('pos_emb_config', [{ + 'alibi': True, + 'rope': False +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, +}, { + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'hf', + 'rope_hf_config': { + 'type': 'no_scaling', + 'factor': 1.0, + }, +}]) +def test_sequence_id_based_masking(attention_impl: str, pos_emb_config: dict): + # Testing the output of concatenated sequence with sequence id masking vs individual sequences. + alibi = pos_emb_config['alibi'] + if alibi and attention_impl == 'flash': + pytest.skip(f'alibi only implemented with torch and triton attention.') + + rope = pos_emb_config['rope'] + if rope and pos_emb_config[ + 'rope_impl'] == 'dail' and not is_flash_v2_installed(): + pytest.skip( + f'dail implementation of rope requires gpu and flash attention 2.') + + if attention_impl == 'flash' and ( + not is_flash_v2_installed(v2_version='v2.1.2')): + pytest.skip( + 'Using sequence id with flash attention requires flash attention v2.1.2 or higher.' + ) + + composer_device = get_device(None) + + hf_config = MPTConfig( + init_device='cpu', + d_model=128, + n_heads=1, + n_layers=2, + expansion_ratio=2, + max_seq_len=2048, + emb_pdrop=0.1, + resid_pdrop=0.2, + attn_config={ + 'attn_impl': attention_impl, + 'attn_uses_sequence_id': True, + **pos_emb_config, + }, + init_config={ + 'name': 'baseline_', + 'init_std': 0.02, + }, + ) + mpt = MPTForCausalLM(hf_config) + mpt.eval() + mpt = composer_device.module_to_device(mpt) + + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): + # padding on the right side of the input + concatenated_seq_ids = torch.tensor([[11274, 16390, 11, 4332, 323, 423], + [2342, 12, 111, 123, 50256, 342]]) + concatenated_seq_ids = composer_device.tensor_to_device( + concatenated_seq_ids) + + sequence_id = torch.tensor([[0, 0, 0, 1, 2, 2], [0, 0, 0, 1, 2, 2]]) + sequence_id = composer_device.tensor_to_device(sequence_id) + + first_seq_ids = torch.tensor([[11274, 16390, 11], [2342, 12, 111]]) + first_seq_ids = composer_device.tensor_to_device(first_seq_ids) + + second_seq_ids = torch.tensor([[4332], [123]]) + second_seq_ids = composer_device.tensor_to_device(second_seq_ids) + + third_seq_ids = torch.tensor([[323, 423], [50256, 342]]) + third_seq_ids = composer_device.tensor_to_device(third_seq_ids) + + concatenated_seq_output = mpt(concatenated_seq_ids, + sequence_id=sequence_id).logits + first_seq_output = mpt(first_seq_ids).logits + second_seq_output = mpt(second_seq_ids).logits + third_seq_output = mpt(third_seq_ids).logits + + assert torch.allclose(concatenated_seq_output[:, :3], + first_seq_output, + atol=2e-6 if attention_impl == 'torch' else 1e-8) + assert torch.allclose(concatenated_seq_output[:, 3:4], + second_seq_output, + atol=2e-6 if attention_impl == 'torch' else 1e-8) + atol = 1e-8 + if attention_impl == 'torch': + atol = 2e-6 + elif pos_emb_config['rope']: + atol = 2e-2 + assert torch.allclose(concatenated_seq_output[:, 4:6], + third_seq_output, + atol=atol) + + @pytest.mark.parametrize('attention_impl', [ 'torch', pytest.param('flash', marks=pytest.mark.gpu),