diff --git a/tests/conftest.py b/tests/conftest.py index b39ebd66a9..545dc7e38f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,10 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import gc import os from typing import List, Optional import pytest -import torch from composer.utils import reproducibility # Allowed options for pytest.mark.world_size() @@ -18,6 +16,13 @@ # Enforce deterministic mode before any tests start. reproducibility.configure_deterministic_mode() +# Add the path of any pytest fixture files you want to make global +pytest_plugins = [ + 'tests.fixtures.autouse', + 'tests.fixtures.models', + 'tests.fixtures.data', +] + def _add_option(parser: pytest.Parser, name: str, @@ -78,12 +83,3 @@ def pytest_collection_modifyitems(config: pytest.Config, def pytest_sessionfinish(session: pytest.Session, exitstatus: int): if exitstatus == 5: session.exitstatus = 0 # Ignore no-test-ran errors - - -@pytest.fixture(autouse=True) -def clear_cuda_cache(request: pytest.FixtureRequest): - """Clear memory between GPU tests.""" - marker = request.node.get_closest_marker('gpu') - if marker is not None and torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000000..f6c1f9f3ab --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py new file mode 100644 index 0000000000..c51ccfacb0 --- /dev/null +++ b/tests/fixtures/autouse.py @@ -0,0 +1,39 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import pytest +import torch +from composer.utils import dist, get_device, reproducibility + + +@pytest.fixture(autouse=True) +def initialize_dist(request: pytest.FixtureRequest): + """Initialize the default PyTorch distributed process group for tests.""" + # should we just always initialize dist like in train.py? + _default = pytest.mark.world_size(1).mark + world_size = request.node.get_closest_marker('world_size', _default).args[0] + gpu = request.node.get_closest_marker('gpu') + if world_size > 1: + dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu')) + + +@pytest.fixture(autouse=True) +def clear_cuda_cache(request: pytest.FixtureRequest): + """Clear memory between GPU tests.""" + marker = request.node.get_closest_marker('gpu') + if marker is not None and torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests + + +@pytest.fixture +def random_seed() -> int: + return 17 + + +@pytest.fixture(autouse=True) +def seed_all(random_seed: int): + """Sets the seed for reproducibility.""" + reproducibility.seed_all(random_seed) diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py new file mode 100644 index 0000000000..39032146b6 --- /dev/null +++ b/tests/fixtures/data.py @@ -0,0 +1,58 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from composer.utils import dist +from omegaconf import DictConfig +from pytest import fixture +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerBase + +from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader +from tests.data_utils import make_tiny_ft_dataset + + +@fixture +def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: + """Creates a tiny dataset and returns the path.""" + tiny_dataset_path = tmp_path / 'test-ift-data-small' + tiny_dataset_path.mkdir(exist_ok=True) + tiny_dataset_file = tiny_dataset_path / 'train.jsonl' + if dist.get_world_size() == 1 or dist.get_global_rank() == 0: + make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size) + return tiny_dataset_path + + +@fixture +def tiny_ft_dataloader(tiny_ft_dataset_path: Path, + mpt_tokenizer: PreTrainedTokenizerBase, + max_seq_len: int = 128, + device_batch_size: int = 1) -> DataLoader: + dataloader_cfg = DictConfig({ + 'name': 'finetuning', + 'dataset': { + 'hf_name': str(tiny_ft_dataset_path), + 'split': 'train', + 'max_seq_len': max_seq_len, + 'decoder_only_format': True, + 'allow_pad_trimming': False, + 'packing_ratio': None, + 'shuffle': True, + }, + 'drop_last': False, + 'num_workers': 4, + 'pin_memory': False, + 'prefetch_factor': 2, + 'persistent_workers': False, + 'timeout': 0 + }) + + dataloader = build_finetuning_dataloader( + dataloader_cfg, + mpt_tokenizer, + device_batch_size, + ).dataloader + + assert isinstance(dataloader, DataLoader) + return dataloader diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py new file mode 100644 index 0000000000..1b1ef86302 --- /dev/null +++ b/tests/fixtures/models.py @@ -0,0 +1,70 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable + +from omegaconf import DictConfig +from pytest import fixture +from transformers import PreTrainedTokenizerBase + +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM +from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM +from llmfoundry.utils.builders import build_tokenizer + + +def _build_model(config: DictConfig, tokenizer: PreTrainedTokenizerBase): + model = COMPOSER_MODEL_REGISTRY[config.name](config, tokenizer) + return model + + +@fixture +def mpt_tokenizer(): + return build_tokenizer('EleutherAI/gpt-neox-20b', {}) + + +@fixture +def build_tiny_mpt( + mpt_tokenizer: PreTrainedTokenizerBase +) -> Callable[..., ComposerMPTCausalLM]: + + def build(**kwargs: Any) -> ComposerMPTCausalLM: + config = DictConfig({ + 'name': 'mpt_causal_lm', + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 2, + }) + config.update(kwargs) + model = _build_model(config, mpt_tokenizer) + assert isinstance(model, ComposerMPTCausalLM) + return model + + return build + + +@fixture +def build_tiny_hf_mpt( + mpt_tokenizer: PreTrainedTokenizerBase +) -> Callable[..., ComposerHFCausalLM]: + + def build(**kwargs: Any) -> ComposerHFCausalLM: + config_overrides = { + 'd_model': 128, + 'n_heads': 4, + 'n_layers': 2, + 'expansion_ratio': 2, + } + config_overrides.update(kwargs) + config = DictConfig({ + 'name': 'hf_causal_lm', + 'pretrained_model_name_or_path': 'mosaicml/mpt-7b', + 'pretrained': False, + 'config_overrides': config_overrides, + }) + model = _build_model(config, mpt_tokenizer) + assert isinstance(model, ComposerHFCausalLM) + return model + + return build diff --git a/tests/test_data_prep_scripts.py b/tests/test_data_prep_scripts.py index 4c555ea9a2..4fe5ed7e64 100644 --- a/tests/test_data_prep_scripts.py +++ b/tests/test_data_prep_scripts.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import os -import shutil import sys from argparse import Namespace +from pathlib import Path # Add repo root to path so we can import scripts and test it repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) @@ -13,17 +13,16 @@ from scripts.data_prep.convert_dataset_json import main as main_json -def test_download_script_from_api(): +def test_download_script_from_api(tmp_path: Path): # test calling it directly - path = os.path.join(os.getcwd(), 'my-copy-c4-1') - shutil.rmtree(path, ignore_errors=True) + path = os.path.join(tmp_path, 'my-copy-c4-1') main_hf( Namespace( **{ 'dataset': 'c4', 'data_subset': 'en', 'splits': ['val_xsmall'], - 'out_root': './my-copy-c4-1', + 'out_root': path, 'compression': None, 'concat_tokens': None, 'bos_text': None, @@ -32,18 +31,16 @@ def test_download_script_from_api(): 'num_workers': None })) assert os.path.exists(path) - shutil.rmtree(path, ignore_errors=False) -def test_json_script_from_api(): +def test_json_script_from_api(tmp_path: Path): # test calling it directly - path = os.path.join(os.getcwd(), 'my-copy-arxiv-1') - shutil.rmtree(path, ignore_errors=True) + path = os.path.join(tmp_path, 'my-copy-arxiv-1') main_json( Namespace( **{ 'path': 'scripts/data_prep/example_data/arxiv.jsonl', - 'out_root': './my-copy-arxiv-1', + 'out_root': path, 'compression': None, 'split': 'train', 'concat_tokens': None, @@ -53,4 +50,3 @@ def test_json_script_from_api(): 'num_workers': None })) assert os.path.exists(path) - shutil.rmtree(path, ignore_errors=False) diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 145d4a5885..e6fe8eb438 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -3,7 +3,6 @@ import pytest import torch -from composer.utils import reproducibility from omegaconf import OmegaConf as om @@ -39,8 +38,6 @@ def test_attn_impl(attn_impl_0: str, if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'): pytest.xfail('flash attn does not support alibi') - reproducibility.seed_all(7) - cfg = om.create({ 'attn_impl': 'flash', 'd_model': 128, @@ -135,8 +132,6 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'): """Compare diff attn_impl to torch.nn.MultiheadAttention.""" from llmfoundry.models.layers import attention - reproducibility.seed_all(17) - cfg = om.create({ 'attn_impl': attn_impl, 'd_model': 256, @@ -234,8 +229,6 @@ def test_grouped_attention_heads(attn_impl: str, """Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads.""" from llmfoundry.models.layers import attention - reproducibility.seed_all(17) - cfg = om.create({ 'attn_impl': attn_impl, 'd_model': 256, @@ -273,8 +266,6 @@ def test_grouped_query_invalid_heads(attn_impl: str, device: str = 'cuda'): """Check indivisble combinations of grouped_query_attention.""" from llmfoundry.models.layers import attention - reproducibility.seed_all(17) - cfg = om.create({ 'attn_impl': attn_impl, 'd_model': 256, diff --git a/tests/test_hf_config.py b/tests/test_hf_config.py index 5b3bb3d150..b47f267c55 100644 --- a/tests/test_hf_config.py +++ b/tests/test_hf_config.py @@ -9,7 +9,6 @@ import pytest import torch -from composer.utils import reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import AutoModelForCausalLM @@ -93,8 +92,6 @@ def test_hf_config_override( with open(conf_path) as f: test_cfg = om.load(f) - reproducibility.seed_all(test_cfg.seed) - # Build Model # For fast initialization, use `meta` device print('Initializing model...') diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py index cc357141ba..ea133c64fa 100644 --- a/tests/test_hf_mpt_gen.py +++ b/tests/test_hf_mpt_gen.py @@ -1,167 +1,51 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Any, Dict -from unittest.mock import Mock +from typing import Callable import pytest -from composer.callbacks import Generate as ComposerGenerate from composer.core.precision import get_precision_context -from composer.trainer import Trainer -from composer.utils import get_device, reproducibility -from omegaconf import DictConfig -from omegaconf import OmegaConf as om +from composer.utils import get_device +from transformers import PreTrainedTokenizerBase -from llmfoundry import COMPOSER_MODEL_REGISTRY -from llmfoundry.data.finetuning import build_finetuning_dataloader -from llmfoundry.utils import build_tokenizer -from tests.data_utils import make_tiny_ft_dataset +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM @pytest.mark.gpu @pytest.mark.parametrize('device', ['cpu', 'gpu']) @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) -def test_init_hfhub_mpt(device: str, attn_impl: str): +def test_init_hfhub_mpt( + device: str, + attn_impl: str, + build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM], + mpt_tokenizer: PreTrainedTokenizerBase, +): if device == 'cpu' and attn_impl == 'triton': pytest.skip(f'{attn_impl=} not implemented for {device=}.') composer_device = get_device(device) - with open('scripts/train/yamls/pretrain/testing.yaml') as f: - test_cfg = om.load(f) - - assert isinstance(test_cfg, DictConfig) - reproducibility.seed_all(test_cfg.get('seed', 42)) - - attn_uses_sequence_id = True if test_cfg.get('eos_token_id', - None) is not None else False - test_cfg.model = DictConfig({ - 'name': 'hf_causal_lm', - 'pretrained_model_name_or_path': 'mosaicml/mpt-7b', - 'pretrained': False, - 'config_overrides': { - 'd_model': 128, - 'n_heads': 4, - 'n_layers': 2, - 'expansion_ratio': 2, - 'attn_config': { - 'attn_impl': attn_impl, - 'attn_uses_sequence_id': attn_uses_sequence_id, - }, - }, + model = build_tiny_hf_mpt(attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, }) - - # build tokenizer - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore - tokenizer_name = tokenizer_cfg['name'] - tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) - tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) - - # build model - model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model, - tokenizer) - test_cfg.n_params = sum(p.numel() for p in model.parameters()) + model = composer_device.module_to_device(model) model.eval() - model = composer_device.module_to_device(model) with get_precision_context('amp_bf16' if composer_device.name == 'gpu' else 'fp32'): _ = model.generate( composer_device.tensor_to_device( - tokenizer('hello', return_tensors='pt')['input_ids']), + mpt_tokenizer('hello', return_tensors='pt')['input_ids']), max_new_tokens=10, ) -def test_init_hfhub_mpt_cpu(): - test_init_hfhub_mpt(device='cpu', attn_impl='torch') - - -@pytest.mark.gpu -def test_mpt_generate_callback(tmpdir: Path): - composer_device = get_device('gpu') - reproducibility.seed_all(42) - max_seq_len = 128 - - # testing dataset and dataloader - dataset_size = 5 - - tiny_dataset_path = tmpdir / 'test-ift-data-small' - tiny_dataset_path.mkdir() - tiny_dataset_file = tiny_dataset_path / 'train.jsonl' - make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size) - - dataloader_cfg = DictConfig({ - 'name': 'finetuning', - 'dataset': { - 'hf_name': str(tiny_dataset_path), - 'split': 'train', - 'max_seq_len': max_seq_len, - 'decoder_only_format': True, - 'allow_pad_trimming': False, - 'packing_ratio': None, - 'shuffle': True, - }, - 'drop_last': False, - 'num_workers': 4, - 'pin_memory': False, - 'prefetch_factor': 2, - 'persistent_workers': False, - 'timeout': 0 - }) - - # build tokenizer - tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {}) - - # build mpt model - model_config = DictConfig({ - 'name': 'mpt_causal_lm', - 'config_overrides': { - 'd_model': 128, - 'n_heads': 4, - 'n_layers': 2, - 'expansion_ratio': 2, - }, - }) - model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer) - model = composer_device.module_to_device(model) - - # generate callback - prompts = [ - 'The best banana bread recipe is', - '2+2=', - 'how much wood could a woodchuck chuck', - ] - gen_interval = 1 - generate = ComposerGenerate( - prompts, - interval=f'{gen_interval}ba', - max_new_tokens=5, - batch_size=len(prompts), - use_cache=True, - ) - generate.generate = Mock(wraps=generate.generate, autospec=True) - - # build trainer - device_batch_size = 1 - train_dataloader = build_finetuning_dataloader( - dataloader_cfg, - tokenizer, - device_batch_size, - ) - - trainer = Trainer( - model=model, - train_dataloader=train_dataloader, - device=composer_device, - max_duration=f'{gen_interval}ba', - callbacks=[generate], - ) - trainer.logger.log_table = Mock() - trainer.fit() - - generate.generate.assert_called_once() - trainer.logger.log_table.assert_called_once() +def test_init_hfhub_mpt_cpu( + build_tiny_hf_mpt: Callable[..., ComposerHFCausalLM], + mpt_tokenizer: PreTrainedTokenizerBase, +): + test_init_hfhub_mpt(device='cpu', + attn_impl='torch', + build_tiny_hf_mpt=build_tiny_hf_mpt, + mpt_tokenizer=mpt_tokenizer) diff --git a/tests/test_hf_v_mpt.py b/tests/test_hf_v_mpt.py index 82e2d05550..46172faf35 100644 --- a/tests/test_hf_v_mpt.py +++ b/tests/test_hf_v_mpt.py @@ -5,7 +5,6 @@ import pytest import torch -from composer.utils import reproducibility from omegaconf import OmegaConf as om from llmfoundry import COMPOSER_MODEL_REGISTRY @@ -52,10 +51,6 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, batch_size = 2 # set batch size device = 'cuda' # set decive - # ensure reproducibility - seed = 17 - reproducibility.seed_all(seed) # set seed - # get hf gpt2 cfg hf_cfg = om.create({ 'model': { @@ -154,11 +149,9 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, # UTIL: can be used to verify that models are not the same at init with torch.autocast(device_type='cuda', dtype=torch.float16): - torch.manual_seed(seed) hf_model_fwd = hf_model(batch)['logits'] if kpm is not None: hf_model_fwd *= kpm - torch.manual_seed(seed) model_fwd = model(batch).logits if kpm is not None: model_fwd *= kpm @@ -208,11 +201,9 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, model.load_state_dict(_hf_model_statedict) with torch.autocast(device_type=device, dtype=torch.float16): - torch.manual_seed(seed) hf_model_fwd = hf_model(batch)['logits'] if kpm is not None: hf_model_fwd *= kpm - torch.manual_seed(seed) model_fwd = model(batch).logits if kpm is not None: model_fwd *= kpm diff --git a/tests/test_init_fn.py b/tests/test_init_fn.py index b054bac186..6be2c5ca42 100644 --- a/tests/test_init_fn.py +++ b/tests/test_init_fn.py @@ -8,7 +8,6 @@ import pytest import torch -from composer.utils import reproducibility from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from torch import nn @@ -35,8 +34,6 @@ def forward(self, x: torch.Tensor): @pytest.mark.parametrize('is_residual', [True, False]) def test_div_is_residual(is_residual: bool): - reproducibility.seed_all(7) - in_features, out_features = 8, 32 cfg = om.create({ 'in_features': in_features, @@ -64,8 +61,6 @@ def test_div_is_residual(is_residual: bool): @pytest.mark.parametrize('fused', [True, False]) def test_fused_init_helper(fused: bool): - reproducibility.seed_all(7) - in_features, out_features = 8, 32 cfg = om.create({ 'in_features': in_features, @@ -133,8 +128,6 @@ def max_fill_init_(weight: torch.Tensor): ('emb_init_uniform_lim', [1, 1]) ]) def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): - reproducibility.seed_all(7) - cfg: Dict[str, Union[int, List[int]]] = { 'vocab_size': 64, 'in_features': 16, diff --git a/tests/test_model.py b/tests/test_model.py index 6ea530731a..67166bef68 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -16,7 +16,7 @@ from composer.core.precision import Precision, get_precision_context from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module -from composer.utils import dist, get_device, reproducibility +from composer.utils import dist, get_device from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, @@ -56,8 +56,6 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): message='Torchmetrics v0.9 introduced a new argument class property') test_cfg = get_config(conf_path=conf_path) - reproducibility.seed_all(test_cfg.seed) - # Read FSDP Config as a dict fsdp_config = test_cfg.get('fsdp_config', None) fsdp_config = om.to_container(fsdp_config, @@ -316,7 +314,6 @@ def test_determinism(attn_impl: str, precision: torch.dtype): pytest.skip( 'This test requires CUDA to be available in order to run with bfloat16 precision.' ) - reproducibility.seed_all(1111) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: @@ -394,8 +391,6 @@ def test_loss_fn(): 'init_std': 0.02, } - reproducibility.seed_all(test_cfg.get('global_seed', 42)) - tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -537,7 +532,6 @@ def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - reproducibility.seed_all(1234) composer_device = get_device(device) hf_config = MPTConfig( @@ -716,7 +710,6 @@ def test_generate(attention_impl: str, device: str, alibi: bool): if alibi and attention_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - reproducibility.seed_all(1234) composer_device = get_device(device) hf_config = MPTConfig( @@ -776,14 +769,12 @@ def test_generate(attention_impl: str, device: str, alibi: bool): use_cache=False) assert batched_generation.shape == (2, 6 + 5) - reproducibility.seed_all(1234) generation_with_left_padding = mpt.generate( input_ids=left_padding_input_ids, attention_mask=left_padding_attention_mask, max_new_tokens=5, use_cache=False) assert generation_with_left_padding.shape == (2, 6 + 5) - reproducibility.seed_all(1234) generation_with_no_padding = mpt.generate( input_ids=no_padding_input_ids, attention_mask=no_padding_attention_mask, @@ -1007,14 +998,12 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): 'init_std': 0.02, }, ) - reproducibility.seed_all(1234) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) mpt.eval() with get_precision_context('amp_bf16' if composer_device.name == 'gpu' else 'fp32'): - reproducibility.seed_all(1234) first_input_ids = torch.tensor([[11274, 16390, 11]]) first_input_ids = composer_device.tensor_to_device(first_input_ids) first_attention_mask = torch.tensor([[1, 1, 1]]).bool() @@ -1040,7 +1029,6 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): assert all(past_key_value[1].shape == (1, 3, 128) for past_key_value in first_output.past_key_values) - reproducibility.seed_all(1234) second_input_ids = torch.tensor([[11274, 16390, 11, 11274]]) second_input_ids = composer_device.tensor_to_device(second_input_ids) second_attention_mask = torch.tensor([[1, 1, 1, 1]]).bool() @@ -1070,7 +1058,6 @@ def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): assert all(past_key_value[1].shape == (1, 4, 128) for past_key_value in second_output.past_key_values) - reproducibility.seed_all(1234) # pass through the first four tokens without the key-value cache full_output = mpt(second_input_ids, attention_mask=second_attention_mask) @@ -1205,7 +1192,6 @@ def test_model_to(attention_impl: str, alibi: bool): 'init_std': 0.02, }, ) - reproducibility.seed_all(1234) mpt = MPTForCausalLM(hf_config) mpt = mpt.bfloat16() mpt = mpt.to('cuda') @@ -1318,14 +1304,12 @@ def test_forward_with_output_attentions_and_output_hidden_states( 'init_std': 0.02, }, ) - reproducibility.seed_all(1234) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) mpt.eval() with get_precision_context('amp_bf16' if composer_device.name == 'gpu' else 'fp32'): - reproducibility.seed_all(1234) input_ids = torch.tensor([[11274, 16390, 11]]) input_ids = composer_device.tensor_to_device(input_ids) attention_mask = torch.tensor([[1, 1, 1]]).bool() diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py index 06ddccd479..c52b765480 100644 --- a/tests/test_mpt_gen.py +++ b/tests/test_mpt_gen.py @@ -1,19 +1,21 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple -from unittest.mock import patch +from typing import Callable, List, Optional, Tuple +from unittest.mock import Mock, patch import pytest import torch +from composer import Trainer +from composer.callbacks import Generate as ComposerGenerate from composer.core.precision import get_precision_context -from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig +from composer.utils import dist, get_device from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerBase -from llmfoundry import COMPOSER_MODEL_REGISTRY -from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM -from llmfoundry.utils import build_tokenizer +from llmfoundry.models.mpt.modeling_mpt import (ComposerMPTCausalLM, + MPTForCausalLM) EOS_TOKEN_ID = 0 @@ -55,44 +57,72 @@ def forward( @pytest.mark.parametrize('use_alibi', [True, False]) @patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM', new=MockMPTForCausalLM) -def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool): +def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, + build_tiny_mpt: Callable[..., + ComposerMPTCausalLM], + mpt_tokenizer: PreTrainedTokenizerBase): """Tests mpt generation with mutiple gpus. and generations of different lengths. """ - composer_device = get_device('gpu') - dist.initialize_dist(composer_device) - reproducibility.seed_all(42) - - model_config = DictConfig({ - 'name': 'mpt_causal_lm', - 'd_model': 128, - 'n_heads': 4, - 'n_layers': 2, - 'expansion_ratio': 2, - 'no_bias': False, - 'use_cache': True, - 'attn_config': { - 'attn_impl': attn_impl, - 'attn_uses_sequence_id': False, - 'alibi': use_alibi - }, - }) - - # build tokenizer - tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {}) - - # build model - model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer) - model = composer_device.module_to_device(model) + device = get_device('gpu') + + model = build_tiny_mpt(attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + },) + model = device.module_to_device(model) + model.eval() model.model = FSDP(model.model) with get_precision_context('amp_bf16'): - _ = model.generate(composer_device.tensor_to_device( - tokenizer('hello', return_tensors='pt')['input_ids']), + _ = model.generate(device.tensor_to_device( + mpt_tokenizer('hello', return_tensors='pt')['input_ids']), max_new_tokens=3, eos_token_id=EOS_TOKEN_ID, use_cache=True, synced_gpus=True) + + +@pytest.mark.gpu +def test_mpt_generate_callback(build_tiny_mpt: Callable[..., + ComposerMPTCausalLM], + tiny_ft_dataloader: DataLoader): + device = get_device('gpu') + + # build mpt model + model = build_tiny_mpt() + model = device.module_to_device(model) + + # generate callback + prompts = [ + 'The best banana bread recipe is', + '2+2=', + 'how much wood could a woodchuck chuck', + ] + gen_interval = 1 + generate = ComposerGenerate( + prompts, + interval=f'{gen_interval}ba', + max_new_tokens=5, + batch_size=len(prompts), + use_cache=True, + ) + generate.generate = Mock(wraps=generate.generate, autospec=True) + + # build trainer + trainer = Trainer( + model=model, + train_dataloader=tiny_ft_dataloader, + device=device, + max_duration=f'{gen_interval}ba', + callbacks=[generate], + ) + trainer.logger.log_table = Mock() + trainer.fit() + + generate.generate.assert_called_once() + trainer.logger.log_table.assert_called_once() diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 4ccb8e4112..d0e01746eb 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -4,7 +4,6 @@ import pathlib import torch -from composer.utils import reproducibility from transformers import AutoModelForCausalLM from llmfoundry import MPTConfig, MPTForCausalLM @@ -27,7 +26,6 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): def test_onnx_export(tmp_path: pathlib.Path): - reproducibility.seed_all(42) from transformers.models.auto.configuration_auto import CONFIG_MAPPING CONFIG_MAPPING._extra_content['mpt'] = MPTConfig AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM)