Skip to content

Commit

Permalink
Add fixtures (#673)
Browse files Browse the repository at this point in the history
Add fixtures for testing boilerplate, tiny mpt models, and tiny finetune dataset
  • Loading branch information
irenedea committed Oct 25, 2023
1 parent bc687b7 commit ea3279a
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 243 deletions.
18 changes: 7 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import gc
import os
from typing import List, Optional

import pytest
import torch
from composer.utils import reproducibility

# Allowed options for pytest.mark.world_size()
Expand All @@ -18,6 +16,13 @@
# Enforce deterministic mode before any tests start.
reproducibility.configure_deterministic_mode()

# Add the path of any pytest fixture files you want to make global
pytest_plugins = [
'tests.fixtures.autouse',
'tests.fixtures.models',
'tests.fixtures.data',
]


def _add_option(parser: pytest.Parser,
name: str,
Expand Down Expand Up @@ -78,12 +83,3 @@ def pytest_collection_modifyitems(config: pytest.Config,
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
if exitstatus == 5:
session.exitstatus = 0 # Ignore no-test-ran errors


@pytest.fixture(autouse=True)
def clear_cuda_cache(request: pytest.FixtureRequest):
"""Clear memory between GPU tests."""
marker = request.node.get_closest_marker('gpu')
if marker is not None and torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
2 changes: 2 additions & 0 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
39 changes: 39 additions & 0 deletions tests/fixtures/autouse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import gc

import pytest
import torch
from composer.utils import dist, get_device, reproducibility


@pytest.fixture(autouse=True)
def initialize_dist(request: pytest.FixtureRequest):
"""Initialize the default PyTorch distributed process group for tests."""
# should we just always initialize dist like in train.py?
_default = pytest.mark.world_size(1).mark
world_size = request.node.get_closest_marker('world_size', _default).args[0]
gpu = request.node.get_closest_marker('gpu')
if world_size > 1:
dist.initialize_dist(get_device('gpu' if gpu is not None else 'cpu'))


@pytest.fixture(autouse=True)
def clear_cuda_cache(request: pytest.FixtureRequest):
"""Clear memory between GPU tests."""
marker = request.node.get_closest_marker('gpu')
if marker is not None and torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests


@pytest.fixture
def random_seed() -> int:
return 17


@pytest.fixture(autouse=True)
def seed_all(random_seed: int):
"""Sets the seed for reproducibility."""
reproducibility.seed_all(random_seed)
58 changes: 58 additions & 0 deletions tests/fixtures/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path

from composer.utils import dist
from omegaconf import DictConfig
from pytest import fixture
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from tests.data_utils import make_tiny_ft_dataset


@fixture
def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path:
"""Creates a tiny dataset and returns the path."""
tiny_dataset_path = tmp_path / 'test-ift-data-small'
tiny_dataset_path.mkdir(exist_ok=True)
tiny_dataset_file = tiny_dataset_path / 'train.jsonl'
if dist.get_world_size() == 1 or dist.get_global_rank() == 0:
make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size)
return tiny_dataset_path


@fixture
def tiny_ft_dataloader(tiny_ft_dataset_path: Path,
mpt_tokenizer: PreTrainedTokenizerBase,
max_seq_len: int = 128,
device_batch_size: int = 1) -> DataLoader:
dataloader_cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'hf_name': str(tiny_ft_dataset_path),
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
})

dataloader = build_finetuning_dataloader(
dataloader_cfg,
mpt_tokenizer,
device_batch_size,
).dataloader

assert isinstance(dataloader, DataLoader)
return dataloader
70 changes: 70 additions & 0 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable

from omegaconf import DictConfig
from pytest import fixture
from transformers import PreTrainedTokenizerBase

from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.utils.builders import build_tokenizer


def _build_model(config: DictConfig, tokenizer: PreTrainedTokenizerBase):
model = COMPOSER_MODEL_REGISTRY[config.name](config, tokenizer)
return model


@fixture
def mpt_tokenizer():
return build_tokenizer('EleutherAI/gpt-neox-20b', {})


@fixture
def build_tiny_mpt(
mpt_tokenizer: PreTrainedTokenizerBase
) -> Callable[..., ComposerMPTCausalLM]:

def build(**kwargs: Any) -> ComposerMPTCausalLM:
config = DictConfig({
'name': 'mpt_causal_lm',
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'expansion_ratio': 2,
})
config.update(kwargs)
model = _build_model(config, mpt_tokenizer)
assert isinstance(model, ComposerMPTCausalLM)
return model

return build


@fixture
def build_tiny_hf_mpt(
mpt_tokenizer: PreTrainedTokenizerBase
) -> Callable[..., ComposerHFCausalLM]:

def build(**kwargs: Any) -> ComposerHFCausalLM:
config_overrides = {
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'expansion_ratio': 2,
}
config_overrides.update(kwargs)
config = DictConfig({
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'mosaicml/mpt-7b',
'pretrained': False,
'config_overrides': config_overrides,
})
model = _build_model(config, mpt_tokenizer)
assert isinstance(model, ComposerHFCausalLM)
return model

return build
18 changes: 7 additions & 11 deletions tests/test_data_prep_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import os
import shutil
import sys
from argparse import Namespace
from pathlib import Path

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
Expand All @@ -13,17 +13,16 @@
from scripts.data_prep.convert_dataset_json import main as main_json


def test_download_script_from_api():
def test_download_script_from_api(tmp_path: Path):
# test calling it directly
path = os.path.join(os.getcwd(), 'my-copy-c4-1')
shutil.rmtree(path, ignore_errors=True)
path = os.path.join(tmp_path, 'my-copy-c4-1')
main_hf(
Namespace(
**{
'dataset': 'c4',
'data_subset': 'en',
'splits': ['val_xsmall'],
'out_root': './my-copy-c4-1',
'out_root': path,
'compression': None,
'concat_tokens': None,
'bos_text': None,
Expand All @@ -32,18 +31,16 @@ def test_download_script_from_api():
'num_workers': None
}))
assert os.path.exists(path)
shutil.rmtree(path, ignore_errors=False)


def test_json_script_from_api():
def test_json_script_from_api(tmp_path: Path):
# test calling it directly
path = os.path.join(os.getcwd(), 'my-copy-arxiv-1')
shutil.rmtree(path, ignore_errors=True)
path = os.path.join(tmp_path, 'my-copy-arxiv-1')
main_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
'out_root': './my-copy-arxiv-1',
'out_root': path,
'compression': None,
'split': 'train',
'concat_tokens': None,
Expand All @@ -53,4 +50,3 @@ def test_json_script_from_api():
'num_workers': None
}))
assert os.path.exists(path)
shutil.rmtree(path, ignore_errors=False)
9 changes: 0 additions & 9 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import OmegaConf as om


Expand Down Expand Up @@ -39,8 +38,6 @@ def test_attn_impl(attn_impl_0: str,
if alibi and (attn_impl_0 == 'flash' or attn_impl_1 == 'flash'):
pytest.xfail('flash attn does not support alibi')

reproducibility.seed_all(7)

cfg = om.create({
'attn_impl': 'flash',
'd_model': 128,
Expand Down Expand Up @@ -135,8 +132,6 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'):
"""Compare diff attn_impl to torch.nn.MultiheadAttention."""
from llmfoundry.models.layers import attention

reproducibility.seed_all(17)

cfg = om.create({
'attn_impl': attn_impl,
'd_model': 256,
Expand Down Expand Up @@ -234,8 +229,6 @@ def test_grouped_attention_heads(attn_impl: str,
"""Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads."""
from llmfoundry.models.layers import attention

reproducibility.seed_all(17)

cfg = om.create({
'attn_impl': attn_impl,
'd_model': 256,
Expand Down Expand Up @@ -273,8 +266,6 @@ def test_grouped_query_invalid_heads(attn_impl: str, device: str = 'cuda'):
"""Check indivisble combinations of grouped_query_attention."""
from llmfoundry.models.layers import attention

reproducibility.seed_all(17)

cfg = om.create({
'attn_impl': attn_impl,
'd_model': 256,
Expand Down
3 changes: 0 additions & 3 deletions tests/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -93,8 +92,6 @@ def test_hf_config_override(
with open(conf_path) as f:
test_cfg = om.load(f)

reproducibility.seed_all(test_cfg.seed)

# Build Model
# For fast initialization, use `meta` device
print('Initializing model...')
Expand Down
Loading

0 comments on commit ea3279a

Please sign in to comment.