diff --git a/.github/workflows/smoketest.yaml b/.github/workflows/smoketest.yaml new file mode 100644 index 0000000000..0bf3968753 --- /dev/null +++ b/.github/workflows/smoketest.yaml @@ -0,0 +1,41 @@ +name: Smoketest +on: + push: + branches: + - main + - release/* + pull_request: + branches: + - main + - release/* + workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/dev' }} +defaults: + run: + working-directory: . +jobs: + smoketest: + runs-on: ubuntu-20.04 + timeout-minutes: 10 + strategy: + matrix: + python_version: + - "3.9" + - "3.10" + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + - name: Setup + run: | + set -ex + python -m pip install --upgrade 'pip<23' wheel + python -m pip install --upgrade . + python -m pip install pytest==7.2.1 pytest_codeblocks==0.16.1 + - name: Run checks + run: | + pytest tests/test_smoketest.py diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 85f96aadb9..87504d26b3 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -4,6 +4,26 @@ import torch try: + import warnings + + # bitsandbytes is a very noisy library. A lot of it is print statements that we can't easily suppress, + # but we can at least suppress a bunch of spurious warnings. + warnings.filterwarnings('ignore', + category=UserWarning, + module='bitsandbytes') + + import logging + + from llmfoundry.utils.logging_utils import SpecificWarningFilter + + # Filter out Hugging Face warning for not using a pinned revision of the model + hf_dynamic_modules_logger = logging.getLogger( + 'transformers.dynamic_module_utils') + new_files_warning_filter = SpecificWarningFilter( + 'A new version of the following files was downloaded from') + + hf_dynamic_modules_logger.addFilter(new_files_warning_filter) + # Before importing any transformers models, we need to disable transformers flash attention if # we are in an environment with flash attention version <2. Transformers hard errors on a not properly # gated import otherwise. diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 44743d6eb7..52a30b4b7f 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -44,6 +44,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase +from llmfoundry.utils.logging_utils import SpecificWarningFilter + log = logging.getLogger(__name__) __all__ = ['dataset_constructor'] @@ -245,7 +247,7 @@ def wrapper(func: Callable) -> Callable: def print_registered_tasks(self) -> None: tasks = sorted(self._task_preprocessing_registry.keys()) - print('\n'.join(tasks)) + log.info('\n'.join(tasks)) def get_preprocessing_fn_from_dict( self, @@ -363,6 +365,15 @@ def build_from_hf( with dist.local_rank_zero_download_and_wait(signal_file_path): pass + hf_tokenization_logger = logging.getLogger( + 'transformers.tokenization_utils_base') + sequence_length_warning_filter = SpecificWarningFilter( + 'Token indices sequence length is longer than the specified maximum sequence length' + ) + + # We will trim examples later in the collate_fn, so we want to silence this warning from Hugging Face + hf_tokenization_logger.addFilter(sequence_length_warning_filter) + error: Optional[Exception] = None filtered_dataset = None try: @@ -468,6 +479,9 @@ def filter_long_or_empty_examples(example: Dict) -> bool: log.error('Error during data prep') raise error log.debug('All ranks finished data prep') + + hf_tokenization_logger.removeFilter(sequence_length_warning_filter) + assert filtered_dataset is not None return filtered_dataset diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index d52633a09b..fcac57d817 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -20,7 +20,7 @@ from composer.utils import dist from omegaconf import DictConfig from torch import nn -from transformers import (AutoConfig, AutoModelForCausalLM, +from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase) from llmfoundry.models.hf.hf_fsdp import hf_get_init_device @@ -102,20 +102,27 @@ def __init__(self, om_model_config: Union[DictConfig, 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. ' + 'Please install flash_attn==2.3.2`.') + requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager' config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, + attn_implementation=requested_attention_implementation, + use_cache= + False, # Necessary due to https://github.com/huggingface/transformers/issues/28056 ) - # This is not how you are supposed to set this, but transformers currently only - # supports enabling flash attention 2 when using the from_pretrained API. - # We need to support it for both from_pretrained and from_config, so we have to - # set the private attribute here. This will just skip all of transformers' - # validation logic that it is ok to use flash attention 2, so we check - # whether it is installed above, and whether the chosen config supports it here. - # https://github.com/huggingface/transformers/issues/26878 - config._flash_attn_2_enabled = use_flash_attention_2 + # This is not ideal, however Hugging Face's _autoset_attn_implementation function + # forces you to load the model in fp16/bf16 if you want to use flash attention. Rather than loading + # the model and then casting it back to fp32, we are monkeypatching their check. + # https://github.com/huggingface/transformers/issues/28052 + def _autoset_attn_implementation_monkeypatch( + cls, config, *args, **kwargs): # type: ignore + config._attn_implementation = requested_attention_implementation + return config + + PreTrainedModel._autoset_attn_implementation = classmethod( + _autoset_attn_implementation_monkeypatch) # set config overrides for k, v in om_model_config.get('config_overrides', {}).items(): @@ -184,7 +191,8 @@ def __init__(self, om_model_config: Union[DictConfig, trust_remote_code=trust_remote_code, use_auth_token=use_auth_token, load_in_8bit=load_in_8bit, - config=config) + config=config, + ) else: model = AutoModelForCausalLM.from_config( config, diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 7257b98bd8..39de2ba59c 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -23,9 +23,11 @@ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', ] -from openai.types.chat.chat_completion import ChatCompletion -from openai.types.completion import Completion -from openai.types.completion_choice import Logprobs + +if TYPE_CHECKING: + from openai.types.chat.chat_completion import ChatCompletion + from openai.types.completion import Completion + from openai.types.completion_choice import Logprobs MAX_RETRIES = 10 @@ -99,7 +101,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: 'role': 'system', 'content': - model_cfg.get('sytsem_role_prompt', + model_cfg.get('system_role_prompt', 'Please complete the following text: ') }, { 'role': 'user', @@ -201,7 +203,7 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None): return torch.stack(output_logits_batch).to(batch['input_ids'].device) - def process_result(self, completion: Optional[ChatCompletion]): + def process_result(self, completion: Optional['ChatCompletion']): if completion is None: raise ValueError("Couldn't generate model output") @@ -234,7 +236,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: logprobs=5, temperature=0.0) - def process_result(self, completion: Optional[Completion]): + def process_result(self, completion: Optional['Completion']): if completion is None: raise ValueError("Couldn't generate model output") diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index e18e611ca6..560e8c31fc 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -4,7 +4,9 @@ """MPT Blocks used for the MPT Model.""" import logging -from typing import Any, Optional, Union +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn @@ -18,6 +20,36 @@ log = logging.getLogger(__name__) +_FFN_ACT_FN_DEFAULT = { + 'name': 'gelu', + 'approximate': 'none', +} + + +def resolve_ffn_act_fn( + config: Optional[dict] = None,) -> Callable[[torch.Tensor], torch.Tensor]: + """Resolve the activation function for the feed-forward network. + + Args: + config (Optional[dict]): The configuration dictionary for the activation function. + The dict config must specify the 'name' of a torch.nn.functional activation + function. All of other key values pairs are bound to the function as a partial. + + Returns: + Callable[[torch.Tensor], torch.Tensor]: The activation function. + """ + if config is None: + config = _FFN_ACT_FN_DEFAULT + config = deepcopy(config) + name = config.pop('name') + if not hasattr(torch.nn.functional, name): + raise ValueError(f'Unrecognised activation function name ({name}).') + act = getattr(torch.nn.functional, name) + return partial(act, **config) + + +_DEFAULT_ACT_FN = resolve_ffn_act_fn(_FFN_ACT_FN_DEFAULT) + def resolve_ffn_hidden_size( d_model: int, @@ -55,6 +87,7 @@ def __init__( expansion_ratio: Union[int, float], fc_type: str = 'torch', ffn_hidden_size: Optional[int] = None, + act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN, device: Optional[str] = None, bias: bool = True, ): @@ -72,7 +105,7 @@ def __init__( ffn_hidden_size, **self.fc_kwargs, ) - self.act = nn.GELU(approximate='none') + self.act = act_fn self.down_proj = FC_CLASS_REGISTRY[fc_type]( ffn_hidden_size, d_model, @@ -92,6 +125,7 @@ def __init__( expansion_ratio: Union[int, float], fc_type: str = 'torch', ffn_hidden_size: Optional[int] = None, + act_fn: Callable[[torch.Tensor], torch.Tensor] = _DEFAULT_ACT_FN, device: Optional[str] = None, bias: bool = True, ): @@ -100,6 +134,7 @@ def __init__( expansion_ratio=expansion_ratio, fc_type=fc_type, ffn_hidden_size=ffn_hidden_size, + act_fn=act_fn, device=device, bias=bias, ) @@ -128,6 +163,7 @@ def build_ffn( expansion_ratio: Union[int, float], fc_type: str = 'torch', ffn_hidden_size: Optional[int] = None, + ffn_act_fn: Optional[dict] = None, device: Optional[str] = None, bias: bool = True, **kwargs: Any, @@ -142,6 +178,7 @@ def build_ffn( d_model=d_model, expansion_ratio=expansion_ratio, fc_type=fc_type, + act_fn=resolve_ffn_act_fn(ffn_act_fn), ffn_hidden_size=ffn_hidden_size, device=device, bias=bias, @@ -150,6 +187,10 @@ def build_ffn( assert te is not None ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio, ffn_hidden_size) + if ffn_act_fn is not None: + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' + ) return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=ffn_hidden_size, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 2ecc726aa3..913c39d44f 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -295,6 +295,10 @@ def _validate_config(self) -> None: self.ffn_config['fc_type'] = self.fc_type elif self.ffn_config['ffn_type'] == 'te_ln_mlp': self.ffn_config['bias'] = not self.no_bias + if 'ffn_act_fn' in self.ffn_config.keys(): + raise ValueError( + f'Transformer Engine block does not support custom activation functions.' + ) if not self.use_pad_tok_in_ffn: try: from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip diff --git a/llmfoundry/utils/logging_utils.py b/llmfoundry/utils/logging_utils.py new file mode 100644 index 0000000000..081a06fefb --- /dev/null +++ b/llmfoundry/utils/logging_utils.py @@ -0,0 +1,21 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging + + +class SpecificWarningFilter(logging.Filter): + + def __init__(self, message_to_suppress: str): + """Filter out a specific warning message based on its content. + + This can be useful for filtering out specific warning messages from third party packages. + + Args: + message_to_suppress (str): The warning message to suppress. + """ + super().__init__() + self.message_to_suppress = message_to_suppress + + def filter(self, record: logging.LogRecord) -> bool: + return self.message_to_suppress not in record.getMessage() diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 369a894720..5c74b9fd8f 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -26,6 +26,8 @@ build_tokenizer) from llmfoundry.utils.config_utils import pop_config, process_init_device +log = logging.getLogger(__name__) + def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, num_retries: int) -> ComposerModel: @@ -65,7 +67,7 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, if retries >= num_retries: raise e else: - print( + log.info( f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) @@ -89,7 +91,7 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, if retries >= num_retries: raise e else: - print( + log.info( f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining' ) @@ -116,7 +118,7 @@ def evaluate_model( icl_subset_num_batches: Optional[int], ): - print(f'Evaluating model: {model_cfg.model_name}', flush=True) + log.info(f'Evaluating model: {model_cfg.model_name}') # Build tokenizer and model tokenizer_cfg: Dict[str, Any] = om.to_container(model_cfg.tokenizer, @@ -200,7 +202,7 @@ def evaluate_model( if torch.cuda.is_available(): torch.cuda.synchronize() b = time.time() - print(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') + log.info(f'Ran {model_cfg.model_name} eval in: {b-a} seconds') return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df) @@ -215,7 +217,7 @@ def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: must_exist=False, default_value=None) if eval_gauntlet_config: - print( + warnings.warn( 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`' ) diff --git a/scripts/train/train.py b/scripts/train/train.py index 809f2fb09c..2c1099ff00 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -34,6 +34,8 @@ process_init_device, update_batch_size_info) +log = logging.getLogger(__name__) + def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" @@ -138,17 +140,17 @@ def build_composer_peft_model( + f'Error encountered: {e}') # 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM. - print('Building Lora config...') + log.info('Building Lora config...') lora_cfg = LoraConfig(**lora_args) - print('Building model from HuggingFace checkpoint...') + log.info('Building model from HuggingFace checkpoint...') model = MPTForCausalLM.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) - print('Model built!') + log.info('Model built!') - print('Adding Lora modules...') + log.info('Adding Lora modules...') model = get_peft_model(model, lora_cfg) - print('Lora modules added!') + log.info('Lora modules added!') model = ComposerHFCausalLM(model, tokenizer) @@ -163,7 +165,7 @@ def print_trainable_parameters(model: torch.nn.Module) -> None: all_param += param.numel() if param.requires_grad: trainable_params += param.numel() - print( + log.info( f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}' ) @@ -260,9 +262,9 @@ def main(cfg: DictConfig) -> Trainer: must_exist=False, default_value=None) if eval_gauntlet_config is not None: - print( - 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`' - ) + warnings.warn( + 'Use of the key `model_gauntlet` is deprecated, please use the key `eval_gauntlet`', + DeprecationWarning) icl_subset_num_batches: Optional[int] = pop_config(cfg, 'icl_subset_num_batches', must_exist=False, @@ -398,7 +400,7 @@ def main(cfg: DictConfig) -> Trainer: autoresume_default = True if cfg.get('autoresume') is None and autoresume_default: - print('As run_name, save_folder, and save_latest_filename are set, \ + log.info('As run_name, save_folder, and save_latest_filename are set, \ changing autoresume default to True...') autoresume: bool = pop_config(cfg, @@ -514,7 +516,7 @@ def main(cfg: DictConfig) -> Trainer: ] if algorithm_configs else None # Dataloaders - print('Building train loader...') + log.info('Building train loader...') train_loader = build_dataloader( train_loader_config, tokenizer, @@ -525,7 +527,7 @@ def main(cfg: DictConfig) -> Trainer: mosaicml_logger.log_metrics({'data_validated': time.time()}) ## Evaluation - print('Building eval loader...') + log.info('Building eval loader...') eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len evaluators, _, eval_gauntlet_callback = build_evaluators( eval_loader_config, @@ -541,7 +543,7 @@ def main(cfg: DictConfig) -> Trainer: callbacks.append(eval_gauntlet_callback) # Build Model - print('Initializing model...') + log.info('Initializing model...') with init_context: if lora_config is not None: # frozen model + trainable lora modules model: ComposerHFCausalLM = build_composer_peft_model( @@ -570,7 +572,7 @@ def main(cfg: DictConfig) -> Trainer: evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics) # Build the Trainer - print('Building trainer...') + log.info('Building trainer...') trainer = Trainer( run_name=run_name, seed=seed, @@ -609,7 +611,7 @@ def main(cfg: DictConfig) -> Trainer: compile_config=compile_config, ) - print('Logging config') + log.info('Logging config') log_config(logged_cfg) torch.cuda.empty_cache() gc.collect() @@ -618,10 +620,10 @@ def main(cfg: DictConfig) -> Trainer: if eval_first and trainer.state.timestamp.batch.value == 0: trainer.eval() - print('Starting training...') + log.info('Starting training...') trainer.fit() - print('Done.') + log.info('Done.') return trainer diff --git a/setup.py b/setup.py index 9853aa17bf..2283e60d9c 100644 --- a/setup.py +++ b/setup.py @@ -48,11 +48,11 @@ install_requires = [ 'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.17.1,<0.18', - 'accelerate>=0.20,<0.21', # for HF inference `device_map` - 'transformers>=4.34.1,<4.35', + 'accelerate>=0.25,<0.26', # for HF inference `device_map` + 'transformers>=4.36,<4.37', 'mosaicml-streaming>=0.7.1,<0.8', 'torch>=2.1,<2.1.1', - 'datasets>=2.14.5,<2.15', + 'datasets==2.15.0', 'fsspec==2023.6.0', # newer version results in a bug in datasets that duplicates data 'sentencepiece==0.1.97', 'einops==0.5.0', diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 94a2d66c6e..28fb9219f8 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -258,7 +258,7 @@ def test_callback_inits(): @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '4ba', 2, 2), ('1dur', '2ba', '1ep', 1, 2)]) -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def test_huggingface_conversion_callback_interval( tmp_path: pathlib.Path, log_to_mlflow: bool, hf_save_interval: str, save_interval: str, max_duration: str, expected_hf_checkpoints: int, @@ -381,7 +381,7 @@ def test_huggingface_conversion_callback_interval( @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('1ba', '1ba', '1ba', 1, 1)]) -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def test_huggingface_conversion_callback( model: str, tmp_path: pathlib.Path, diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index 16dd01347d..9ba053ffe8 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -26,7 +26,7 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: @fixture -@patch('os.cpu_count', MagicMock(return_value=None)) +@patch('os.cpu_count', MagicMock(return_value=1)) def tiny_ft_dataloader(tiny_ft_dataset_path: Path, mpt_tokenizer: PreTrainedTokenizerBase, max_seq_len: int = 128, diff --git a/tests/models/layers/test_huggingface_flash.py b/tests/models/layers/test_huggingface_flash.py index 70c08c4eb1..411aab77a2 100644 --- a/tests/models/layers/test_huggingface_flash.py +++ b/tests/models/layers/test_huggingface_flash.py @@ -159,9 +159,11 @@ def test_attn_patch_integration(patch: str): @pytest.mark.gpu +@pytest.mark.world_size(2) @pytest.mark.parametrize('model_name', ['llama2', 'mistral']) @pytest.mark.parametrize('use_flash_attention_2', [True, False]) -def test_flash2(model_name: str, use_flash_attention_2: bool): +@pytest.mark.parametrize('init_device', ['cpu', 'mixed', 'meta']) +def test_flash2(model_name: str, use_flash_attention_2: bool, init_device: str): if model_name == 'llama2': if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( @@ -177,7 +179,7 @@ def test_flash2(model_name: str, use_flash_attention_2: bool): }, 'use_auth_token': True, 'pretrained': False, - 'init_device': 'cpu', + 'init_device': init_device, } tokenizer_name = 'meta-llama/Llama-2-7b-hf' @@ -228,21 +230,27 @@ def test_flash2(model_name: str, use_flash_attention_2: bool): model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg, tokenizer) # check that it actually used flash attention 2 - assert model.model.config._flash_attn_2_enabled if use_flash_attention_2 else not model.model.config._flash_attn_2_enabled + assert model.model.config._attn_implementation == ( + 'flash_attention_2' if use_flash_attention_2 else 'eager') attention_layer = rgetattr( rgetattr(model, attention_layers_attr)[0], attention_attr) assert isinstance(attention_layer, flash_attn_class) - tokenized_input = tokenizer(['Hello world blah blah', 'Goodbye world'], - return_tensors='pt', - padding=True) - tokenized_input['labels'] = tokenized_input['input_ids'].clone() - - tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} - model.to('cuda') - - with get_precision_context('amp_bf16'): - # We're just testing that flash attention 2 runs okay - outputs = model(tokenized_input) - loss = outputs.loss - loss.backward() + # Skip attempting to run forward/backward when some devices have meta params + # because we are not instantiating a full Trainer here, which contains the logic + # to move params off of meta device. + if init_device == 'cpu': + tokenized_input = tokenizer( + ['Hello world blah blah', 'Goodbye world'], + return_tensors='pt', + padding=True) + tokenized_input['labels'] = tokenized_input['input_ids'].clone() + + tokenized_input = {k: v.cuda() for k, v in tokenized_input.items()} + model.to('cuda') + + with get_precision_context('amp_bf16'): + # We're just testing that flash attention 2 runs okay + outputs = model(tokenized_input) + loss = outputs.loss + loss.backward() diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 6d48d115fd..3b2fc22ee3 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -351,7 +351,25 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) @pytest.mark.parametrize('ffn_type', ['mptmlp', 'mptgeglu']) -def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): +@pytest.mark.parametrize('ffn_act_fn', [ + None, + { + 'name': 'gelu', + 'approximate': 'tanh', + }, + { + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, + }, + pytest.param({'name': 'relu5'}, + marks=pytest.mark.xfail(reason='invalid choice.', + strict=True)), +]) +def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str, + ffn_act_fn: dict): conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) @@ -363,6 +381,7 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str): test_cfg.model.ffn_config['ffn_type'] = ffn_type else: test_cfg.model.setdefault('ffn_config', {'ffn_type': ffn_type}) + test_cfg.model.ffn_config['ffn_act_fn'] = ffn_act_fn test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' @@ -516,12 +535,34 @@ def test_opt_wrapping(): @pytest.mark.parametrize('tie_word_embeddings', [True, False]) @pytest.mark.parametrize('expansion_ratio,ffn_hidden_size', [ (2, None), - (1.231, None), + pytest.param(1.231, + None, + marks=pytest.mark.xfail( + reason='d_model * expansion_ratio must be an integer.', + strict=True)), (2, 128), (2, 256), ]) +@pytest.mark.parametrize('ffn_act_fn', [ + None, + { + 'name': 'gelu', + 'approximate': 'tanh', + }, + { + 'name': 'silu', + }, + { + 'name': 'relu', + 'inplace': True, + }, + pytest.param({'name': 'relu5'}, + marks=pytest.mark.xfail(reason='invalid choice.', + strict=True)), +]) def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, - expansion_ratio: Union[int, float], ffn_hidden_size: int): + expansion_ratio: Union[int, float], ffn_hidden_size: int, + ffn_act_fn: dict): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', @@ -541,11 +582,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool, ffn_config={ 'ffn_type': 'mptmlp', 'ffn_hidden_size': ffn_hidden_size, + 'ffn_act_fn': ffn_act_fn, }, ) - if hf_config.d_model * hf_config.expansion_ratio != int( - hf_config.d_model * hf_config.expansion_ratio): - pytest.xfail('d_model * expansion_ratio must be an integer.') mpt = MPTForCausalLM(hf_config) @@ -1901,7 +1940,7 @@ def test_hf_init(tmp_path: pathlib.Path, precision = Precision('amp_bf16') hf_config = MPTConfig( - init_device=init_device, + init_device='cpu', d_model=32, n_heads=4, n_layers=1, diff --git a/tests/test_smoketest.py b/tests/test_smoketest.py new file mode 100644 index 0000000000..a43925e506 --- /dev/null +++ b/tests/test_smoketest.py @@ -0,0 +1,16 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry import callbacks, data, models, optim, tokenizers, utils + + +# This very simple test is just to use the above imports, which check and make sure we can import all the top-level +# modules from foundry. This is mainly useful for checking that we have correctly conditionally imported all optional +# dependencies. +def test_smoketest(): + assert callbacks + assert data + assert models + assert optim + assert tokenizers + assert utils