From 9250e84676f342dabe2fafafa1274e7a3fac184b Mon Sep 17 00:00:00 2001 From: bcui19 Date: Wed, 2 Aug 2023 14:20:04 -0400 Subject: [PATCH] Adding pyright to pre-commit (#477) --- .github/workflows/code-quality.yaml | 1 - .pre-commit-config.yaml | 10 + llmfoundry/callbacks/fdiff_callback.py | 6 +- llmfoundry/callbacks/generate_callback.py | 13 +- .../callbacks/model_gauntlet_callback.py | 28 ++- .../callbacks/monolithic_ckpt_callback.py | 4 +- llmfoundry/data/data.py | 1 + llmfoundry/data/denoising.py | 26 +-- llmfoundry/data/finetuning/collator.py | 2 +- llmfoundry/data/finetuning/dataloader.py | 16 +- llmfoundry/data/finetuning/tasks.py | 13 +- llmfoundry/data/packing.py | 5 +- llmfoundry/data/text_data.py | 22 +-- llmfoundry/models/hf/hf_causal_lm.py | 18 +- llmfoundry/models/hf/hf_fsdp.py | 30 +-- llmfoundry/models/hf/hf_prefix_lm.py | 19 +- llmfoundry/models/hf/hf_t5.py | 19 +- llmfoundry/models/hf/model_wrapper.py | 28 +-- llmfoundry/models/layers/attention.py | 174 ++++++++++-------- llmfoundry/models/layers/blocks.py | 47 +++-- llmfoundry/models/layers/custom_embedding.py | 1 - llmfoundry/models/layers/ffn.py | 11 +- llmfoundry/models/layers/norm.py | 44 ++--- llmfoundry/models/mpt/configuration_mpt.py | 16 +- llmfoundry/models/mpt/modeling_mpt.py | 70 +++---- llmfoundry/models/utils/adapt_tokenizer.py | 11 +- .../models/utils/hf_prefixlm_converter.py | 32 ++-- llmfoundry/models/utils/meta_init_context.py | 21 ++- llmfoundry/models/utils/param_init_fns.py | 43 +++-- llmfoundry/optim/adaptive_lion.py | 30 +-- llmfoundry/optim/lion.py | 15 +- llmfoundry/utils/builders.py | 44 +++-- llmfoundry/utils/config_utils.py | 11 +- llmfoundry/utils/huggingface_hub_utils.py | 14 +- pyproject.toml | 12 +- scripts/data_prep/convert_dataset_hf.py | 5 +- scripts/data_prep/convert_dataset_json.py | 10 - scripts/eval/eval.py | 32 +++- scripts/inference/benchmarking/benchmark.py | 13 +- scripts/inference/convert_hf_mpt_to_ft.py | 7 +- scripts/inference/convert_hf_to_onnx.py | 6 +- scripts/inference/hf_chat.py | 36 ++-- scripts/inference/hf_generate.py | 21 ++- scripts/inference/run_mpt_with_ft.py | 17 +- scripts/misc/convert_examples_ckpt.py | 7 +- scripts/train/benchmarking/collect_results.py | 14 +- .../train/benchmarking/submit_benchmarks.py | 75 +++++--- .../train/finetune_example/preprocessing.py | 8 +- scripts/train/train.py | 34 ++-- setup.py | 2 +- tests/conftest.py | 2 +- tests/test_dataloader.py | 23 ++- tests/test_flash_triton_torch.py | 40 +++- tests/test_hf_config.py | 10 +- tests/test_hf_conversion_script.py | 8 +- tests/test_hf_mpt_gen.py | 17 +- tests/test_hf_v_mpt.py | 6 +- tests/test_icl_datasets.py | 17 +- tests/test_init_fn.py | 48 ++--- tests/test_model.py | 149 ++++++++------- tests/test_onnx.py | 13 +- tests/test_tokenizer.py | 2 +- tests/test_training.py | 8 +- 63 files changed, 840 insertions(+), 647 deletions(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 6f78284c84..77f1b55e0e 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -24,7 +24,6 @@ jobs: strategy: matrix: python_version: - - '3.8' - '3.9' - '3.10' pip_deps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 881a6bafec..66990493ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -89,6 +89,16 @@ repos: entry: yamllint language: python types: [file, yaml] +- repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: node + types: [python] + pass_filenames: false + args: [--warnings] + additional_dependencies: ["pyright@1.1.256"] - repo: https://github.com/trufflesecurity/trufflehog.git rev: v3.40.0 hooks: diff --git a/llmfoundry/callbacks/fdiff_callback.py b/llmfoundry/callbacks/fdiff_callback.py index bcef73875d..3c6064932d 100644 --- a/llmfoundry/callbacks/fdiff_callback.py +++ b/llmfoundry/callbacks/fdiff_callback.py @@ -10,13 +10,15 @@ class FDiffMetrics(Callback): - """Rate of chage of metrics. + """Rate of change of metrics. tracks and plots the rate of change of metrics effectively taking the numerical derivative of the metrics """ - def __init__(self, diff_train_metrics=False, diff_eval_metrics=True): + def __init__(self, + diff_train_metrics: bool = False, + diff_eval_metrics: bool = True): self.diff_train_metrics = diff_train_metrics self.diff_eval_metrics = diff_eval_metrics diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index 476f9a0948..b6596fbc6a 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Periodically log generations to wandb from a set of prompts.""" -from typing import List, Union, cast +from typing import Any, List, Union, cast import torch import wandb @@ -16,7 +16,8 @@ class Generate(Callback): - def __init__(self, prompts: List[str], batch_log_interval: int, **kwargs): + def __init__(self, prompts: List[str], batch_log_interval: int, + **kwargs: Any): """Periodically log generations to wandb from a set of prompts. In the main view for a run, there will be a table that will show the _last_ logged generations. @@ -57,6 +58,11 @@ def generate(self, state: State, logger: Logger): tokenizer = cast(Tokenizer, state.model.tokenizer) device = state.device + if not hasattr(model.model, 'generate'): + raise ValueError( + f'Cannot generate from model {model.model.__class__.__name__} because it does not have a `generate` method' + ) + # stash the original original value of padding_side because generation requires left padding original_padding_side = tokenizer.padding_side tokenizer.padding_side = 'left' @@ -74,9 +80,10 @@ def generate(self, state: State, logger: Logger): dummy_input = device.tensor_to_device(dummy_input) with get_precision_context(state.precision): with torch.no_grad(): + assert isinstance(model.model, torch.nn.Module) _ = model.model(input_ids=dummy_input) - output_token_ids = model.model.generate( + output_token_ids = model.model.generate( # type: ignore input_ids=tokenized_input['input_ids'], attention_mask=tokenized_input['attention_mask'], synced_gpus=True, diff --git a/llmfoundry/callbacks/model_gauntlet_callback.py b/llmfoundry/callbacks/model_gauntlet_callback.py index cd81eafbec..e97d800a9a 100644 --- a/llmfoundry/callbacks/model_gauntlet_callback.py +++ b/llmfoundry/callbacks/model_gauntlet_callback.py @@ -35,7 +35,8 @@ class ModelGauntlet(Callback): weighting (Weighting): The weighting scheme used to balance different tasks within each category. Either assign them all equal weight, assign them weight proportional to the dataset size, or assign them weight proportional to the log2 of the dataset size. - substract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy + Options are 'EQUAL', 'SAMPLE_SZ', and 'LOG_SAMPLE_SZ'. + subtract_random_baseline (bool): Flag determining whether to subtract random baseline accuracy from the performance on each individual benchmark before aggregating. rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0. @@ -45,7 +46,7 @@ class ModelGauntlet(Callback): def __init__(self, logger_keys: dict, categories: dict, - weighting: Weighting = Weighting.EQUAL, + weighting: str = 'EQUAL', subtract_random_baseline: bool = True, rescale_accuracy: bool = True, benchmark_sizes: Optional[dict] = None): @@ -69,10 +70,16 @@ def __init__(self, for benchmark in category['benchmarks']: bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot" - cumulative_samples = max( - sum(count for name, count in benchmark_sizes.items() - if name.startswith(bench_name)), 1) + if self.weighting != Weighting.EQUAL: + assert benchmark_sizes is not None + cumulative_samples = max( + sum(count for name, count in benchmark_sizes.items() + if name.startswith(bench_name)), 1) + else: + cumulative_samples = -1 # pyright + + weight = None if self.weighting == Weighting.EQUAL: weight = 1 elif self.weighting == Weighting.SAMPLE_SZ: @@ -80,16 +87,21 @@ def __init__(self, elif self.weighting == Weighting.LOG_SAMPLE_SZ: weight = max(math.log(cumulative_samples, 2), 1) + assert weight is not None benchmark['weighting'] = weight - def compute_averages(self, logger_data): + def compute_averages(self, logger_data: Logger): results = {} pat = re.compile( - 'metrics/(.*?)/(\d+)-shot(/.*?)?/InContextLearning(.*)') + 'metrics/(.*?)/(\d+)-shot(/.*?)?/InContextLearning(.*)' # type: ignore + ) for key in self.logger_keys: match = pat.match(key) - val = logger_data.data[key][0][1].item() + + # TODO(bmosaicml) This needs to be factored for this callback to work as a normal callback + # and therefore for the typing to be fixed + val = logger_data.data[key][0][1].item() # type: ignore if match: eval_name = match.group(1) diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index afca099832..71d1a93f7d 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -72,7 +72,9 @@ def _save_checkpoint(self, state: State, logger: Logger): ) if self.upload_to_object_store else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: - save_path = str(Path(temp_save_dir) / Path(filename)) + save_path = str( + Path(temp_save_dir) / # type: ignore + Path(filename)) dirname = os.path.dirname(save_path) if dirname: os.makedirs(dirname, exist_ok=True) diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index 384d71ec0a..ef758dfcef 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -96,6 +96,7 @@ def __init__( 'eos_text' if eos_text_provided else 'bos_text') warnings.warn( f'The provided tokenizer adds special tokens, but you also specified {message}. This may result ' + + 'in duplicated special tokens. Please be sure this is what you intend.' ) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 8a953f5841..443777668c 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -13,7 +13,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om from torch.utils.data import DataLoader -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerBase from llmfoundry.data.packing import BinPackWrapper from llmfoundry.data.text_data import StreamingTextDataset @@ -26,16 +26,15 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - # Required signature of any `prefix_function` (see below) -PREFIX_FUNCTION = Callable[[float, Optional[float], Tokenizer], Sequence[int]] +PREFIX_FUNCTION = Callable[[float, Optional[float], PreTrainedTokenizerBase], + Sequence[int]] def ul2_prefix_function( mask_ratio: float, mean_length: Optional[float], - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, ) -> Sequence[int]: """Generates prefixes based on UL2 paper. @@ -132,7 +131,7 @@ class MixtureOfDenoisersCollator: def __init__( self, - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, max_seq_length: int, decoder_only_format: bool = False, span_mean_lengths_and_ratios: Optional[List] = None, @@ -352,7 +351,7 @@ def __call__(self, examples: List[Dict[str, def build_text_denoising_dataloader( cfg: DictConfig, - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, device_batch_size: int, ) -> DataLoader[Dict]: """Constructor function for a Mixture of Denoisers dataloader. @@ -527,7 +526,7 @@ def noise_token_sequence( prefix_tokens: Optional[Sequence[int]], max_raw_length: int, max_seq_length: int, - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, sentinel_token_ids: np.ndarray, decoder_only_format: bool, context_eos: bool, @@ -678,7 +677,8 @@ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray: """ span_markers = np.less(np.arange(total_tokens - 1), num_spans - 1)[np.random.permutation(total_tokens - 1)] - span_start_indicator = np.concatenate([[0], span_markers]) + span_start_indicator = np.concatenate([[0], + span_markers]) # type: ignore span_id = np.cumsum(span_start_indicator).reshape(-1, 1) spans = np.arange(num_spans).reshape(1, -1) span_lengths = np.sum(span_id == spans, axis=0) @@ -715,12 +715,13 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray], # Ensure there's an end-of-sentence token at the end if ensure_eos and (noised_tokens[-1] != eos_token_id): - noised_tokens = np.concatenate([noised_tokens, [eos_token_id]]) + noised_tokens = np.concatenate([noised_tokens, + [eos_token_id]]) # type: ignore return noised_tokens # Masking at previous token - prev_token_mask = np.concatenate([[0], mask[:-1]]) + prev_token_mask = np.concatenate([[0], mask[:-1]]) # type: ignore # Decompose mask into start-of-span mask and non-start-of-span mask start_of_noise_span_token = np.logical_and(mask, @@ -739,7 +740,8 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray], # Ensure there's an end-of-sentence token at the end if ensure_eos and (noised_tokens[-1] != eos_token_id): - noised_tokens = np.concatenate([noised_tokens, [eos_token_id]]) + noised_tokens = np.concatenate([noised_tokens, + [eos_token_id]]) # type: ignore return noised_tokens diff --git a/llmfoundry/data/finetuning/collator.py b/llmfoundry/data/finetuning/collator.py index bcaef79527..1e402c4647 100644 --- a/llmfoundry/data/finetuning/collator.py +++ b/llmfoundry/data/finetuning/collator.py @@ -336,7 +336,7 @@ def _process_and_batch_encoder_decoder( return batch -def ensure_list(x: Union[List, torch.Tensor]): +def ensure_list(x: Union[List, torch.Tensor]) -> List: if isinstance(x, torch.Tensor): x = list(x.flatten()) assert isinstance(x, list) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index df07acf7cf..86aa7e0815 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -3,13 +3,12 @@ import logging import os -from typing import Union import torch from composer.utils import dist, get_file, parse_uri from omegaconf import DictConfig from torch.utils.data import DataLoader -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerBase from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator from llmfoundry.data.finetuning.tasks import dataset_constructor @@ -20,10 +19,9 @@ # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - -def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer, +def build_finetuning_dataloader(cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, device_batch_size: int) -> DataLoader: """Builds a finetuning dataloader for training or evaluating. @@ -115,6 +113,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer, if tokenizer.pad_token is None: # type: ignore tokenizer.pad_token = tokenizer.eos_token + dataset = None # for pyright if cfg.dataset.get('remote') is not None: dataset = dataset_constructor.build_from_streaming( tokenizer=tokenizer, @@ -166,6 +165,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer, collate_fn, dataloader_batch_size = _build_collate_fn( cfg.dataset, tokenizer, device_batch_size) + assert dataset is not None return DataLoader( dataset, collate_fn=collate_fn, @@ -235,7 +235,8 @@ def _validate_config(dataset_cfg: DictConfig): ) -def _build_hf_dataset_from_remote(cfg: DictConfig, tokenizer: Tokenizer): +def _build_hf_dataset_from_remote(cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase): """Builds a dataset from a remote object store. This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download @@ -313,7 +314,8 @@ def _build_hf_dataset_from_remote(cfg: DictConfig, tokenizer: Tokenizer): return dataset -def _build_collate_fn(dataset_cfg: DictConfig, tokenizer: Tokenizer, +def _build_collate_fn(dataset_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase, device_batch_size: int): collate_fn = Seq2SeqFinetuningCollator( tokenizer=tokenizer, diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index 509df3adee..9c01aab49c 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -39,14 +39,13 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import datasets as hf_datasets from omegaconf import DictConfig from streaming import StreamingDataset -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerBase __all__ = ['dataset_constructor'] -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - -def _tokenize_formatted_example(example: Dict[str, Any], tokenizer: Tokenizer): +def _tokenize_formatted_example(example: Dict[str, Any], + tokenizer: PreTrainedTokenizerBase): if ('prompt' not in example) or ('response' not in example): raise KeyError( 'Unable to tokenize example because it has not been properly formatted. ' +\ @@ -86,7 +85,7 @@ class StreamingFinetuningDataset(StreamingDataset): def __init__(self, local: str, - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, remote: Optional[str] = None, split: Optional[str] = None, shuffle: bool = False, @@ -162,7 +161,7 @@ def print_registered_tasks(self): tasks = sorted(self._task_preprocessing_registry.keys()) print('\n'.join(tasks)) - def get_preprocessing_fn_from_dict(self, mapping: dict): + def get_preprocessing_fn_from_dict(self, mapping: Union[Dict, DictConfig]): """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -256,7 +255,7 @@ def get_preprocessing_fn_from_str(self, return preprocessing_fn def build_from_hf(self, cfg: DictConfig, max_seq_len: int, - tokenizer: Tokenizer): + tokenizer: PreTrainedTokenizerBase): """Load a HuggingFace Datasets, preprocess, and tokenize. Note: This function will drop examples where the prompt is longer than the max_seq_len diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 1e0a6cec58..1bd03b42b4 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -6,6 +6,8 @@ import numpy as np import torch +from omegaconf import DictConfig +from transformers import PreTrainedTokenizerBase class BinPackWrapper: @@ -312,7 +314,8 @@ def parse_args() -> Namespace: raise ValueError('`num_packing_ratios` must be a positive integer.') return args - def build_dataloader(cfg, tokenizer, device_batch_size): + def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int): if cfg.name == 'text': return build_text_dataloader(cfg, tokenizer, device_batch_size) elif cfg.name == 'text_denoising': diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 6dff05bc50..b3765c25ef 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -5,7 +5,7 @@ import os from itertools import islice -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union import numpy as np import torch @@ -14,9 +14,7 @@ from omegaconf import OmegaConf as om from streaming import Stream, StreamingDataset from torch.utils.data import DataLoader -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast - -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +from transformers import PreTrainedTokenizerBase class StreamingTextDataset(StreamingDataset): @@ -70,7 +68,7 @@ class StreamingTextDataset(StreamingDataset): """ def __init__(self, - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, max_seq_len: int, streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, @@ -90,7 +88,7 @@ def __init__(self, shuffle_algo: str = 'py1b', shuffle_seed: int = 9176, shuffle_block_size: int = 1 << 18, - **kwargs: Dict[str, Any]): + **kwargs: Any): group_method = kwargs.pop('group_method', None) if group_method is not None: @@ -99,7 +97,7 @@ def __init__(self, 'concatenate, use the --concat_tokens ' + 'argument when creating your MDS dataset with concat_c4.py') - if kwargs is not None and len(kwargs) > 0: + if len(kwargs) > 0: raise ValueError( f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}' ) @@ -136,7 +134,7 @@ def __init__(self, self.max_seq_len = max_seq_len # How to tokenize a text sample to a token sample - def _tokenize(self, text_sample): + def _tokenize(self, text_sample: Mapping): if self.tokenizer._pad_token is None: # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs raise RuntimeError( @@ -147,7 +145,7 @@ def _tokenize(self, text_sample): padding='max_length', max_length=self.max_seq_len) - def _read_binary_tokenized_sample(self, sample): + def _read_binary_tokenized_sample(self, sample: Dict[str, Any]): return torch.from_numpy( np.frombuffer(sample['tokens'], dtype=np.int64)[:self.max_seq_len].copy()) @@ -172,8 +170,8 @@ class ConcatenatedSequenceCollatorWrapper: def __init__( self, base_collator: Callable, - eos_token_id=None, - bos_token_id=None, + eos_token_id: Optional[int] = None, + bos_token_id: Optional[int] = None, ): self.base_collator = base_collator if (eos_token_id is None) and (bos_token_id is None): @@ -215,7 +213,7 @@ def get_sequence_id_from_batch( def build_text_dataloader( cfg: DictConfig, - tokenizer: Tokenizer, + tokenizer: PreTrainedTokenizerBase, device_batch_size: int, ): assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 318825082b..b9277c03b1 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -16,8 +16,8 @@ LanguageCrossEntropy, LanguagePerplexity) from composer.utils import dist from omegaconf import DictConfig -from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import (AutoConfig, AutoModelForCausalLM, + PreTrainedTokenizerBase) from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss @@ -35,14 +35,12 @@ __all__ = ['ComposerHFCausalLM'] -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - class ComposerHFCausalLM(HuggingFaceModelWithZLoss): """Configures a :class:`.HuggingFaceModel` around a Causal LM. Args: - om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either n omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library. + om_model_config (DictConfig | PeftModel | transformers.PreTrainedModel): either an omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library. if DictConfig, the following keys are required: cfg.pretrained_model_name_or_path (str): The name of or local path to the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel). @@ -58,8 +56,10 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__(self, om_model_config: _om_model_config_type, - tokenizer: Tokenizer): + def __init__( + self, + om_model_config: _om_model_config_type, # type: ignore + tokenizer: PreTrainedTokenizerBase): # set up training and eval metrics train_metrics = [ @@ -102,8 +102,8 @@ def __init__(self, om_model_config: _om_model_config_type, ] if extra_keys: raise ValueError( - f'Config dict override got unknown keys. ' - f'Extra keys: {extra_keys}. ' + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + f'Expected (a subset of) keys: {list(attr.keys())}.' ) getattr(config, k).update(v) diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index 5aa65e4f4e..4a0b76e640 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -5,7 +5,7 @@ # which is MIT licensed import functools -from typing import Any, Iterable, List +from typing import Any, Iterable, List, Optional import torch from transformers import PreTrainedModel @@ -67,7 +67,12 @@ def hf_get_causal_base_model(model: PreTrainedModel): return model.get_decoder() decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox') - return findattr(model, decoder_attrs) + causal_base_model = findattr(model, decoder_attrs) + if causal_base_model is None: + raise ValueError( + f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.' + ) + return causal_base_model def hf_get_hidden_layers(model: PreTrainedModel): @@ -92,7 +97,7 @@ def hf_get_hidden_layers(model: PreTrainedModel): return findattr(model, hidden_layers_attrs) -def hf_get_init_device(init_device: str): +def hf_get_init_device(init_device: Optional[str]): """Returns the appropriate device to initialize models.""" from composer.utils import dist if init_device == 'mixed': @@ -105,7 +110,8 @@ def hf_get_init_device(init_device: str): # /end helper functions -def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: str) -> None: +def prepare_hf_model_for_fsdp(model: PreTrainedModel, + init_device: Optional[str]) -> None: """FSDP wrap a HuggingFace model. Call specific functions @@ -119,13 +125,14 @@ def prepare_hf_model_for_fsdp(model: PreTrainedModel, init_device: str) -> None: def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, - init_device: str) -> None: + init_device: Optional[str]) -> None: """FSDP wrap a HuggingFace decoder. Wrap any model for FSDP which follows one of the 3 existing conventions from HuggingFace for decoder-only LLMs. """ causal_base_model = hf_get_causal_base_model(model) + # OPT has an extra layer of wrapping, so special case here if isinstance(causal_base_model, OPTDecoder): model.model._fsdp_wrap = False @@ -133,7 +140,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, lm_head = model.get_output_embeddings() # some models (OPT) implement .get_input_embeddings for the causal subclass # but all of them implement it for the base model - tied_embeddings = causal_base_model.get_input_embeddings() # type: ignore + tied_embeddings = causal_base_model.get_input_embeddings() modules = { 'base_model': causal_base_model, 'model_block': model_block, @@ -148,7 +155,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, 'follow common layer/weight naming conventions.') block_type = type(model_block[0]) # type: ignore if init_device == 'mixed': - # For FSDP with models with different device intiailizations, `mixed`, which + # For FSDP with models with different device initializations, `mixed`, which # initializes the model on rank 0 on `cpu` and on all other ranks on `meta,`` # we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`. for child in model.children(): @@ -165,11 +172,12 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, if model.config.tie_word_embeddings and not model.config.model_type == 'mpt': raise ValueError( - 'The passed in HuggingFaceModel has tied word embeddings ' - 'and the passed in initialization device is `mixed.` ' + 'The passed in HuggingFaceModel has tied word embeddings ' + + 'and the passed in initialization device is `mixed.` ' + 'In order to support this initialization scheme, we would need to break ' + + 'the weight tying. As a result, either use a different initialization scheme ' - 'or in the model config set `tie_word_embeddings=False.`') + + 'or in the model config set `tie_word_embeddings=False.`') else: # When using the HF LM models, # the weights of the self.lm_head and self.transformer.wte are tied. @@ -189,7 +197,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, - init_device: str) -> None: + init_device: Optional[str]) -> None: """Wrap an encoder/decoder HF model. This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet) diff --git a/llmfoundry/models/hf/hf_prefix_lm.py b/llmfoundry/models/hf/hf_prefix_lm.py index 863db0d08a..7152dfae70 100644 --- a/llmfoundry/models/hf/hf_prefix_lm.py +++ b/llmfoundry/models/hf/hf_prefix_lm.py @@ -5,13 +5,13 @@ from __future__ import annotations -from typing import Mapping, Union +from typing import Mapping, MutableMapping from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy from composer.utils import dist from omegaconf import DictConfig -from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import (AutoConfig, AutoModelForCausalLM, + PreTrainedTokenizerBase) from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss @@ -22,8 +22,6 @@ __all__ = ['ComposerHFPrefixLM'] -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 @@ -68,7 +66,8 @@ class ComposerHFPrefixLM(HuggingFaceModelWithZLoss): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): + def __init__(self, om_model_config: DictConfig, + tokenizer: PreTrainedTokenizerBase): config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, trust_remote_code=om_model_config.get('trust_remote_code', True), @@ -87,8 +86,8 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] if extra_keys: raise ValueError( - f'Config dict override got unknown keys. ' - f'Extra keys: {extra_keys}. ' + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + f'Expected (a subset of) keys: {list(attr.keys())}.') getattr(config, k).update(v) else: @@ -101,7 +100,7 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): init_device = om_model_config.get('init_device', 'cpu') # Get the device we want to initialize, and use the - # reolved version to initialize the HF model + # resolved version to initialize the HF model resolved_init_device = hf_get_init_device(init_device) # We need to have all non-zero local ranks be not-pretrained @@ -145,7 +144,7 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): return composer_model - def forward(self, batch): + def forward(self, batch: MutableMapping): # Add bidirectional_mask if it is missing and can be constructed add_bidirectional_mask_if_missing(batch) return super().forward(batch) diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index e3b26e4b26..690a0de447 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -5,13 +5,13 @@ from __future__ import annotations -from typing import Mapping, Union +from typing import Mapping from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy from composer.utils import dist from omegaconf import DictConfig -from transformers import (AutoConfig, PreTrainedTokenizer, - PreTrainedTokenizerFast, T5ForConditionalGeneration) +from transformers import (AutoConfig, PreTrainedTokenizerBase, + T5ForConditionalGeneration) from llmfoundry.models.hf.hf_fsdp import hf_get_init_device from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss @@ -20,8 +20,6 @@ __all__ = ['ComposerHFT5'] -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 @@ -29,7 +27,7 @@ class ComposerHFT5(HuggingFaceModelWithZLoss): """Configures a :class:`.HuggingFaceModel` around a T5. - Note: This function uses `transformers.T5ForConditionalGenration`. Future releases + Note: This function uses `transformers.T5ForConditionalGeneration`. Future releases will expand support to more general classes of HF Encoder-Decoder models. Args: @@ -57,7 +55,8 @@ class ComposerHFT5(HuggingFaceModelWithZLoss): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): + def __init__(self, om_model_config: DictConfig, + tokenizer: PreTrainedTokenizerBase): config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, trust_remote_code=om_model_config.get('trust_remote_code', True), @@ -76,8 +75,8 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] if extra_keys: raise ValueError( - f'Config dict override got unknown keys. ' - f'Extra keys: {extra_keys}. ' + f'Config dict override got unknown keys. ' + + f'Extra keys: {extra_keys}. ' + f'Expected (a subset of) keys: {list(attr.keys())}.') getattr(config, k).update(v) else: @@ -94,7 +93,7 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): init_device = om_model_config.get('init_device', 'cpu') # Get the device we want to initialize, and use the - # reolved version to initialize the HF model + # resolved version to initialize the HF model resolved_init_device = hf_get_init_device(init_device) # We need to have all non-zero local ranks be not-pretrained diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index ec4cebd3ae..d59c243f9d 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -7,21 +7,20 @@ import inspect from collections import UserDict -from typing import List, Optional, Union +from typing import List, Mapping, Optional import torch import transformers from composer.models.huggingface import HuggingFaceModel from torchmetrics import Metric -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizerBase +from transformers.utils.generic import ModelOutput from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - class HuggingFaceModelWithZLoss(HuggingFaceModel): """Wrapper around HuggingFaceModel. @@ -41,7 +40,7 @@ class HuggingFaceModelWithZLoss(HuggingFaceModel): def __init__(self, model: transformers.PreTrainedModel, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, metrics: Optional[List[Metric]] = None, eval_metrics: Optional[List[Metric]] = None, z_loss: float = 0.0, @@ -68,7 +67,7 @@ def __init__(self, self.model.param_init_fn = lambda module: self.model._init_weights( module) - def forward(self, batch): + def forward(self, batch: Mapping): if isinstance(batch, dict) or isinstance(batch, UserDict): # Further input validation is left to the huggingface forward call batch = { @@ -81,7 +80,7 @@ def forward(self, batch): ) return output - def loss(self, outputs, batch): + def loss(self, outputs: ModelOutput, batch: Mapping): if self.config.use_return_dict: loss, logits = outputs['loss'], outputs['logits'] else: @@ -103,18 +102,3 @@ def loss(self, outputs, batch): else: outputs[0] += z_loss return outputs[0] - - # def eval_forward(self, batch, outputs: Optional[Any] = None): - # if 'generate_output' in batch: - # self.labels = batch.pop('labels') - # return self.model.generate( - # batch['input_ids'], - # attention_mask=batch['attention_mask'], - # max_new_tokens=512, - # do_sample=True, - # top_p=0.90, - # top_k=0, - # no_repeat_ngram_size=3, - # ) - - # return super().eval_forward(batch, outputs) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index d7bf4a87ea..34692e600b 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -5,7 +5,7 @@ import math import warnings -from typing import Optional +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -32,20 +32,21 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, def scaled_multihead_dot_product_attention( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + n_heads: int, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + softmax_scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + dropout_p: float = 0.0, + training: bool = False, + needs_weights: bool = False, + multiquery: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, + torch.Tensor]]]: q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) kv_n_heads = 1 if multiquery else n_heads k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) @@ -91,9 +92,9 @@ def scaled_multihead_dot_product_attention( if key_padding_mask is not None: if attn_bias is not None: warnings.warn( - 'Propogating key_padding_mask to the attention module ' +\ + 'Propagating key_padding_mask to the attention module ' +\ 'and applying it within the attention module can cause ' +\ - 'unneccessary computation/memory usage. Consider integrating ' +\ + 'unnecessary computation/memory usage. Consider integrating ' +\ 'into attn_bias once and passing that to each attention ' +\ 'module instead.' ) @@ -126,7 +127,10 @@ def scaled_multihead_dot_product_attention( return out, None, past_key_value -def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): +def check_valid_inputs(*tensors: torch.Tensor, + valid_dtypes: Optional[List[torch.dtype]] = None): + if valid_dtypes is None: + valid_dtypes = [torch.float16, torch.bfloat16] for tensor in tensors: if tensor.dtype not in valid_dtypes: raise TypeError(f'{tensor.dtype=} must be in {valid_dtypes=}.') @@ -135,20 +139,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): def flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + n_heads: int, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + softmax_scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + dropout_p: float = 0.0, + training: bool = False, + needs_weights: bool = False, + multiquery: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, + torch.Tensor]]]: try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: @@ -229,20 +234,21 @@ def flash_attn_fn( def triton_flash_attn_fn( - query, - key, - value, - n_heads, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - dropout_p=0.0, - training=False, - needs_weights=False, - multiquery=False, -): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + n_heads: int, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + softmax_scale: Optional[float] = None, + attn_bias: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, + dropout_p: float = 0.0, + training: bool = False, + needs_weights: bool = False, + multiquery: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, + torch.Tensor]]]: try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func except: @@ -260,9 +266,13 @@ def triton_flash_attn_fn( # default recommendation is to install this variant raise RuntimeError( 'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU ' + + 'and `pip install .[gpu]` if installing from llm-foundry source or ' + + '`pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` ' + + 'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). ' + + 'Note: (1) requires you have CMake and PyTorch already installed.' ) @@ -284,6 +294,7 @@ def triton_flash_attn_fn( if dropout_p: raise NotImplementedError( f'Dropout not implemented for attn_impl: triton.') + dropout_p = dropout_p if training else 0.0 if needs_weights: raise NotImplementedError( @@ -319,10 +330,10 @@ def triton_flash_attn_fn( value = value.repeat(1, 1, n_heads, 1) reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) - attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, - softmax_scale) + attn_output = flash_attn_func( # type: ignore + query, key, value, attn_bias, reset_is_causal, softmax_scale) - output = attn_output.view(*attn_output.shape[:2], -1) + output = attn_output.view(*attn_output.shape[:2], -1) # type: ignore return output, None, past_key_value @@ -330,7 +341,7 @@ def triton_flash_attn_fn( class MultiheadAttention(nn.Module): """Multi-head self attention. - Using torch or triton attention implemetation enables user to also use + Using torch or triton attention implementation enables user to also use additive bias. """ @@ -409,12 +420,12 @@ def __init__( def forward( self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, + x: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_bias: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = True, + needs_weights: bool = False, ): qkv = self.Wqkv(x) @@ -452,7 +463,7 @@ def forward( class MultiQueryAttention(nn.Module): """Multi-Query self attention. - Using torch or triton attention implemetation enables user to also use + Using torch or triton attention implementation enables user to also use additive bias. """ @@ -536,13 +547,14 @@ def __init__( def forward( self, - x, - past_key_value=None, - attn_bias=None, - attention_mask=None, - is_causal=True, - needs_weights=False, - ): + x: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attn_bias: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + is_causal: bool = True, + needs_weights: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) if self.clip_qkv: @@ -578,8 +590,8 @@ def forward( return self.out_proj(context), attn_weights, past_key_value -def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, - use_sequence_id): +def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, + prefix_lm: bool, causal: bool, use_sequence_id: bool): if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: @@ -595,13 +607,13 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, def build_attn_bias( - attn_impl, - attn_bias, - n_heads, - seq_len, - causal=False, - alibi=False, - alibi_bias_max=8, + attn_impl: str, + attn_bias: torch.Tensor, + n_heads: int, + seq_len: int, + causal: bool = False, + alibi: bool = False, + alibi_bias_max: int = 8, ): if attn_impl == 'flash': return None @@ -623,7 +635,9 @@ def build_attn_bias( raise ValueError(f'{attn_impl=} is an invalid setting.') -def gen_slopes(n_heads, alibi_bias_max=8, device=None): +def gen_slopes(n_heads: int, + alibi_bias_max: int = 8, + device: Optional[torch.device] = None): _n_heads = 2**math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) @@ -639,12 +653,12 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None): def build_alibi_bias( - n_heads, - seq_len, - full=False, - alibi_bias_max=8, - device=None, - dtype=None, + n_heads: int, + seq_len: int, + full: bool = False, + alibi_bias_max: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index faaca18eac..cec14e5d2a 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -3,13 +3,12 @@ """GPT Blocks used for the GPT Model.""" -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY @@ -21,32 +20,39 @@ def __init__( d_model: int, n_heads: int, expansion_ratio: int, - attn_config: Dict = { - 'attn_type': 'multihead_attention', - 'attn_pdrop': 0.0, - 'attn_impl': 'triton', - 'qk_ln': False, - 'clip_qkv': None, - 'softmax_scale': None, - 'prefix_lm': False, - 'attn_uses_sequence_id': False, - 'alibi': False, - 'alibi_bias_max': 8, - }, - ffn_config: Dict = { - 'ffn_type': 'mptmlp', - }, + attn_config: Optional[Dict] = None, + ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', verbose: int = 0, fc_type: str = 'torch', device: Optional[str] = None, - **kwargs, + **kwargs: Any, ): + if attn_config is None: + attn_config = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8, + } + + if ffn_config is None: + ffn_config = { + 'ffn_type': 'mptmlp', + } + del kwargs # unused, just to capture any extra args from the config super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] + assert isinstance(attn_config['attn_type'], str) attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] self.norm_1 = norm_class(d_model, device=device) @@ -79,11 +85,12 @@ def __init__( def forward( self, x: torch.Tensor, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) b, attn_weights, past_key_value = self.attn( a, diff --git a/llmfoundry/models/layers/custom_embedding.py b/llmfoundry/models/layers/custom_embedding.py index 7eca7fea7b..20a2be3a55 100644 --- a/llmfoundry/models/layers/custom_embedding.py +++ b/llmfoundry/models/layers/custom_embedding.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index d3a77f009c..a02558102f 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -3,14 +3,12 @@ """GPT Blocks used for the GPT Model.""" -from typing import Optional +from typing import Any, Optional import torch import torch.nn as nn -from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY try: import transformer_engine.pytorch as te @@ -44,7 +42,7 @@ def __init__( ) self.down_proj._is_residual = True # type: ignore - def forward(self, x): + def forward(self, x: torch.Tensor): return self.down_proj(self.act(self.up_proj(x))) @@ -62,11 +60,11 @@ def build_ffn( expansion_ratio: int, fc_type: str = 'torch', device: Optional[str] = None, - **kwargs, + **kwargs: Any, ): ffn_type = kwargs.pop('ffn_type') if ffn_type == 'mptmlp': - if kwargs is not None and len(kwargs) > 0: + if len(kwargs) > 0: raise ValueError( f'MPTMLP got an unexpected keyword argument: {kwargs}') return MPTMLP( @@ -76,6 +74,7 @@ def build_ffn( device=device, ) elif ffn_type == 'te_ln_mlp': + assert te is not None return te.LayerNormMLP( hidden_size=d_model, ffn_hidden_size=d_model * expansion_ratio, diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 8eff7b545e..fabe0a8ccb 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -1,12 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Type +from typing import Dict, List, Optional, Type, Union import torch -def _cast_if_autocast_enabled(tensor): +def _cast_if_autocast_enabled(tensor: torch.Tensor): if torch.is_autocast_enabled(): if tensor.device.type == 'cuda': dtype = torch.get_autocast_gpu_dtype() @@ -22,11 +22,11 @@ class LPLayerNorm(torch.nn.LayerNorm): def __init__( self, - normalized_shape, - eps=1e-05, - elementwise_affine=True, - device=None, - dtype=None, + normalized_shape: Union[int, List[int], torch.Size], + eps: float = 1e-05, + elementwise_affine: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): super().__init__( normalized_shape=normalized_shape, @@ -36,7 +36,7 @@ def __init__( dtype=dtype, ) - def forward(self, x): + def forward(self, x: torch.Tensor): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled( @@ -53,7 +53,9 @@ def forward(self, x): ) -def rms_norm(x, weight=None, eps=1e-5): +def rms_norm(x: torch.Tensor, + weight: Optional[torch.Tensor] = None, + eps: float = 1e-5): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight @@ -64,11 +66,11 @@ class RMSNorm(torch.nn.Module): def __init__( self, - normalized_shape, - eps=1e-5, - weight=True, - dtype=None, - device=None, + normalized_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + weight: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ): super().__init__() self.eps = eps @@ -78,7 +80,7 @@ def __init__( else: self.register_parameter('weight', None) - def forward(self, x): + def forward(self, x: torch.Tensor): return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) @@ -86,11 +88,11 @@ class LPRMSNorm(RMSNorm): def __init__( self, - normalized_shape, - eps=1e-5, - weight=True, - dtype=None, - device=None, + normalized_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + weight: bool = True, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ): super().__init__( normalized_shape=normalized_shape, @@ -100,7 +102,7 @@ def __init__( device=device, ) - def forward(self, x): + def forward(self, x: torch.Tensor): downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled( self.weight) if self.weight is not None else self.weight diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index f387d5307a..7558a4b9ba 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -4,7 +4,7 @@ """A HuggingFace-style model configuration.""" import warnings -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union from transformers import PretrainedConfig @@ -62,7 +62,7 @@ def __init__( use_cache: bool = False, init_config: Dict = init_config_defaults, fc_type: str = 'torch', - **kwargs, + **kwargs: Any, ): """The MPT configuration class. @@ -119,7 +119,7 @@ def __init__( init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options - fc_type (str): choose fc layer implementaion. Options: torch and te. te layers support fp8 when using H100 GPUs. + fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. """ self.d_model = d_model self.n_heads = n_heads @@ -141,6 +141,7 @@ def __init__( self.use_cache = use_cache self.init_config = init_config self.fc_type = fc_type + self.bias = None if 'name' in kwargs: del kwargs['name'] if 'loss_fn' in kwargs: @@ -153,7 +154,8 @@ def __init__( self._validate_config() - def _set_config_defaults(self, config, config_defaults): + def _set_config_defaults(self, config: Dict[str, Any], + config_defaults: Dict[str, Any]): # set config defaults for k, v in config_defaults.items(): if k not in config: @@ -218,11 +220,13 @@ def _validate_config(self): if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': try: import transformer_engine.pytorch as te + del te # unused except: raise ImportError( - 'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed.' + 'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. ' + + 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' - 'pip install flash-attn==1.0.6 --no-build-isolation \n' + + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' ) if self.ffn_config['ffn_type'] == 'mptmlp': diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 168520abe9..744886cbbd 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,7 +8,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union import torch import torch.nn as nn @@ -23,16 +23,18 @@ from composer.utils import dist from omegaconf import DictConfig from omegaconf import OmegaConf as om -from transformers import (PreTrainedModel, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.modeling_outputs import (BaseModelOutputWithPast, CausalLMOutputWithPast) from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.layers.custom_embedding import SharedEmbedding -from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY -from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn +from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY +from llmfoundry.models.layers.ffn import \ + FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY +from llmfoundry.models.layers.ffn import MPTMLP as MPTMLP +from llmfoundry.models.layers.ffn import build_ffn as build_ffn from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY from llmfoundry.models.mpt.configuration_mpt import MPTConfig @@ -56,13 +58,11 @@ ) try: - from llmfoundry.models.layers.flash_attn_triton import flash_attn_func + from llmfoundry.models.layers.flash_attn_triton import flash_attn_func as flash_attn_func except: pass # isort: on -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig @@ -165,8 +165,8 @@ def set_input_embeddings(self, value: nn.Embedding): @torch.no_grad() def _attn_bias( self, - device, - dtype, + device: torch.device, + dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, @@ -307,10 +307,12 @@ def forward( if use_cache is not None else self.config.use_cache) if attention_mask is not None: - attention_mask = attention_mask.bool() + attention_mask = attention_mask.bool( + ) # type: ignore (TODO to figure out the right type here) if prefix_mask is not None: - prefix_mask = prefix_mask.bool() + prefix_mask = prefix_mask.bool( + ) # type: ignore (TODO to figure out the right type here) # These args are passed in by keyword in huggingface's generate function # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206 @@ -379,6 +381,7 @@ def forward( if S + past_position > self.config.max_seq_len: raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' + + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) pos = torch.arange( @@ -460,7 +463,7 @@ def forward( ) # Param Initialization, needed for device='meta' fast initialization - def param_init_fn(self, module): + def param_init_fn(self, module: nn.Module): init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name]( module=module, @@ -470,11 +473,11 @@ def param_init_fn(self, module): ) # FSDP Wrap function - def fsdp_wrap_fn(self, module): + def fsdp_wrap_fn(self, module: nn.Module): return isinstance(module, MPTBlock) # Activation Checkpointing - def activation_checkpointing_fn(self, module): + def activation_checkpointing_fn(self, module: nn.Module): return isinstance(module, MPTBlock) @@ -513,16 +516,17 @@ def __init__(self, config: MPTConfig): def get_input_embeddings(self): return self.transformer.wte - def set_input_embeddings(self, value): + def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]): self.transformer.wte = value def get_output_embeddings(self): return self.transformer.wte - def set_output_embeddings(self, new_embeddings): + def set_output_embeddings(self, new_embeddings: Union[SharedEmbedding, + nn.Embedding]): self.transformer.wte = new_embeddings - def set_decoder(self, decoder): + def set_decoder(self, decoder: MPTModel): self.transformer = decoder def get_decoder(self): @@ -596,7 +600,7 @@ def forward( ) # Param Initialization, needed for device='meta' fast initialization - def param_init_fn(self, module): + def param_init_fn(self, module: nn.Module): init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name]( module=module, @@ -606,19 +610,20 @@ def param_init_fn(self, module): ) # FSDP Wrap function - def fsdp_wrap_fn(self, module): + def fsdp_wrap_fn(self, module: nn.Module): return isinstance(module, MPTBlock) # Activation Checkpointing - def activation_checkpointing_fn(self, module): + def activation_checkpointing_fn(self, module: nn.Module): return isinstance(module, MPTBlock) def prepare_inputs_for_generation( self, - input_ids, - past_key_values=None, - inputs_embeds=None, - **kwargs, + input_ids: torch.Tensor, + past_key_values: Optional[List[Tuple[torch.Tensor, + torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Any, ): if inputs_embeds is not None: raise NotImplementedError( @@ -657,7 +662,8 @@ def prepare_inputs_for_generation( } @staticmethod - def _reorder_cache(past_key_values, beam_idx): + def _reorder_cache(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], + beam_idx: torch.LongTensor): """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 @@ -678,7 +684,7 @@ class ComposerMPTCausalLM(HuggingFaceModel): def __init__( self, om_model_config: DictConfig, - tokenizer: Optional[Tokenizer] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, ): resolved_om_model_config = om.to_container(om_model_config, resolve=True) @@ -719,7 +725,9 @@ def __init__( except: raise ValueError( 'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU ' + + 'and `pip install .[gpu]` if installing from source or `pip install xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy` ' + + 'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.' ) elif loss_fn_config == 'torch_crossentropy': @@ -729,12 +737,12 @@ def __init__( f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].' ) - def get_targets(self, batch): + def get_targets(self, batch: Mapping): targets = torch.roll(batch['labels'], shifts=-1) targets[:, -1] = -100 return targets - def forward(self, batch): + def forward(self, batch: MutableMapping): if self.model.transformer.prefix_lm: add_bidirectional_mask_if_missing(batch) # Note: prefix_mask is only used if model.prefix_lm is True @@ -746,12 +754,12 @@ def forward(self, batch): inputs_embeds=batch.get('inputs_embeds', None), ) - def loss(self, outputs, batch): + def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping): targets = self.get_targets(batch) return self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1)) - def flops_per_batch(self, batch): + def flops_per_batch(self, batch: Mapping): # Note: this computation does not take into account padding, and assumes # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass diff --git a/llmfoundry/models/utils/adapt_tokenizer.py b/llmfoundry/models/utils/adapt_tokenizer.py index 72145fe742..df98ba6895 100644 --- a/llmfoundry/models/utils/adapt_tokenizer.py +++ b/llmfoundry/models/utils/adapt_tokenizer.py @@ -1,19 +1,16 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Union +from typing import Any -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) - -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +from transformers import AutoTokenizer, PreTrainedTokenizerBase # For consistency with T5 Tokenizer, which is what this adaptation aims to mimic, # we hardcode there to be 100 sentinel tokens NUM_SENTINEL_TOKENS: int = 100 -def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): +def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase): """Adds sentinel tokens and padding token (if missing). Expands the tokenizer vocabulary to include sentinel tokens @@ -52,7 +49,7 @@ class AutoTokenizerForMOD(AutoTokenizer): """ @classmethod - def from_pretrained(cls, *args, **kwargs): + def from_pretrained(cls, *args: Any, **kwargs: Any): """See `AutoTokenizer.from_pretrained` docstring.""" tokenizer = super().from_pretrained(*args, **kwargs) adapt_tokenizer_for_denoising(tokenizer) diff --git a/llmfoundry/models/utils/hf_prefixlm_converter.py b/llmfoundry/models/utils/hf_prefixlm_converter.py index 872249738d..ae8ed444c8 100644 --- a/llmfoundry/models/utils/hf_prefixlm_converter.py +++ b/llmfoundry/models/utils/hf_prefixlm_converter.py @@ -13,7 +13,7 @@ import math import warnings from types import MethodType -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, MutableMapping, Optional, Tuple, Union import torch from transformers.models.bloom.modeling_bloom import ( @@ -188,8 +188,7 @@ def call_og_forward(): # Return the outputs return output - def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, - Any]): + def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any): """Wraps original generate to enable PrefixLM attention.""" attn_modules = _get_attn_modules(model) @@ -343,7 +342,7 @@ def forward( # type: ignore output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments) -> Union[Tuple[ + **deprecated_arguments: Any) -> Union[Tuple[ torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop('position_ids', False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so @@ -441,9 +440,9 @@ def forward( # type: ignore ) use_cache = False - def create_custom_forward(module): + def create_custom_forward(module: torch.nn.Module): - def custom_forward(*inputs): + def custom_forward(*inputs: Any): # None for past_key_value return module(*inputs, use_cache=use_cache, @@ -526,7 +525,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments + **deprecated_arguments: Any, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """Replacement forward method for BloomCausalLM.""" if deprecated_arguments.pop('position_ids', False) is not False: @@ -592,7 +591,8 @@ def prepare_inputs_for_generation(self: BloomForCausalLM, past: Optional[torch.Tensor] = None, attention_mask: Optional[ torch.Tensor] = None, - **kwargs) -> dict: + **kwargs: Any) -> dict: + del kwargs # unused # only last token for input_ids if past is not None if past: input_ids = input_ids[:, -1].unsqueeze(-1) # type: ignore @@ -600,7 +600,7 @@ def prepare_inputs_for_generation(self: BloomForCausalLM, # has been encoded into `past` bidirectional_mask = None - # the cache may be in the stardard format (e.g. in contrastive + # the cache may be in the standard format (e.g. in contrastive # search), convert to bloom's format if needed if past[0][0].shape[0] == input_ids.shape[0]: past = self._convert_to_bloom_cache(past) @@ -656,12 +656,16 @@ def _convert_opt_causal_lm_to_prefix_lm( # Modified from transformers.models.bloom.modeling_opt.OPTDecoder._prepare_decoder_attn_mask # https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/models/opt/modeling_opt.py#L532 - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, - inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask(self: torch.nn.Module, + attention_mask: Optional[torch.Tensor], + input_shape: Tuple[int, int], + inputs_embeds: Optional[torch.Tensor], + past_key_values_length: int): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: + assert inputs_embeds is not None # 'g' indicates generation mode. Causal mask replaced with 0. if self.bidirectional_mask == 'g': bsz, src_length = input_shape @@ -679,6 +683,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, # Make use of the batch-specific `bidirectional_mask` attribute # set by the parent module in its (new) `forward` method wrapper if self.bidirectional_mask is not None: + assert attention_mask is not None # The two masks should have the same size assert attention_mask.shape == self.bidirectional_mask.shape @@ -691,6 +696,7 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, expanded_bidirectional_mask, combined_attention_mask) if attention_mask is not None: + assert inputs_embeds is not None # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, @@ -758,7 +764,7 @@ def call_og_forward(): # Return the outputs return outputs - def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]): + def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Any): """Wraps original generate to enable PrefixLM-style attention.""" # Flag the child module to use generation-style attention masking self.model.decoder.bidirectional_mask = 'g' @@ -867,7 +873,7 @@ def convert_hf_causal_lm_to_prefix_lm( ) -def add_bidirectional_mask_if_missing(batch: Dict[str, Any]): +def add_bidirectional_mask_if_missing(batch: MutableMapping): """Attempts to add bidirectional_mask to batch if missing. Raises: diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py index 8ff4c0a5b9..eee30f9357 100644 --- a/llmfoundry/models/utils/meta_init_context.py +++ b/llmfoundry/models/utils/meta_init_context.py @@ -15,9 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Modified from https://github.com/huggingface/accelerate/blob/main/src/accelerate/big_modeling.py - from contextlib import contextmanager +# Modified from https://github.com/huggingface/accelerate/blob/main/src/accelerate/big_modeling.py +from typing import Any, Callable, Optional import torch import torch.nn as nn @@ -80,18 +80,21 @@ def init_on_device(device: torch.device, include_buffers: bool = False): if include_buffers: old_register_buffer = nn.Module.register_buffer - def register_empty_parameter(module, name, param): + def register_empty_parameter(module: torch.nn.Module, name: str, + param: Optional[torch.nn.Parameter]): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) - kwargs = module._parameters[name].__dict__ + kwargs = module._parameters[name].__dict__ # type: ignore module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs) + module._parameters[name].to(device), **kwargs) # type: ignore - def register_empty_buffer(module, name, buffer): + def register_empty_buffer(module: torch.nn.Module, name: str, + buffer: Optional[torch.Tensor]): old_register_buffer(module, name, buffer) if buffer is not None: - module._buffers[name] = module._buffers[name].to(device) + module._buffers[name] = module._buffers[name].to( # type: ignore + device) # Patch tensor creation if include_buffers: @@ -102,9 +105,9 @@ def register_empty_buffer(module, name, buffer): else: tensor_constructors_to_patch = {} - def patch_tensor_constructor(fn): + def patch_tensor_constructor(fn: Callable): - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any): kwargs['device'] = device return fn(*args, **kwargs) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 3df6387b1d..0dbb4e4a6f 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -5,7 +5,7 @@ import warnings from collections.abc import Sequence from functools import partial -from typing import Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch from torch import nn @@ -22,7 +22,7 @@ def torch_default_param_init_fn_( module: nn.Module, verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config if verbose > 1: @@ -33,11 +33,11 @@ def torch_default_param_init_fn_( module.reset_parameters() # type: ignore -def fused_init_helper_(module: nn.Module, init_fn_): +def fused_init_helper_(module: nn.Module, init_fn_: Callable): # parameter initialization is often based on the parameters shape. # If a layer is fused, initialization should be based on the shapes # of the original tensor instead of the shape of the fused tensor. - # Layers which are fused should have the _fused attibute defined. + # Layers which are fused should have the _fused attribute defined. # The first element of _fused is the dimension along which the tensor is fused. # This is followed by an iterable of split indices." @@ -56,14 +56,14 @@ def fused_init_helper_(module: nn.Module, init_fn_): def generic_param_init_fn_( module: nn.Module, - init_fn_, + init_fn_: Callable, n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config if verbose > 1: @@ -82,8 +82,9 @@ def generic_param_init_fn_( elif isinstance(init_div_is_residual, float) or isinstance( init_div_is_residual, int): div_is_residual = init_div_is_residual - elif isinstance(init_div_is_residual, - str) and init_div_is_residual.isnumeric(): + elif isinstance( + init_div_is_residual, # type: ignore + str) and init_div_is_residual.isnumeric(): # do not trust YAML parsing to always convert numbers to numbers div_is_residual = float(init_div_is_residual) else: @@ -107,12 +108,13 @@ def generic_param_init_fn_( else: init_fn_(module.weight) if module.bias is not None: + assert isinstance(module.bias, torch.Tensor) torch.nn.init.zeros_(module.bias) if init_div_is_residual is not False and getattr( module, '_is_residual', False): with torch.no_grad(): - module.weight.div_(div_is_residual) + module.weight.div_(div_is_residual) # type: ignore elif isinstance(module, nn.Embedding): # Embedding @@ -204,13 +206,15 @@ def generic_param_init_fn_( init_fn_(module.fc1_weight) if module.fc1_bias is not None: + assert isinstance(module.fc1_bias, torch.Tensor) torch.nn.init.zeros_(module.fc1_bias) init_fn_(module.fc2_weight) if module.fc2_bias is not None: + assert isinstance(module.fc2_bias, torch.Tensor) torch.nn.init.zeros_(module.fc2_bias) with torch.no_grad(): - module.fc2_weight.div_(div_is_residual) + module.fc2_weight.div_(div_is_residual) # type: ignore else: for _ in module.parameters(recurse=False): @@ -220,7 +224,7 @@ def generic_param_init_fn_( ) -def _normal_init_(std, mean=0.0): +def _normal_init_(std: float, mean: float = 0.0): return partial(torch.nn.init.normal_, mean=mean, std=std) @@ -233,7 +237,7 @@ def _normal_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config init_fn_ = _normal_init_(std=std) @@ -263,7 +267,7 @@ def baseline_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config if init_std is None: @@ -290,7 +294,7 @@ def small_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config # very close to kaiming normal @@ -315,7 +319,7 @@ def neox_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, verbose: int = 0, - **kwargs, + **kwargs: Any, ): """From section 2.3.1 of GPT-NeoX-20B: @@ -351,7 +355,7 @@ def kaiming_uniform_param_init_fn_( fan_mode: str = 'fan_in', init_nonlinearity: str = 'leaky_relu', verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config @@ -389,7 +393,7 @@ def kaiming_normal_param_init_fn_( fan_mode: str = 'fan_in', init_nonlinearity: str = 'leaky_relu', verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config @@ -425,7 +429,7 @@ def xavier_uniform_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, verbose: int = 0, - **kwargs, + **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) @@ -457,8 +461,9 @@ def xavier_normal_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, verbose: int = 0, - **kwargs, + **kwargs: Any, ): + del kwargs # unused, just to capture any extra args from the config xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) if verbose > 1: diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index 115aec90d6..a36e2da192 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -3,7 +3,7 @@ import logging import math -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Iterable, Optional, Tuple, Union import torch from composer.utils import dist @@ -59,7 +59,7 @@ class DecoupledAdaLRLion(Optimizer): } def __init__(self, - params, + params: Union[Iterable[torch.Tensor], Iterable[dict]], lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, @@ -76,6 +76,7 @@ def __init__(self, if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' ) @@ -91,7 +92,9 @@ def __init__(self, self.min_scale = min_scale @staticmethod - def lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) -> None: + def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, + lr: float, initial_lr: float, wd: float, beta1: float, + beta2: float) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 @@ -178,14 +181,14 @@ def step(self, closure: Optional[Callable] = None): return loss - def dist_reduce_metrics(self, optimizer_metrics): + def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): for metric in optimizer_metrics: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') - optimizer_metrics[metric] = math.sqrt(reduced) + optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) elif metric.startswith('cosine'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: @@ -209,7 +212,7 @@ def dist_reduce_metrics(self, optimizer_metrics): return optimizer_metrics - def pre_reduce_metrics(self, optimizer_metrics): + def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" # Sort L2 norms first so they are squared before other metrics, which depend on squared values metrics = optimizer_metrics.keys() @@ -303,11 +306,11 @@ class DecoupledClipLion(Optimizer): } def __init__(self, - params, + params: Union[Iterable[torch.Tensor], Iterable[dict]], lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, - outlier_threshold=5.0): + outlier_threshold: float = 5.0): if lr <= 0.: raise Exception(f'Invalid LR: {lr}. LR must be > 0') if not all([0. <= beta <= 1. for beta in betas]): @@ -317,6 +320,7 @@ def __init__(self, if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' ) @@ -329,7 +333,9 @@ def __init__(self, self.outlier_threshold = outlier_threshold @staticmethod - def lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) -> None: + def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, + lr: float, initial_lr: float, wd: float, beta1: float, + beta2: float) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 @@ -385,14 +391,14 @@ def step(self, closure: Optional[Callable] = None): return loss - def dist_reduce_metrics(self, optimizer_metrics): + def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): for metric in optimizer_metrics: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') - optimizer_metrics[metric] = math.sqrt(reduced) + optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) elif metric.startswith('cosine'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: @@ -416,7 +422,7 @@ def dist_reduce_metrics(self, optimizer_metrics): return optimizer_metrics - def pre_reduce_metrics(self, optimizer_metrics): + def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" # Sort L2 norms first so they are squared before other metrics, which depend on squared values metrics = optimizer_metrics.keys() diff --git a/llmfoundry/optim/lion.py b/llmfoundry/optim/lion.py index 820530373f..2c469f2fcd 100644 --- a/llmfoundry/optim/lion.py +++ b/llmfoundry/optim/lion.py @@ -3,7 +3,7 @@ import logging import math -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Iterable, Optional, Tuple, Union import torch from composer.utils import dist @@ -38,7 +38,7 @@ class DecoupledLionW(Optimizer): def __init__( self, - params, + params: Union[Iterable[torch.Tensor], Iterable[dict]], lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, @@ -52,6 +52,7 @@ def __init__( if weight_decay >= 1e-3: log.warning( f'You are using a high value of `weight_decay={weight_decay}` for the `DecoupledLionW` optimizer. Are you sure you want to do this? ' + + f'Your model\'s weights will be multiplied by {1.0 - weight_decay} on every step!' ) @@ -63,7 +64,9 @@ def __init__( group['initial_lr'] = group['lr'] @staticmethod - def lionw(p, grad, exp_avg, lr, initial_lr, wd, beta1, beta2) -> None: + def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, + lr: float, initial_lr: float, wd: float, beta1: float, + beta2: float) -> None: # stepweight decay if wd != 0: decay_factor = (lr / initial_lr) if initial_lr else 1.0 @@ -103,14 +106,14 @@ def step(self, closure: Optional[Callable] = None): return loss - def dist_reduce_metrics(self, optimizer_metrics): + def dist_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): for metric in optimizer_metrics: if metric.startswith('l2_norm'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: dist.all_reduce(reduced, reduce_operation='SUM') - optimizer_metrics[metric] = math.sqrt(reduced) + optimizer_metrics[metric] = torch.tensor(math.sqrt(reduced)) elif metric.startswith('cosine'): reduced = optimizer_metrics[metric] if dist.get_world_size() > 1: @@ -132,7 +135,7 @@ def dist_reduce_metrics(self, optimizer_metrics): return optimizer_metrics - def pre_reduce_metrics(self, optimizer_metrics): + def pre_reduce_metrics(self, optimizer_metrics: Dict[str, torch.Tensor]): """Preprocess metrics to reduce across ranks correctly.""" # Sort L2 norms first so they are squared before other metrics, which depend on squared values metrics = optimizer_metrics.keys() diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index bf16ddc663..b2cd535fe9 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Union +from typing import Any, Dict, Optional, Union +import torch from composer import algorithms from composer.callbacks import (EarlyStopper, LRMonitor, MemoryMonitor, OptimizerMonitor, RuntimeEstimator, @@ -17,10 +18,9 @@ CosineAnnealingWithWarmupScheduler, LinearWithWarmupScheduler) from composer.utils import dist -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.callbacks import (FDiffMetrics, Generate, GlobalLRScaling, LayerFreezing, MonolithicCheckpointSaver, @@ -28,10 +28,8 @@ from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW) -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - -def build_callback(name, kwargs): +def build_callback(name: str, kwargs: Dict[str, Any]): if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -64,7 +62,7 @@ def build_callback(name, kwargs): raise ValueError(f'Not sure how to build callback: {name}') -def build_logger(name, kwargs): +def build_logger(name: str, kwargs: Dict[str, Any]): if name == 'wandb': return WandBLogger(**kwargs) elif name == 'tensorboard': @@ -73,7 +71,7 @@ def build_logger(name, kwargs): raise ValueError(f'Not sure how to build logger: {name}') -def build_algorithm(name, kwargs): +def build_algorithm(name: str, kwargs: Dict[str, Any]): if name == 'gradient_clipping': return algorithms.GradientClipping(**kwargs) elif name == 'alibi': @@ -88,7 +86,7 @@ def build_algorithm(name, kwargs): raise ValueError(f'Not sure how to build algorithm: {name}') -def build_optimizer(cfg, model): +def build_optimizer(cfg: DictConfig, model: torch.nn.Module): if cfg.name == 'decoupled_adamw': return DecoupledAdamW(model.parameters(), lr=cfg.lr, @@ -119,7 +117,7 @@ def build_optimizer(cfg, model): raise ValueError(f'Not sure how to build optimizer: {cfg.name}') -def build_scheduler(cfg): +def build_scheduler(cfg: DictConfig): if cfg.name == 'constant_with_warmup': return ConstantWithWarmupScheduler(t_warmup=cfg.t_warmup) elif cfg.name == 'cosine_with_warmup': @@ -132,7 +130,7 @@ def build_scheduler(cfg): raise ValueError(f'Not sure how to build scheduler: {cfg.name}') -def build_tokenizer(om_tokenizer_config: DictConfig,) -> Tokenizer: +def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase: os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' @@ -155,20 +153,26 @@ def build_tokenizer(om_tokenizer_config: DictConfig,) -> Tokenizer: return tokenizer -def build_icl_evaluators(icl_tasks, - tokenizer, - default_max_seq_len, - default_batch_size, - destination_dir=os.getcwd()): +def build_icl_evaluators(icl_tasks: Union[str, ListConfig], + tokenizer: PreTrainedTokenizerBase, + default_max_seq_len: int, + default_batch_size: int, + destination_dir: Optional[str] = None): + if destination_dir is None: + destination_dir = os.getcwd() evaluators = [] logger_keys = [] + + icl_tasks_list = None if isinstance(icl_tasks, str): print(f'Extracting ICL task config from path: {icl_tasks}') with open(icl_tasks, 'r') as icl_f: icl_task_cfg = om.load(icl_f) - icl_tasks = icl_task_cfg.icl_tasks + icl_tasks_list = icl_task_cfg.icl_tasks + else: + icl_tasks_list = icl_tasks - def _validate_cfg(icl_cfg): + def _validate_cfg(icl_cfg: DictConfig): assert 'label' in icl_cfg assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None assert 'icl_task_type' in icl_cfg @@ -203,7 +207,7 @@ def _validate_cfg(icl_cfg): if 'batch_size' not in icl_cfg: icl_cfg.batch_size = default_batch_size - for icl_cfg in icl_tasks: + for icl_cfg in icl_tasks_list: _validate_cfg(icl_cfg) for num_fewshot in list(icl_cfg.num_fewshot): if tokenizer.pad_token_id is None: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index a7fe93ef4e..6a4f1a8a4a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Union +from typing import Dict, Optional, Union from composer.utils import dist from omegaconf import DictConfig @@ -18,15 +18,16 @@ def calculate_batch_size_info(global_batch_size: int, if global_batch_size % dist.get_world_size() != 0: raise ValueError( f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' - f'to be divisible by world size, {dist.get_world_size()}.') + + f'to be divisible by world size, {dist.get_world_size()}.') device_batch_size = global_batch_size // dist.get_world_size() if device_microbatch_size == 'auto': device_grad_accum = 'auto' elif isinstance(device_microbatch_size, int): if device_microbatch_size > device_batch_size: print( - f'WARNING: device_microbatch_size > device_batch_size, ' + f'WARNING: device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.' ) device_microbatch_size = device_batch_size @@ -55,7 +56,7 @@ def update_batch_size_info(cfg: DictConfig): return cfg -def process_init_device(model_cfg: DictConfig, fsdp_config: dict): +def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Restrict model init_device to 'meta' and 'cpu', # using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors # when multiple GPUs are available. @@ -73,7 +74,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: dict): if model_cfg.init_device == 'mixed': if fsdp_config is None: raise NotImplementedError( - 'Using init_device `mixed` is only supported with FSDP. ' + 'Using init_device `mixed` is only supported with FSDP. ' + 'Please add a FSDP config.') # Always set `sync_module_states` to True for mixed initialization if not fsdp_config.get('sync_module_states', False): diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 647a070870..4b837d2e67 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -34,6 +34,8 @@ def find_module_file(module_name: str) -> str: raise ValueError(f'Invalid input: {module_name=}') module = importlib.import_module(module_name) module_file = module.__file__ + if module_file is None: + raise ValueError(f'Could not find file for module: {module_name}') return module_file @@ -50,17 +52,18 @@ def process_file(file_path: str, folder_path: str) -> List[str]: nodes_to_remove = [] for node in ast.walk(tree): # convert any llmfoundry imports into relative imports - if isinstance(node, - ast.ImportFrom) and node.module.startswith('llmfoundry'): + if isinstance( + node, ast.ImportFrom + ) and node.module is not None and node.module.startswith('llmfoundry'): module_path = find_module_file(node.module) node.module = convert_to_relative_import(node.module, parent_module_name) # recursively process any llmfoundry files new_files_to_process.append(module_path) # remove any imports from composer or omegaconf - elif isinstance( - node, ast.ImportFrom) and (node.module.startswith('composer') or - node.module.startswith('omegaconf')): + elif isinstance(node, ast.ImportFrom) and node.module is not None and ( + node.module.startswith('composer') or + node.module.startswith('omegaconf')): nodes_to_remove.append(node) # remove the Composer* class elif isinstance(node, @@ -83,6 +86,7 @@ def process_file(file_path: str, folder_path: str) -> List[str]: new_filename = file_path.split('/')[-2] + '.py' new_file_path = os.path.join(folder_path, new_filename) with open(new_file_path, 'w') as f: + assert new_tree is not None f.write(ast.unparse(new_tree)) return new_files_to_process diff --git a/pyproject.toml b/pyproject.toml index 444206c5df..efa8a7b582 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,10 @@ include = [ "llmfoundry/*" ] + # Pyright [tool.pyright] -exclude = ['env-**', 'venv*'] +exclude = ['env-**', 'venv*', '**/flash_attn_triton.py'] ignore = ['llmfoundry/models/layers/flash_attn_triton.py'] stubPath = "" # suppress useless 'stubPath is not a valid directory' errors @@ -36,7 +37,7 @@ reportUnusedVariable = "error" reportDuplicateImport = "error" reportWildcardImportFromLibrary = "error" reportUntypedFunctionDecorator = "warning" -reportPrivateImportUsage = "warning" +reportPrivateImportUsage = "none" reportUndefinedVariable = "error" strictParameterNoneValue = true reportPropertyTypeMismatch = "error" @@ -46,15 +47,16 @@ reportInvalidTypeVarUse = "error" reportOverlappingOverload = "error" reportUninitializedInstanceVariable = "error" reportInvalidStringEscapeSequence = "error" -reportMissingParameterType = "warning" # TODO: make this an error -reportCallInDefaultInitializer = "none" # TODO: make this an error -reportUnnecessaryComparison = "warning" +reportMissingParameterType = "error" +reportCallInDefaultInitializer = "error" +reportUnnecessaryComparison = "error" reportSelfClsParameterName = "error" reportImplicitStringConcatenation = "warning" # TODO: make this an error reportInvalidStubStatement = "error" reportIncompleteStub = "error" reportUnsupportedDunderAll = "error" reportUnusedCoroutine = "error" +reportMissingImports = "none" # Pytest [tool.pytest.ini_options] diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index 3acf0fdff1..50dd30c45d 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -12,7 +12,7 @@ import datasets as hf_datasets import psutil from streaming import MDSWriter -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import DataLoader, Dataset, IterableDataset from tqdm import tqdm from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -252,7 +252,8 @@ def _est_progress_denominator(total_samples: int, chars_per_sample: int, return total_samples * est_tokens_per_sample // max_length -def build_dataloader(dataset, batch_size, num_workers) -> DataLoader: +def build_dataloader(dataset: Dataset, batch_size: int, + num_workers: Optional[int]) -> DataLoader: if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' or 'macos' in platform.platform().lower(): diff --git a/scripts/data_prep/convert_dataset_json.py b/scripts/data_prep/convert_dataset_json.py index ae3cb0cdf6..54c0bfa814 100644 --- a/scripts/data_prep/convert_dataset_json.py +++ b/scripts/data_prep/convert_dataset_json.py @@ -131,16 +131,6 @@ def build_hf_dataset( return dataset -def _est_progress_denominator(total_samples: int, chars_per_sample: int, - chars_per_token: int, mode: ConcatMode, - max_length: int): - est_tokens_per_sample = chars_per_sample // chars_per_token - if mode == ConcatMode.NO_CONCAT: - return total_samples - elif mode == ConcatMode.CONCAT_TOKENS: - return total_samples * est_tokens_per_sample // max_length - - def generate_samples( loader: DataLoader, truncate_num_samples: Optional[int] = None diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index e51622a78f..58dcbc940a 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -5,14 +5,17 @@ import re import sys import time -from typing import List +from typing import Dict, List, Optional import pandas as pd import torch from composer.loggers import InMemoryLogger, LoggerDestination +from composer.models.base import ComposerModel from composer.trainer import Trainer from composer.utils import dist, get_device, reproducibility +from omegaconf import DictConfig from omegaconf import OmegaConf as om +from transformers import PreTrainedTokenizerBase from llmfoundry.callbacks import ModelGauntlet from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY @@ -21,7 +24,9 @@ from llmfoundry.utils.config_utils import process_init_device -def load_model(model_cfg, tokenizer, fsdp_config, num_retries): +def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + fsdp_config: Optional[Dict], + num_retries: int) -> Optional[ComposerModel]: init_context = process_init_device(model_cfg, fsdp_config) retries = 0 @@ -41,7 +46,8 @@ def load_model(model_cfg, tokenizer, fsdp_config, num_retries): ) -def evaluate_model(model_cfg, run_name, model_gauntlet_df): +def evaluate_model(model_cfg: DictConfig, cfg: DictConfig, run_name: str, + model_gauntlet_df: Optional[pd.DataFrame]): print(f'Evaluating model: {model_cfg.model_name}', flush=True) # Build tokenizer and model tokenizer = build_tokenizer(model_cfg.tokenizer) @@ -68,6 +74,7 @@ def evaluate_model(model_cfg, run_name, model_gauntlet_df): fsdp_config = cfg.get('fsdp_config', None) fsdp_config = om.to_container( fsdp_config, resolve=True) if fsdp_config is not None else None + assert isinstance(fsdp_config, Dict) or fsdp_config is None composer_model = load_model(model_cfg.model, tokenizer, fsdp_config, cfg.get('num_retries', 3)) @@ -86,6 +93,7 @@ def evaluate_model(model_cfg, run_name, model_gauntlet_df): load_path = model_cfg.get('load_path', None) + assert composer_model is not None trainer = Trainer( run_name=run_name, model=composer_model, @@ -111,7 +119,7 @@ def evaluate_model(model_cfg, run_name, model_gauntlet_df): model_gauntlet, model_gauntlet_df) -def main(cfg): +def main(cfg: DictConfig): cfg.dist_timeout = cfg.get('dist_timeout', 600.0) if cfg.get('run_name') is None: cfg.run_name = os.environ.get('RUN_NAME', 'llm') @@ -121,14 +129,16 @@ def main(cfg): model_gauntlet_df = None models_df = None + composite_scores = None for model_cfg in cfg.models: (in_memory_logger, logger_keys, model_gauntlet_callback, model_gauntlet, - model_gauntlet_df) = evaluate_model(model_cfg, cfg.run_name, + model_gauntlet_df) = evaluate_model(model_cfg, cfg, cfg.run_name, model_gauntlet_df) if model_gauntlet_callback is not None: + # TODO(bmosaicml) This needs to be refactored to fix the typing issue composite_scores = model_gauntlet_callback.eval_end( - None, in_memory_logger) + None, in_memory_logger) # type: ignore benchmark_to_taxonomy = {} if model_gauntlet is not None: @@ -147,6 +157,7 @@ def main(cfg): models_df = pd.concat([models_df, model_results], ignore_index=True) if model_gauntlet_df is not None and model_gauntlet is not None and model_gauntlet_df is not None: + assert composite_scores is not None row = {'model_name': model_cfg['model_name']} row.update({ t.name: composite_scores[f'metrics/model_gauntlet/{t.name}'] @@ -166,10 +177,12 @@ def main(cfg): print(models_df.to_markdown(index=False)) -def calculate_markdown_results(logger_keys, logger_data, benchmark_to_taxonomy, - model_name): +def calculate_markdown_results(logger_keys: List[str], logger_data: Dict, + benchmark_to_taxonomy: Dict[str, str], + model_name: str): results = {} - pat = re.compile('metrics/(.*?)/(\d+)-shot(/.*?)?/InContextLearning(.*)') + pat = re.compile( + 'metrics/(.*?)/(\d+)-shot(/.*?)?/InContextLearning(.*)') # type: ignore for key in logger_keys: match = pat.match(key) val = logger_data[key][0][1].item() @@ -253,4 +266,5 @@ def calculate_markdown_results(logger_keys, logger_data, benchmark_to_taxonomy, yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) cfg = om.merge(yaml_cfg, cli_cfg) + assert isinstance(cfg, DictConfig) main(cfg) diff --git a/scripts/inference/benchmarking/benchmark.py b/scripts/inference/benchmarking/benchmark.py index 8209fe21bd..d2e51bb7a5 100644 --- a/scripts/inference/benchmarking/benchmark.py +++ b/scripts/inference/benchmarking/benchmark.py @@ -6,13 +6,13 @@ from contextlib import nullcontext import torch -# You can use this to load the model weights +from omegaconf import DictConfig from omegaconf import OmegaConf as om from llmfoundry import COMPOSER_MODEL_REGISTRY -def get_dtype(dtype): +def get_dtype(dtype: str): if dtype == 'fp32': return torch.float32 elif dtype == 'fp16': @@ -21,18 +21,18 @@ def get_dtype(dtype): return torch.bfloat16 else: raise NotImplementedError( - f'dtype {dtype} is not supported. ' + f'dtype {dtype} is not supported. ' + f'We only support fp32, fp16, and bf16 currently') -def compare_dtype(dtype, param_dtype): +def compare_dtype(dtype: torch.dtype, param_dtype: torch.dtype): if dtype != param_dtype: raise ValueError( - f'dtype type is: {dtype} but model dtype is: {param_dtype}. ' + f'dtype type is: {dtype} but model dtype is: {param_dtype}. ' + f"The expected dtype and model dtype don't match.") -def main(config): +def main(config: DictConfig): if config.device is not None: device = config.device else: @@ -128,4 +128,5 @@ def main(config): yaml_config = om.load(f) cli_config = om.from_cli(args_list) config = om.merge(yaml_config, cli_config) + assert isinstance(config, DictConfig) main(config) diff --git a/scripts/inference/convert_hf_mpt_to_ft.py b/scripts/inference/convert_hf_mpt_to_ft.py index 8a94a42fbe..ceb4c5c770 100644 --- a/scripts/inference/convert_hf_mpt_to_ft.py +++ b/scripts/inference/convert_hf_mpt_to_ft.py @@ -24,10 +24,9 @@ import argparse import configparser import os -from typing import Any, Dict, List +from typing import Any, Dict, Tuple, Union import numpy as np -import torch import transformers @@ -41,7 +40,7 @@ def get_weight_data_type(data_type: str): def write_zero_bias(weight_name: str, weight_file_path: str, - bias_shape: List[int]) -> None: + bias_shape: Union[Tuple[int, ...], int]) -> None: """Write zeros for bias. MPT model might not have bias while FT expects bias. @@ -49,7 +48,7 @@ def write_zero_bias(weight_name: str, weight_file_path: str, Args: weight_name (str): Name of the weight tensor. weight_file_path (str): Output path for storing the weight (NOT zero bias). - bias_shape (List[int]): Shape of the bias array. + bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array. """ if 'weight' not in weight_file_path: raise RuntimeError( diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index 8a653c13e9..dd7a6f7a62 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -30,7 +30,7 @@ import os from argparse import ArgumentTypeError from pathlib import Path -from typing import Optional +from typing import Optional, Union import torch from composer.utils import (maybe_create_object_store_from_uri, parse_uri, @@ -38,7 +38,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -def str2bool(v): +def str2bool(v: Union[str, bool]): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -49,7 +49,7 @@ def str2bool(v): raise ArgumentTypeError('Boolean value expected.') -def str_or_bool(v): +def str_or_bool(v: Union[str, bool]): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 28f9af90b8..4f938f999e 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -5,15 +5,13 @@ import warnings from argparse import ArgumentParser, ArgumentTypeError, Namespace from contextlib import nullcontext -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizer, PreTrainedTokenizerFast, + PreTrainedModel, PreTrainedTokenizerBase, StoppingCriteria, StoppingCriteriaList, TextStreamer) -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - class ChatFormatter: """A class for formatting the chat history. @@ -57,13 +55,14 @@ class Conversation: cli_instructions: The instructions to display to the user. """ - def __init__( - self, - model, - tokenizer: Tokenizer, - chat_format: ChatFormatter, - generate_kwargs: Dict[str, Any], - stop_tokens: List[str] = ['<|endoftext|>', '<|im_end|>']) -> None: + def __init__(self, + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + chat_format: ChatFormatter, + generate_kwargs: Dict[str, Any], + stop_tokens: Optional[List[str]] = None) -> None: + if stop_tokens is None: + stop_tokens = ['<|endoftext|>', '<|im_end|>'] self.model = model self.tokenizer = tokenizer self.chat_format = chat_format @@ -77,7 +76,8 @@ def __init__( class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, - scores: torch.FloatTensor, **kwargs) -> bool: + scores: torch.FloatTensor, **kwargs: Any) -> bool: + del kwargs # unused for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True @@ -173,11 +173,11 @@ def get_dtype(dtype: str): return torch.bfloat16 else: raise NotImplementedError( - f'dtype {dtype} is not supported. ' - f'We only support fp32, fp16, and bf16 currently') + f'dtype {dtype} is not supported. ' + + 'We only support fp32, fp16, and bf16 currently') -def str2bool(v): +def str2bool(v: Union[str, bool]): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -305,7 +305,9 @@ def main(args: Namespace) -> None: except Exception as e: raise RuntimeError( 'If you are having auth problems, try logging in via `huggingface-cli login` ' + + 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' + + 'using your access token from https://huggingface.co/settings/tokens.' ) from e @@ -324,9 +326,11 @@ def main(args: Namespace) -> None: model.to(device) except Exception as e: raise RuntimeError( - 'Unable to load HF model. ' + 'Unable to load HF model. ' + 'If you are having auth problems, try logging in via `huggingface-cli login` ' + + 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' + + 'using your access token from https://huggingface.co/settings/tokens.' ) from e diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 5b28e0b598..cc331acd14 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -1,6 +1,5 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import itertools import os import random @@ -8,14 +7,14 @@ import warnings from argparse import ArgumentParser, ArgumentTypeError, Namespace from contextlib import nullcontext +from typing import Dict, Union import numpy as np import torch -from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - pipeline) +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -def get_dtype(dtype): +def get_dtype(dtype: str): if dtype == 'fp32': return torch.float32 elif dtype == 'fp16': @@ -28,7 +27,7 @@ def get_dtype(dtype): f'We only support fp32, fp16, and bf16 currently') -def str2bool(v): +def str2bool(v: Union[str, bool]): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -39,7 +38,7 @@ def str2bool(v): raise ArgumentTypeError('Boolean value expected.') -def str_or_bool(v): +def str_or_bool(v: Union[str, bool]): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): @@ -220,9 +219,11 @@ def main(args: Namespace) -> None: model.to(device) except Exception as e: raise RuntimeError( - 'Unable to load HF model. ' - 'If you are having auth problems, try logging in via `huggingface-cli login` ' +\ - 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' +\ + 'Unable to load HF model. ' + + 'If you are having auth problems, try logging in via `huggingface-cli login` ' + + + 'or by setting the environment variable `export HUGGING_FACE_HUB_TOKEN=... ' + + 'using your access token from https://huggingface.co/settings/tokens.' ) from e @@ -261,7 +262,7 @@ def main(args: Namespace) -> None: print(f'\nGenerate kwargs:\n{generate_kwargs}') # Generate function with correct context managers - def _generate(encoded_inp): + def _generate(encoded_inp: Dict[str, torch.Tensor]): with torch.no_grad(): with autocast_context: return model.generate( diff --git a/scripts/inference/run_mpt_with_ft.py b/scripts/inference/run_mpt_with_ft.py index 83fb5180cd..df07967091 100644 --- a/scripts/inference/run_mpt_with_ft.py +++ b/scripts/inference/run_mpt_with_ft.py @@ -34,7 +34,6 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.join(dir_path, '../../..')) -import examples.pytorch.gpt.utils.gpt_token_encoder as encoder from examples.pytorch.gpt.utils import comm, gpt_decoder from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT @@ -134,7 +133,7 @@ def main(): type=float, default=0., help= - 'presence penalty. Similar to repetition, but addive rather than multiplicative.' + 'presence penalty. Similar to repetition, but additive rather than multiplicative.' ) parser.add_argument('--min_length', type=int, @@ -182,8 +181,8 @@ def main(): type=int, default=0, choices=[0, 1], - help='The level of quantization to perform.' - ' 0: No quantization. All computation in data_type' + help='The level of quantization to perform.' + + ' 0: No quantization. All computation in data_type' + ' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32' ) parser.add_argument( @@ -198,15 +197,15 @@ def main(): type=int, default=0, choices=[0, 1, 2], - help='Whether to compute the cumulative log probsbility of sentences.' - ' 0: do not return the cumulative log probs ' - ' 1: return the cumulative log probs of generated sequences' + help='Whether to compute the cumulative log probsbility of sentences.' + + ' 0: do not return the cumulative log probs' + + ' 1: return the cumulative log probs of generated sequences' + ' 2: return the cumulative log probs of sequences') parser.add_argument('--shared_contexts_ratio', type=float, default=0.0, - help='Triggers the shared context optimization when' - 'compact_size <= shared_contexts_ratio * batch_size' + help='Triggers the shared context optimization when ' + + 'compact_size <= shared_contexts_ratio * batch_size ' + 'A value of 0.0 deactivate the optimization') parser.add_argument( '--use_gpt_decoder_ops', diff --git a/scripts/misc/convert_examples_ckpt.py b/scripts/misc/convert_examples_ckpt.py index 0cb9c6f1e9..a533aec72d 100644 --- a/scripts/misc/convert_examples_ckpt.py +++ b/scripts/misc/convert_examples_ckpt.py @@ -79,7 +79,7 @@ def convert_examples_ckpt( local_ckpt_path = Path(tmp_dir.name) / 'local-composer-checkpoint.pt' # create object store if output_path - _, _, local_folder_path = parse_uri(output_path) + _, _, local_folder_path = parse_uri(str(output_path)) object_store = maybe_create_object_store_from_uri(str(output_path)) if object_store is not None: local_output_path = tempfile.TemporaryDirectory().name @@ -181,13 +181,14 @@ def convert_examples_ckpt( param_idx] = param_name # Save weights - file_path = str(Path(local_output_path) / checkpoint_path.split('/')[-1]) + file_path = str( + Path(local_output_path) / str(checkpoint_path).split('/')[-1]) print(f'Writing converted output to {file_path}') torch.save(composer_state_dict, file_path) if object_store is not None: remote_file_path = os.path.join(local_folder_path, - checkpoint_path.split('/')[-1]) + str(checkpoint_path).split('/')[-1]) print(f'Uploading from {file_path} to {remote_file_path}') object_store.upload_object(remote_file_path, file_path) diff --git a/scripts/train/benchmarking/collect_results.py b/scripts/train/benchmarking/collect_results.py index c0dfd877de..050390b743 100644 --- a/scripts/train/benchmarking/collect_results.py +++ b/scripts/train/benchmarking/collect_results.py @@ -4,14 +4,14 @@ import argparse import csv import math -from typing import Any, Dict +from typing import Any, Dict, List, Union from mcli import sdk as msdk GPU_AVAILABLE_FLOPS = 312_000_000_000_000 -def str_to_bool(value): +def str_to_bool(value: Union[bool, str]): # helper fn if isinstance(value, bool): return value @@ -45,12 +45,12 @@ def parse_args(): return parser.parse_args() -def get_runs(args): +def get_runs(args: argparse.Namespace): runs = [r for r in msdk.get_runs() if args.project in r.name] for filter in args.filters: runs = [r for r in runs if filter in r.name] - def sort_key(r): + def sort_key(r: msdk.Run): model_name = r.name.split('-')[2] num_gpu = r.config.gpu_num if model_name[-1] == 'm': @@ -69,7 +69,7 @@ def sort_key(r): return runs -def filter_runs(runs): +def filter_runs(runs: List[msdk.Run]): pop_runs = [] for run in runs: if run.status == msdk.RunStatus('FAILED'): @@ -102,7 +102,7 @@ def filter_runs(runs): return runs -def parse_run(run) -> Dict[str, Any]: +def parse_run(run: msdk.Run) -> Dict[str, Any]: n_params = micro_batchsize = throughput = -1 model_name = run.name.split('-')[2] @@ -203,7 +203,7 @@ def parse_run(run) -> Dict[str, Any]: } -def main(args): +def main(args: argparse.Namespace): runs = get_runs(args) runs = filter_runs(runs) diff --git a/scripts/train/benchmarking/submit_benchmarks.py b/scripts/train/benchmarking/submit_benchmarks.py index c476dadf5c..f7db0613ef 100644 --- a/scripts/train/benchmarking/submit_benchmarks.py +++ b/scripts/train/benchmarking/submit_benchmarks.py @@ -1,9 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import argparse import math import os +from typing import Any, Dict, List, Optional, Tuple, Union import requests import yaml @@ -26,7 +26,7 @@ def _get_cluster_info(): CLUSTER_INFO = _get_cluster_info() -def str_to_bool(value): +def str_to_bool(value: Union[bool, str]): # helper fn if isinstance(value, bool): return value @@ -169,11 +169,17 @@ def parse_args(): return parser.parse_args() -def get_max_seq_lens(pows=[9, 14]): +def get_max_seq_lens(pows: Optional[List[int]] = None): + if pows is None: + pows = [9, 14] return [2**n for n in range(pows[0], pows[1] + 1)] -def get_global_train_batch_sizes(max_seq_len, pows, batch_sizes=[]): +def get_global_train_batch_sizes(max_seq_len: int, + pows: List[int], + batch_sizes: Optional[List[int]] = None): + if batch_sizes is None: + batch_sizes = [] if pows: # global batch size in tokens (defualt: .5M thru 8M) global_train_token_counts = [2**n for n in range(pows[0], pows[1] + 1)] @@ -182,7 +188,7 @@ def get_global_train_batch_sizes(max_seq_len, pows, batch_sizes=[]): return batch_sizes -def get_parameters(yaml_file): +def get_parameters(yaml_file: str): local_yamls = False if 'https' in yaml_file else True if local_yamls: # Load the YAML into a parameters dictionary @@ -197,11 +203,11 @@ def get_parameters(yaml_file): return parameters -def get_cluster_gpu_types(cluster): +def get_cluster_gpu_types(cluster: str): return [gpu_info[0] for gpu_info in CLUSTER_INFO[cluster]] -def get_gpu_types(clusters): +def get_gpu_types(clusters: List[str]): gpu_types = set() for c in clusters: for g in get_cluster_gpu_types(c): @@ -209,7 +215,7 @@ def get_gpu_types(clusters): return gpu_types -def get_gpu_nums(clusters, gpu_types): +def get_gpu_nums(clusters: List[str], gpu_types: List[str]): max_gpus_per_run = 1 for c in clusters: for gpu_info in CLUSTER_INFO[c]: @@ -223,26 +229,26 @@ def get_gpu_nums(clusters, gpu_types): return gpu_nums -def get_valid_gpu_lim(cluster, gpu_type): +def get_valid_gpu_lim(cluster: str, gpu_type: str): for gpu_info in CLUSTER_INFO[cluster]: if gpu_info[0] == gpu_type: return gpu_info[1] raise ValueError -def mod_parameters(parameters, - max_seq_len, - global_train_batch_size, - precision, - fsdp_config_mixed_precision='DEFAULT', - fsdp_config_activation_checkpointing=None, - run_name='', - data_remote=None, - max_duration='30ba', - eval_interval=0, - microbatch_size=None, - wandb=True, - pad_vocab_multiple=None): +def mod_parameters(parameters: Dict[str, Any], + max_seq_len: int, + global_train_batch_size: int, + precision: str, + fsdp_config_mixed_precision: str = 'DEFAULT', + fsdp_config_activation_checkpointing: Optional[bool] = None, + run_name: str = '', + data_remote: Optional[str] = None, + max_duration: str = '30ba', + eval_interval: int = 0, + microbatch_size: Optional[Union[int, str]] = None, + wandb: bool = True, + pad_vocab_multiple: Optional[int] = None): if run_name: parameters['run_name'] = run_name if data_remote is not None: @@ -310,7 +316,10 @@ def mod_parameters(parameters, return parameters -def get_integrations(project, git_branch=None, git_commit=None, wandb=True): +def get_integrations(project: str, + git_branch: Optional[str] = None, + git_commit: Optional[str] = None, + wandb: bool = True): integrations = [] if git_branch and git_commit: @@ -339,7 +348,8 @@ def get_integrations(project, git_branch=None, git_commit=None, wandb=True): return integrations -def run_config(config, args): +def run_config(config: Tuple[str, int, int, str, str, int, str], + args: argparse.Namespace): model_yaml, max_seq_len, global_train_batch_size, cluster, gpu_type, gpu_num, precision = config integrations = get_integrations( @@ -383,6 +393,7 @@ def run_config(config, args): print(f'Shortening {_name} to {name} ({name_len_lim} chars)') microbatch_size = args.microbatch_size or 'auto' + assert isinstance(microbatch_size, (int, str)) parameters = mod_parameters( parameters, max_seq_len, @@ -416,7 +427,10 @@ def run_config(config, args): print(f'run = {name}') -def run_check_capacity(model_yaml, gpu_num, gpu_type, p_multiplier=16): +def run_check_capacity(model_yaml: str, + gpu_num: int, + gpu_type: str, + p_multiplier: int = 16): _params = model_yaml.replace('.yaml', '') params, mult = int(_params[:-1]), _params[-1] if mult == 'm': @@ -436,7 +450,7 @@ def run_check_capacity(model_yaml, gpu_num, gpu_type, p_multiplier=16): return True -def run_check_dtms(num_gpus, dtms, batch_size): +def run_check_dtms(num_gpus: int, dtms: int, batch_size: int): if num_gpus * dtms > batch_size: print( f'WARNING: Cannot run with {batch_size=} on {num_gpus=} with {dtms=} ({num_gpus*dtms=}).' @@ -477,9 +491,12 @@ def run_check_dtms(num_gpus, dtms, batch_size): global_train_batch_size) if run: - config = (model_yaml, max_seq_len, - global_train_batch_size, cluster, - gpu_type, gpu_num, precision) + config: Tuple[str, int, int, str, str, int, + str] = ( + model_yaml, max_seq_len, + global_train_batch_size, + cluster, gpu_type, + gpu_num, precision) print(config) run_config(config, args) n_jobs += 1 diff --git a/scripts/train/finetune_example/preprocessing.py b/scripts/train/finetune_example/preprocessing.py index 59e9098f37..adfa3c5cce 100644 --- a/scripts/train/finetune_example/preprocessing.py +++ b/scripts/train/finetune_example/preprocessing.py @@ -31,15 +31,19 @@ } """ -from typing import Dict +from typing import Dict, List, Union -def multiple_choice(inp: Dict[str, str]) -> Dict[str, str]: +def multiple_choice( + inp: Dict[str, Union[str, List[str], int]]) -> Dict[str, str]: PROMPT_FORMAT = '{query}\nOptions:{options}\nAnswer: ' options = '' + assert isinstance(inp['choices'], List) for option in inp['choices']: options += f'\n - {option}' query = inp['query'] + + assert isinstance(inp['gold'], int) return { 'prompt': PROMPT_FORMAT.format(query=query, options=options), 'response': inp['choices'][inp['gold']], diff --git a/scripts/train/train.py b/scripts/train/train.py index 3cbfecca6d..c6a86503e3 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,9 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import os import sys import warnings +from typing import Dict import torch from composer import Trainer @@ -11,13 +11,12 @@ from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizerBase from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, MPTForCausalLM, build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data.text_data import build_text_dataloader -from llmfoundry.models.utils import init_empty_weights, init_on_device from llmfoundry.utils.builders import (build_algorithm, build_callback, build_icl_evaluators, build_logger, build_optimizer, build_scheduler, @@ -26,7 +25,7 @@ update_batch_size_info) -def validate_config(cfg): +def validate_config(cfg: DictConfig): """Validates compatible model and dataloader selection.""" loaders = [cfg.train_loader] if 'eval_loader' in cfg: @@ -65,6 +64,7 @@ def validate_config(cfg): 'fp8' in cfg.precision): warnings.warn( "fp8 only supported for te.Linear layers. Either set `cfg.model.fc_typ='te'` or " + + "`cfg.model.ffn_config.ffn_type='te_ln_mlp'` to enable layers using fp8 precision." ) @@ -77,20 +77,21 @@ def validate_config(cfg): if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False: warnings.warn( '`te.Linear` layers do not support activation_checkpointing with ' - '`activation_checkpointing_reentrant = False`. ' + + '`activation_checkpointing_reentrant = False`. ' + 'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.' ) cfg.fsdp_config.activation_checkpointing_reentrant = True if 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp'): warnings.warn( - '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.' ) - torch._dynamo.config.suppress_errors = True + torch._dynamo.config.suppress_errors = True # type: ignore -def build_composer_model(model_cfg, tokenizer): +def build_composer_model(model_cfg: DictConfig, + tokenizer: PreTrainedTokenizerBase): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') @@ -102,14 +103,15 @@ def build_composer_model(model_cfg, tokenizer): def build_composer_peft_model( model_cfg: DictConfig, lora_cfg: DictConfig, - tokenizer: PreTrainedTokenizer) -> ComposerHFCausalLM: + tokenizer: PreTrainedTokenizerBase) -> ComposerHFCausalLM: try: from peft import LoraConfig, get_peft_model except ImportError as e: raise ImportError( 'Error importing from peft. Please verify that peft and peft utils ' - 'are installed by running `pip install -e .[peft]` from `llm-foundry/`.' - f'Error encountered: {e}') + + + 'are installed by running `pip install -e .[peft]` from `llm-foundry/`. ' + + f'Error encountered: {e}') # 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM. print('Building Lora config...') @@ -129,7 +131,7 @@ def build_composer_peft_model( return model -def print_trainable_parameters(model) -> None: +def print_trainable_parameters(model: torch.nn.Module) -> None: # Prints the number of trainable parameters in the model. trainable_params = 0 all_param = 0 @@ -142,7 +144,8 @@ def print_trainable_parameters(model) -> None: ) -def build_dataloader(cfg, tokenizer, device_batch_size): +def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int): if cfg.name == 'text': return build_text_dataloader( cfg, @@ -166,7 +169,7 @@ def build_dataloader(cfg, tokenizer, device_batch_size): raise ValueError(f'Not sure how to build dataloader with config: {cfg}') -def main(cfg): +def main(cfg: DictConfig): # Check for incompatibilities between the model and data loaders validate_config(cfg) @@ -194,6 +197,7 @@ def main(cfg): fsdp_config = cfg.get('fsdp_config', None) fsdp_config = om.to_container(fsdp_config, resolve=True) if fsdp_config else None + assert isinstance(fsdp_config, Dict) or fsdp_config is None if dist.get_world_size() == 1 and fsdp_config is not None: warnings.warn( 'FSDP is not applicable for single-GPU training. Reverting to DDP.') @@ -228,6 +232,7 @@ def main(cfg): print('Building eval loader...') evaluators = [] if 'eval_loader' in cfg: + assert model.train_metrics is not None eval_loader = Evaluator(label='eval', dataloader=build_dataloader( cfg.eval_loader, tokenizer, @@ -337,4 +342,5 @@ def main(cfg): yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) cfg = om.merge(yaml_cfg, cli_cfg) + assert isinstance(cfg, DictConfig) main(cfg) diff --git a/setup.py b/setup.py index 43b7939fba..b2f017e8a7 100644 --- a/setup.py +++ b/setup.py @@ -71,7 +71,7 @@ 'pytest>=7.2.1,<8', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', - 'pyright==1.1.296', + 'pyright==1.1.256', 'toml>=0.10.2,<0.11', 'packaging>=21,<23', ] diff --git a/tests/conftest.py b/tests/conftest.py index 4cc8b0bd7c..b39ebd66a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,7 +81,7 @@ def pytest_sessionfinish(session: pytest.Session, exitstatus: int): @pytest.fixture(autouse=True) -def clear_cuda_cache(request): +def clear_cuda_cache(request: pytest.FixtureRequest): """Clear memory between GPU tests.""" marker = request.node.get_closest_marker('gpu') if marker is not None and torch.cuda.is_available(): diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index eed86528bd..7039814d42 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,11 +1,11 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import os import shutil import sys import tempfile from argparse import Namespace +from typing import Optional import pytest import torch @@ -23,24 +23,26 @@ from scripts.data_prep.convert_dataset_hf import main as main_hf -def get_config(conf_path='yamls/mpt/125m.yaml'): +def get_config(conf_path: str = 'yamls/mpt/125m.yaml'): os.environ['TOKENIZERS_PARALLELISM'] = 'false' with open(conf_path) as f: test_cfg = om.load(f) return test_cfg -def get_data_local(tokenizer_name, pretokenize): +def get_data_local(tokenizer_name: str, pretokenize: bool): return f'my-copy-c4-{tokenizer_name}-pretokenize-{pretokenize}' -def get_abs_data_path(data_local): +def get_abs_data_path(data_local: str): return os.path.join(os.getcwd(), data_local) @pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m']) @pytest.mark.parametrize('pretokenize', [False, True]) -def test_correct_padding(tokenizer_name, pretokenize, batch_size=4): +def test_correct_padding(tokenizer_name: str, + pretokenize: bool, + batch_size: int = 4): if tokenizer_name == 'gpt2' and not pretokenize: pytest.xfail('Must pretokenize data if using "gpt2" tokenizer') @@ -127,7 +129,8 @@ def test_correct_padding(tokenizer_name, pretokenize, batch_size=4): @pytest.mark.parametrize(('eos_token_id', 'bos_token_id'), [(5, None), (None, 5), pytest.param(5, 5, marks=pytest.mark.xfail)]) -def test_sequence_id_wrapper(eos_token_id, bos_token_id): +def test_sequence_id_wrapper(eos_token_id: Optional[int], + bos_token_id: Optional[int]): wrapper = ConcatenatedSequenceCollatorWrapper( lambda x: x, # placeholder eos_token_id=eos_token_id, @@ -150,7 +153,8 @@ def test_sequence_id_wrapper(eos_token_id, bos_token_id): @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('pretokenize', [True, False]) @pytest.mark.parametrize('packing_ratio', [None, 5.5]) -def test_denoising_dataloader(decoder_only_format, pretokenize, packing_ratio): +def test_denoising_dataloader(decoder_only_format: bool, pretokenize: bool, + packing_ratio: Optional[float]): # Use the datasets just built in the last test tokenizer_name = 'facebook/opt-125m' data_local = get_data_local(tokenizer_name, pretokenize) @@ -219,8 +223,9 @@ def test_denoising_dataloader(decoder_only_format, pretokenize, packing_ratio): @pytest.mark.parametrize('decoder_only_format', [True, False]) @pytest.mark.parametrize('allow_pad_trimming', [True, False]) @pytest.mark.parametrize('packing_ratio', [10.0, None]) -def test_finetuning_dataloader(decoder_only_format, allow_pad_trimming, - packing_ratio): +def test_finetuning_dataloader(decoder_only_format: bool, + allow_pad_trimming: bool, + packing_ratio: Optional[float]): # Use the datasets just built in the last test tokenizer_name = 'gpt2' if decoder_only_format else 't5-base' max_seq_len = 2048 if decoder_only_format else 1024 diff --git a/tests/test_flash_triton_torch.py b/tests/test_flash_triton_torch.py index 0daba54bfb..d029f4fe4d 100644 --- a/tests/test_flash_triton_torch.py +++ b/tests/test_flash_triton_torch.py @@ -7,7 +7,10 @@ from omegaconf import OmegaConf as om -def allclose_helper(t0, t1, rtol=1e-2, atol=1e-2): +def allclose_helper(t0: torch.Tensor, + t1: torch.Tensor, + rtol: float = 1e-2, + atol: float = 1e-2): return torch.allclose(t0, t1, rtol=rtol, atol=atol) @@ -18,13 +21,13 @@ def allclose_helper(t0, t1, rtol=1e-2, atol=1e-2): @pytest.mark.parametrize('qk_ln', [True, False]) @pytest.mark.parametrize('alibi', [True, False]) @pytest.mark.parametrize('multiquery', [True, False]) -def test_attn_impl(attn_impl_0, - attn_impl_1, - clip_qkv, - qk_ln, - alibi, - multiquery, - device='cuda'): +def test_attn_impl(attn_impl_0: str, + attn_impl_1: str, + clip_qkv: bool, + qk_ln: bool, + alibi: bool, + multiquery: bool, + device: str = 'cuda'): """Compare all attn impl with each other. Includes testing with and without attn_clip_qkv, attn_qk_ln, and alibi. @@ -62,7 +65,7 @@ def test_attn_impl(attn_impl_0, attention_mask = torch.ones(n, s).to(device).bool() - def gen_bias(attn_impl): + def gen_bias(attn_impl: str): causal = True attn_bias = None bs = attention.attn_bias_shape(attn_impl, @@ -118,15 +121,19 @@ def gen_bias(attn_impl): torch_name_param_map = {n: p for n, p in attn1.named_parameters()} for n, p in attn0.named_parameters(): tp = torch_name_param_map[n] + assert p.grad is not None + assert tp.grad is not None assert allclose_helper(p, tp) assert allclose_helper(p.grad, tp.grad) + assert x0.grad is not None + assert x1.grad is not None assert allclose_helper(x0.grad, x1.grad) @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['flash', 'triton', 'torch']) -def test_vs_mha(attn_impl, device='cuda'): +def test_vs_mha(attn_impl: str, device: str = 'cuda'): """Compare diff attn_impl to torch.nn.MultiheadAttention.""" from llmfoundry.models.layers import attention @@ -194,6 +201,19 @@ def gen_tca_mask(): loss0.backward() loss1.backward() + assert y0 is not None + assert y1 is not None + assert tmhsa.out_proj.bias.grad is not None + assert mmhsa.out_proj.bias.grad is not None + assert tmhsa.out_proj.weight.grad is not None + assert mmhsa.out_proj.weight.grad is not None + assert tmhsa.in_proj_bias.grad is not None + assert mmhsa.Wqkv.bias.grad is not None + assert tmhsa.in_proj_weight.grad is not None + assert mmhsa.Wqkv.weight.grad is not None + assert x0.grad is not None + assert x1.grad is not None + assert allclose_helper(y0, y1) assert allclose_helper(tmhsa.out_proj.bias.grad, mmhsa.out_proj.bias.grad) diff --git a/tests/test_hf_config.py b/tests/test_hf_config.py index 4911aa313a..e5399e716c 100644 --- a/tests/test_hf_config.py +++ b/tests/test_hf_config.py @@ -4,7 +4,7 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Mapping +from typing import Any, Dict, Mapping import pytest import torch @@ -51,8 +51,8 @@ strict=True)), ]) def test_hf_config_override( - model_cfg_overrides, - conf_path='scripts/train/yamls/pretrain/testing.yaml', + model_cfg_overrides: Dict[str, Any], + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml', ): AutoConfig.register('mpt', MPTConfig) AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM) @@ -85,12 +85,12 @@ def test_hf_config_override( # load hf causal lm model with config_overrides hf_model_config = deepcopy(test_cfg) - model_cfg = { + model_cfg = DictConfig({ 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': save_path, 'pretrained': False, 'config_overrides': model_cfg_overrides, - } + }) hf_model_config.model = model_cfg hf_model = COMPOSER_MODEL_REGISTRY[hf_model_config.model.name]( diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index fca7bcb2a6..d5372911d5 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import os +import pathlib import sys from composer import Trainer @@ -37,14 +38,15 @@ def delete_transformers_cache(): def get_config( - conf_path='scripts/train/yamls/pretrain/testing.yaml') -> DictConfig: + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml' +) -> DictConfig: os.environ['TOKENIZERS_PARALLELISM'] = 'false' with open(conf_path) as f: test_cfg = om.load(f) return cast(DictConfig, test_cfg) -def test_convert_and_generate_torch(tmp_path): +def test_convert_and_generate_torch(tmp_path: pathlib.Path): delete_transformers_cache() cfg = get_config() @@ -84,7 +86,7 @@ def test_convert_and_generate_torch(tmp_path): @pytest.mark.gpu -def test_convert_and_generate_triton(tmp_path): +def test_convert_and_generate_triton(tmp_path: pathlib.Path): delete_transformers_cache() cfg = get_config() diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py index 8b1c4df4cb..969f42eca3 100644 --- a/tests/test_hf_mpt_gen.py +++ b/tests/test_hf_mpt_gen.py @@ -4,6 +4,7 @@ import pytest from composer.core.precision import get_precision_context from composer.utils import get_device, reproducibility +from omegaconf import DictConfig from omegaconf import OmegaConf as om from llmfoundry import COMPOSER_MODEL_REGISTRY @@ -13,19 +14,20 @@ @pytest.mark.gpu @pytest.mark.parametrize('device', ['cpu', 'gpu']) @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) -def test_init_hfhub_mpt(device, attn_impl): +def test_init_hfhub_mpt(device: str, attn_impl: str): if device == 'cpu' and attn_impl == 'triton': pytest.skip(f'{attn_impl=} not implemented for {device=}.') - device = get_device(device) + composer_device = get_device(device) with open('scripts/train/yamls/pretrain/testing.yaml') as f: test_cfg = om.load(f) + assert isinstance(test_cfg, DictConfig) reproducibility.seed_all(test_cfg.get('seed', 42)) attn_uses_sequence_id = True if test_cfg.get('eos_token_id', None) is not None else False - test_cfg.model = { + test_cfg.model = DictConfig({ 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'mosaicml/mpt-7b', 'pretrained': False, @@ -39,7 +41,7 @@ def test_init_hfhub_mpt(device, attn_impl): 'attn_uses_sequence_id': attn_uses_sequence_id, }, }, - } + }) # build tokenizer tokenizer = build_tokenizer(test_cfg.tokenizer) @@ -50,11 +52,12 @@ def test_init_hfhub_mpt(device, attn_impl): test_cfg.n_params = sum(p.numel() for p in model.parameters()) model.eval() - model = device.module_to_device(model) + model = composer_device.module_to_device(model) - with get_precision_context('amp_bf16' if device.name == 'gpu' else 'fp32'): + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): _ = model.generate( - device.tensor_to_device( + composer_device.tensor_to_device( tokenizer('hello', return_tensors='pt')['input_ids']), max_new_tokens=10, ) diff --git a/tests/test_hf_v_mpt.py b/tests/test_hf_v_mpt.py index c6f600455d..22c9241037 100644 --- a/tests/test_hf_v_mpt.py +++ b/tests/test_hf_v_mpt.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import warnings +from typing import Optional import pytest import torch @@ -40,7 +40,8 @@ ('torch', 0.0, False, None, True), ('triton', 0.0, False, None, True), ]) -def test_compare_hf_v_mpt(attn_impl, dropout, alibi, mask_val, no_attn_mask): +def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, + mask_val: Optional[int], no_attn_mask: bool): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') @@ -144,6 +145,7 @@ def test_compare_hf_v_mpt(attn_impl, dropout, alibi, mask_val, no_attn_mask): model_cfg.max_seq_len), dtype=torch.int64).to(device) # mask out some tokens + assert mask_val is not None batch['attention_mask'][:, model_cfg.max_seq_len // 2:] = mask_val kpm = batch['attention_mask'].view(*batch['attention_mask'].shape, 1) diff --git a/tests/test_icl_datasets.py b/tests/test_icl_datasets.py index 687e5d1032..524aac9fd0 100644 --- a/tests/test_icl_datasets.py +++ b/tests/test_icl_datasets.py @@ -2,18 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import os +import pathlib import random import shutil from pathlib import Path import pytest from omegaconf import OmegaConf as om -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.utils.builders import build_icl_evaluators -def load_icl_config(conf_path='tests/test_tasks.yaml'): +def load_icl_config(conf_path: str = 'tests/test_tasks.yaml'): with open(conf_path) as f: test_cfg = om.load(f) return test_cfg @@ -32,7 +33,9 @@ def tmp_dir(): shutil.rmtree(dirpath) -def run_test(dir, tokenizer, bos_tok=''): +def run_test(dir: pathlib.Path, + tokenizer: PreTrainedTokenizerBase, + bos_tok: str = ''): task_cfg = load_icl_config() evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks, tokenizer, @@ -85,21 +88,21 @@ def run_test(dir, tokenizer, bos_tok=''): assert answer == ' feared violence' -def test_icl_task_loading_gpt2_tokenizer(tmp_dir): +def test_icl_task_loading_gpt2_tokenizer(tmp_dir: pathlib.Path): tokenizer = AutoTokenizer.from_pretrained('gpt2') run_test(tmp_dir, tokenizer) -def test_icl_task_loading_gptj_tokenizer(tmp_dir): +def test_icl_task_loading_gptj_tokenizer(tmp_dir: pathlib.Path): tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6b') run_test(tmp_dir, tokenizer) -def test_icl_task_loading_opt_tokenizer(tmp_dir): +def test_icl_task_loading_opt_tokenizer(tmp_dir: pathlib.Path): tokenizer = AutoTokenizer.from_pretrained('facebook/opt-6.7b') run_test(tmp_dir, tokenizer, '') -def test_icl_task_loading_gptneox_tokenizer(tmp_dir): +def test_icl_task_loading_gptneox_tokenizer(tmp_dir: pathlib.Path): tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') run_test(tmp_dir, tokenizer) diff --git a/tests/test_init_fn.py b/tests/test_init_fn.py index a9b037b718..9355e7a277 100644 --- a/tests/test_init_fn.py +++ b/tests/test_init_fn.py @@ -1,14 +1,15 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import math from collections import OrderedDict from collections.abc import Sequence from functools import partial +from typing import Dict, List, Optional, Tuple, Union import pytest import torch from composer.utils import reproducibility +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from torch import nn @@ -17,14 +18,14 @@ class MLP(nn.Module): - def __init__(self, cfg): + def __init__(self, cfg: Union[ListConfig, DictConfig]): super().__init__() self.fc1 = nn.Linear(cfg.in_features, cfg.out_features, bias=True) self.ln_1 = nn.LayerNorm(cfg.out_features) self.fc2 = nn.Linear(cfg.out_features, cfg.out_features, bias=True) self.fc2._is_residual = True # type: ignore - def forward(self, x): + def forward(self, x: torch.Tensor): y = self.ln_1(self.fc1(x)) res = y y = self.fc2(y) @@ -62,7 +63,7 @@ def test_div_is_residual(is_residual: bool): @pytest.mark.parametrize('fused', [True, False]) -def test_fused_init_helper(fused): +def test_fused_init_helper(fused: bool): reproducibility.seed_all(7) in_features, out_features = 8, 32 @@ -77,7 +78,7 @@ def test_fused_init_helper(fused): if fused: fc._fused = (0, (cfg.out_features // 2,)) # type: ignore - def init_fn_(weight): + def init_fn_(weight: torch.Tensor): # dummy init based on layer width with torch.no_grad(): out_features, _ = weight.shape[:2] @@ -106,10 +107,10 @@ def init_fn_(weight): reason='generic_param_init_fn_ does not init Conv layers', strict=True)), ]) -def test_all_params_init(module): +def test_all_params_init(module: torch.nn.Module): fill_val = torch.finfo(torch.float16).max - def max_fill_init_(weight): + def max_fill_init_(weight: torch.Tensor): # init param with max value with torch.no_grad(): weight.fill_(fill_val) @@ -131,10 +132,10 @@ def max_fill_init_(weight): ('emb_init_uniform_lim', [-1, 4]), ('emb_init_uniform_lim', 0), ('emb_init_uniform_lim', [1, 1]) ]) -def test_emb_init(emb_init_cfg): +def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): reproducibility.seed_all(7) - cfg = { + cfg: Dict[str, Union[int, List[int]]] = { 'vocab_size': 64, 'in_features': 16, 'out_features': 32, @@ -142,31 +143,34 @@ def test_emb_init(emb_init_cfg): } if emb_init_cfg is not None: cfg[emb_init_cfg[0]] = emb_init_cfg[1] - cfg = om.create(cfg) + dict_cfg = om.create(cfg) model = nn.Sequential( OrderedDict([ - ('emb', nn.Embedding(cfg.vocab_size, cfg.in_features)), - ('fc1', nn.Linear(cfg.in_features, cfg.out_features, bias=True)), - ('ln1', nn.LayerNorm(cfg.out_features)), + ('emb', nn.Embedding(dict_cfg.vocab_size, dict_cfg.in_features)), + ('fc1', + nn.Linear(dict_cfg.in_features, dict_cfg.out_features, bias=True)), + ('ln1', nn.LayerNorm(dict_cfg.out_features)), ('act1', nn.ReLU()), - ('fc2', nn.Linear(cfg.out_features, cfg.out_features, bias=True)), + ('fc2', + nn.Linear(dict_cfg.out_features, dict_cfg.out_features, + bias=True)), ])) - model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **cfg)) + model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) - if cfg.get('emb_init_std') is not None: - emb_init_std = cfg.get('emb_init_std') + if dict_cfg.get('emb_init_std') is not None: + emb_init_std = dict_cfg.get('emb_init_std') if emb_init_std == 0: assert (model.emb.weight == 0).all() # type: ignore - elif cfg.get('emb_init_uniform_lim') is not None: - emb_init_uniform_lim = cfg.get('emb_init_uniform_lim') + elif dict_cfg.get('emb_init_uniform_lim') is not None: + emb_init_uniform_lim = dict_cfg.get('emb_init_uniform_lim') if emb_init_uniform_lim == 0: assert (model.emb.weight == 0).all() # type: ignore elif isinstance(emb_init_uniform_lim, Sequence): assert len(emb_init_uniform_lim) <= 2 if len(emb_init_uniform_lim ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: - assert ( - model.emb.weight == emb_init_uniform_lim[0] # type: ignore - ).all() + assert isinstance(model.emb, torch.nn.Embedding) + assert (model.emb.weight == emb_init_uniform_lim[0] + ).all() # type: ignore diff --git a/tests/test_model.py b/tests/test_model.py index 5266b82622..35d34e3626 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,12 +1,12 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import contextlib import copy import gc import os +import pathlib import warnings -from typing import cast +from typing import Any, Dict, Union, cast from unittest import mock import pytest @@ -17,11 +17,11 @@ from composer.optim import DecoupledAdamW from composer.trainer.dist_strategy import prepare_fsdp_module from composer.utils import dist, get_device, reproducibility -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizer, PreTrainedTokenizerFast, - pipeline) + PreTrainedModel, PreTrainedTokenizer, + PreTrainedTokenizerFast, pipeline) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.bloom.modeling_bloom import build_alibi_tensor @@ -35,7 +35,8 @@ def get_config( - conf_path='scripts/train/yamls/pretrain/testing.yaml') -> DictConfig: + conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml' +) -> DictConfig: os.environ['TOKENIZERS_PARALLELISM'] = 'false' print(conf_path) with open(conf_path) as f: @@ -43,7 +44,7 @@ def get_config( return cast(DictConfig, test_cfg) -def get_objs(conf_path='scripts/train/yamls/pretrain/testing.yaml'): +def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') @@ -88,7 +89,7 @@ def get_objs(conf_path='scripts/train/yamls/pretrain/testing.yaml'): return test_cfg, model, optimizer -def gen_random_batch(batch_size, test_cfg): +def gen_random_batch(batch_size: int, test_cfg: Union[DictConfig, ListConfig]): # generate input batch of random data, suitable for a Causal or Prefix LM batch = {} batch['input_ids'] = torch.randint( @@ -107,7 +108,8 @@ def gen_random_batch(batch_size, test_cfg): return batch -def gen_random_enc_dec_batch(batch_size, vocab_size, max_seq_len, device): +def gen_random_enc_dec_batch(batch_size: int, vocab_size: int, max_seq_len: int, + device: str): # generate input batch of random data, suitable for a T5 batch = {} batch['input_ids'] = torch.randint(low=0, @@ -125,7 +127,7 @@ def gen_random_enc_dec_batch(batch_size, vocab_size, max_seq_len, device): return batch -def test_full_forward_and_backward(batch_size=2): +def test_full_forward_and_backward(batch_size: int = 2): test_cfg, model, optimizer = get_objs( conf_path='scripts/train/yamls/pretrain/testing.yaml') @@ -143,7 +145,7 @@ def test_full_forward_and_backward(batch_size=2): assert not torch.equal(original_params, updated_params) -def test_attention_mechanism(batch_size=2): +def test_attention_mechanism(batch_size: int = 2): test_cfg, model, _ = get_objs( conf_path='scripts/train/yamls/pretrain/testing.yaml') @@ -201,7 +203,8 @@ def test_attention_mechanism(batch_size=2): @pytest.mark.parametrize('prefixlm', [False, True]) -def test_full_forward_and_backward_gpt2_small(prefixlm, batch_size=2): +def test_full_forward_and_backward_gpt2_small(prefixlm: bool, + batch_size: int = 2): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') @@ -249,7 +252,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm, batch_size=2): assert not torch.equal(original_params, updated_params) -def test_full_forward_and_backward_t5_small(batch_size=2): +def test_full_forward_and_backward_t5_small(batch_size: int = 2): warnings.filterwarnings( action='ignore', message='Torchmetrics v0.9 introduced a new argument class property') @@ -296,7 +299,7 @@ def test_full_forward_and_backward_t5_small(batch_size=2): [('torch', torch.float16), ('torch', torch.bfloat16), pytest.param('flash', torch.float16, marks=pytest.mark.gpu), pytest.param('flash', torch.bfloat16, marks=pytest.mark.gpu)]) -def test_determinism(attn_impl: str, precision): +def test_determinism(attn_impl: str, precision: torch.dtype): if not torch.cuda.is_available(): pytest.skip( 'This test requires CUDA to be available in order to run with bfloat16 precision.' @@ -420,7 +423,7 @@ def test_loss_fn(): @pytest.mark.parametrize('prefixlm', [False, True]) -def test_opt_wrapping(prefixlm): +def test_opt_wrapping(prefixlm: bool): conf = { 'model': { 'name': 'hf_prefix_lm' if prefixlm else 'hf_causal_lm', @@ -449,7 +452,7 @@ def test_opt_wrapping(prefixlm): @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) @pytest.mark.parametrize('no_bias', [False, True]) -def test_mpt_creation(norm_type, no_bias): +def test_mpt_creation(norm_type: str, no_bias: bool): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', @@ -485,6 +488,7 @@ def test_mpt_creation(norm_type, no_bias): for block in mpt.transformer.blocks: assert isinstance(block, MPTBlock) assert block.norm_1.weight.shape == torch.Size([d_model]) + assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) assert block.ffn.up_proj.weight.shape == torch.Size( [hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model]) @@ -499,7 +503,7 @@ def test_mpt_creation(norm_type, no_bias): ('triton', 'gpu'), ('torch', 'gpu')]) @pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_padding(attention_impl, device, alibi): +def test_forward_with_padding(attention_impl: str, device: str, alibi: bool): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -509,7 +513,7 @@ def test_forward_with_padding(attention_impl, device, alibi): pytest.skip(f'alibi only implemented with torch and triton attention.') reproducibility.seed_all(1234) - device = get_device(device) + composer_device = get_device(device) hf_config = MPTConfig( init_device='cpu', @@ -531,41 +535,43 @@ def test_forward_with_padding(attention_impl, device, alibi): ) mpt = MPTForCausalLM(hf_config) mpt.eval() - mpt = device.module_to_device(mpt) + mpt = composer_device.module_to_device(mpt) - with get_precision_context('amp_bf16' if device.name == 'gpu' else 'fp32'): + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): # padding on the right side of the input right_padding_input_ids = torch.tensor( [[11274, 16390, 11, 50256, 50256, 50256], [11274, 16390, 11, 50256, 50256, 50256]]) - right_padding_input_ids = device.tensor_to_device( + right_padding_input_ids = composer_device.tensor_to_device( right_padding_input_ids) right_padding_attention_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]).bool() - right_padding_attention_mask = device.tensor_to_device( + right_padding_attention_mask = composer_device.tensor_to_device( right_padding_attention_mask) # padding in the middle of the input middle_padding_input_ids = torch.tensor( [[11274, 16390, 50256, 50256, 50256, 11], [11274, 16390, 50256, 50256, 50256, 11]]) - middle_padding_input_ids = device.tensor_to_device( + middle_padding_input_ids = composer_device.tensor_to_device( middle_padding_input_ids) middle_padding_attention_mask = torch.tensor([[1, 1, 0, 0, 0, 1], [1, 1, 0, 0, 0, 1]]).bool() - middle_padding_attention_mask = device.tensor_to_device( + middle_padding_attention_mask = composer_device.tensor_to_device( middle_padding_attention_mask) # padding on the left side of the input left_padding_input_ids = torch.tensor( [[50256, 50256, 50256, 11274, 16390, 11], [50256, 50256, 50256, 11274, 16390, 11]]) - left_padding_input_ids = device.tensor_to_device(left_padding_input_ids) + left_padding_input_ids = composer_device.tensor_to_device( + left_padding_input_ids) left_padding_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]]).bool() - left_padding_attention_mask = device.tensor_to_device( + left_padding_attention_mask = composer_device.tensor_to_device( left_padding_attention_mask) # a single batch with padding in different places @@ -573,10 +579,11 @@ def test_forward_with_padding(attention_impl, device, alibi): [11274, 16390, 11, 50256, 50256, 50256], # right padding [11274, 16390, 50256, 50256, 50256, 11] ]) # middle padding - batched_input_ids = device.tensor_to_device(batched_input_ids) + batched_input_ids = composer_device.tensor_to_device(batched_input_ids) batched_attention_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 0, 0, 0, 1]]).bool() - batched_attention_mask = device.tensor_to_device(batched_attention_mask) + batched_attention_mask = composer_device.tensor_to_device( + batched_attention_mask) right_padding_output = mpt( right_padding_input_ids, @@ -615,7 +622,7 @@ def test_forward_with_padding(attention_impl, device, alibi): @pytest.mark.parametrize('attention_impl', ['torch', 'triton']) -def test_advanced_mask_building(attention_impl): +def test_advanced_mask_building(attention_impl: str): # Test that the correct attention mask is created when both # prefix_mask and sequence_id are used hf_config = MPTConfig( @@ -674,7 +681,7 @@ def test_advanced_mask_building(attention_impl): ('triton', 'gpu'), ('torch', 'gpu')]) @pytest.mark.parametrize('alibi', [True, False]) -def test_generate(attention_impl, device, alibi): +def test_generate(attention_impl: str, device: str, alibi: bool): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': @@ -685,7 +692,7 @@ def test_generate(attention_impl, device, alibi): pytest.skip(f'alibi only implemented with torch and triton attention.') reproducibility.seed_all(1234) - device = get_device(device) + composer_device = get_device(device) hf_config = MPTConfig( init_device='cpu', @@ -703,35 +710,39 @@ def test_generate(attention_impl, device, alibi): ) mpt = MPTForCausalLM(hf_config) mpt.eval() - mpt = device.module_to_device(mpt) + mpt = composer_device.module_to_device(mpt) # padding on the left of the input left_padding_input_ids = torch.tensor( [[50256, 50256, 50256, 11274, 16390, 11], [50256, 50256, 50256, 11274, 16390, 11]]) - left_padding_input_ids = device.tensor_to_device(left_padding_input_ids) + left_padding_input_ids = composer_device.tensor_to_device( + left_padding_input_ids) left_padding_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 1, 1]]) - left_padding_attention_mask = device.tensor_to_device( + left_padding_attention_mask = composer_device.tensor_to_device( left_padding_attention_mask) # no padding in the input no_padding_input_ids = torch.tensor([[11274, 16390, 11], [11274, 16390, 11]]) - no_padding_input_ids = device.tensor_to_device(no_padding_input_ids) + no_padding_input_ids = composer_device.tensor_to_device( + no_padding_input_ids) no_padding_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]) - no_padding_attention_mask = device.tensor_to_device( + no_padding_attention_mask = composer_device.tensor_to_device( no_padding_attention_mask) # a single batch with different amounts of left padding in the input batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11], [50256, 50256, 16, 11274, 16390, 11]]) - batched_input_ids = device.tensor_to_device(batched_input_ids) + batched_input_ids = composer_device.tensor_to_device(batched_input_ids) batched_attention_mask = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 1, 1, 1, 1]]).bool() - batched_attention_mask = device.tensor_to_device(batched_attention_mask) + batched_attention_mask = composer_device.tensor_to_device( + batched_attention_mask) - with get_precision_context('amp_bf16' if device.name == 'gpu' else 'fp32'): + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): # check that a batch with different amounts of padding doesn't crash # and produces the right output shape batched_generation = mpt.generate(input_ids=batched_input_ids, @@ -763,7 +774,8 @@ def test_generate(attention_impl, device, alibi): @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('use_cache', [False, True]) -def test_generate_with_device_map(tmp_path, world_size, use_cache): +def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, + use_cache: bool): if not torch.cuda.is_available(): pytest.skip(f'This test requires CUDA to be available.') if not torch.cuda.device_count() >= world_size: @@ -809,14 +821,15 @@ def test_generate_with_device_map(tmp_path, world_size, use_cache): device_map=device_map, ) with torch.autocast('cuda', dtype=torch.bfloat16): - out = pipe( + _ = pipe( 'The quick fox jumped over', max_length=10, do_sample=True, ) -def check_hf_model_equivalence(model1, model2): +def check_hf_model_equivalence(model1: PreTrainedModel, + model2: PreTrainedModel): # Checks that two huggingface models are equivalent (config and # parameters) expected_model_config_dict = model1.config.to_dict() @@ -837,7 +850,7 @@ def check_hf_model_equivalence(model1, model2): torch.testing.assert_close(p1, p2) -def test_save_from_pretrained(tmp_path): +def test_save_from_pretrained(tmp_path: pathlib.Path): # Test that MPT can be used with the HuggingFace # save_pretrained/from_pretrained api. hf_config = MPTConfig( @@ -862,7 +875,7 @@ def test_save_from_pretrained(tmp_path): @pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache_and_padding(alibi): +def test_forward_with_cache_and_padding(alibi: bool): # Tests that the result is the same with or without padding when using kv caching hf_config = MPTConfig( init_device='cpu', @@ -935,7 +948,7 @@ def test_forward_with_cache_and_padding(alibi): ('torch', 'gpu'), ]) @pytest.mark.parametrize('alibi', [True, False]) -def test_forward_with_cache(attn_impl, device, alibi): +def test_forward_with_cache(attn_impl: str, device: str, alibi: bool): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': @@ -945,7 +958,7 @@ def test_forward_with_cache(attn_impl, device, alibi): if alibi and attn_impl == 'flash': pytest.skip(f'alibi only implemented with torch and triton attention.') - device = get_device(device) + composer_device = get_device(device) hf_config = MPTConfig( init_device='cpu', @@ -970,15 +983,17 @@ def test_forward_with_cache(attn_impl, device, alibi): ) reproducibility.seed_all(1234) mpt = MPTForCausalLM(hf_config) - mpt = device.module_to_device(mpt) + mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if device.name == 'gpu' else 'fp32'): + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): reproducibility.seed_all(1234) first_input_ids = torch.tensor([[11274, 16390, 11]]) - first_input_ids = device.tensor_to_device(first_input_ids) + first_input_ids = composer_device.tensor_to_device(first_input_ids) first_attention_mask = torch.tensor([[1, 1, 1]]).bool() - first_attention_mask = device.tensor_to_device(first_attention_mask) + first_attention_mask = composer_device.tensor_to_device( + first_attention_mask) # start with passing the first three tokens through first_output = mpt(first_input_ids, attention_mask=first_attention_mask) @@ -1001,9 +1016,10 @@ def test_forward_with_cache(attn_impl, device, alibi): reproducibility.seed_all(1234) second_input_ids = torch.tensor([[11274, 16390, 11, 11274]]) - second_input_ids = device.tensor_to_device(second_input_ids) + second_input_ids = composer_device.tensor_to_device(second_input_ids) second_attention_mask = torch.tensor([[1, 1, 1, 1]]).bool() - second_attention_mask = device.tensor_to_device(second_attention_mask) + second_attention_mask = composer_device.tensor_to_device( + second_attention_mask) # pass through the fourth token by itself, using the key-value cache second_output = mpt( @@ -1043,7 +1059,7 @@ def test_forward_with_cache(attn_impl, device, alibi): @pytest.mark.parametrize('alibi', [True, False]) -def test_generate_with_past_kv(alibi): +def test_generate_with_past_kv(alibi: bool): hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1103,7 +1119,8 @@ def test_generate_with_past_kv(alibi): 'top_p': 0.95 }]) @pytest.mark.parametrize('alibi', [True, False]) -def test_generation_kwargs_dont_crash(generation_kwargs, alibi): +def test_generation_kwargs_dont_crash(generation_kwargs: Dict[str, Any], + alibi: bool): hf_config = MPTConfig( init_device='cpu', d_model=128, @@ -1134,7 +1151,7 @@ def test_generation_kwargs_dont_crash(generation_kwargs, alibi): @pytest.mark.gpu @pytest.mark.parametrize('attention_impl', ['torch', 'flash', 'triton']) @pytest.mark.parametrize('alibi', [True, False]) -def test_model_to(attention_impl, alibi): +def test_model_to(attention_impl: str, alibi: bool): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( @@ -1238,7 +1255,8 @@ def test_alibi_vs_hf(): @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( - attn_impl, device, alibi, output_attentions, output_hidden_states): + attn_impl: str, device: str, alibi: bool, output_attentions: bool, + output_hidden_states: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1249,7 +1267,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( if output_attentions and attn_impl in ['flash', 'triton']: pytest.skip(f'output_attentions only implemented with torch attention.') - device = get_device(device) + composer_device = get_device(device) n_layers = 2 @@ -1276,15 +1294,16 @@ def test_forward_with_output_attentions_and_output_hidden_states( ) reproducibility.seed_all(1234) mpt = MPTForCausalLM(hf_config) - mpt = device.module_to_device(mpt) + mpt = composer_device.module_to_device(mpt) mpt.eval() - with get_precision_context('amp_bf16' if device.name == 'gpu' else 'fp32'): + with get_precision_context('amp_bf16' if composer_device.name == + 'gpu' else 'fp32'): reproducibility.seed_all(1234) input_ids = torch.tensor([[11274, 16390, 11]]) - input_ids = device.tensor_to_device(input_ids) + input_ids = composer_device.tensor_to_device(input_ids) attention_mask = torch.tensor([[1, 1, 1]]).bool() - attention_mask = device.tensor_to_device(attention_mask) + attention_mask = composer_device.tensor_to_device(attention_mask) # start with passing the first three tokens through outputs = mpt( @@ -1303,7 +1322,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( @pytest.mark.gpu @pytest.mark.parametrize('init_device', ['cpu', 'meta', 'mixed']) @pytest.mark.parametrize('world_size', [2]) -def test_hf_init(tmp_path, +def test_hf_init(tmp_path: pathlib.Path, init_device: str, world_size: int, batch_size: int = 1): @@ -1366,7 +1385,9 @@ def test_hf_init(tmp_path, trust_remote_code=True) tokenizer = build_tokenizer(test_cfg.tokenizer) - optimizer = DecoupledAdamW(model.parameters(), lr=1e-5, betas=[0.9, 0.99]) + optimizer = DecoupledAdamW(model.parameters(), + lr=1e-5, + betas=tuple([0.9, 0.99])) prepare_fsdp_module(model, optimizer, fsdp_config, precision, device, False) @@ -1388,7 +1409,7 @@ def test_hf_init(tmp_path, @pytest.mark.gpu -def test_head_dim_8_triton_mqa_attn(batch_size=2): +def test_head_dim_8_triton_mqa_attn(batch_size: int = 2): test_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') test_cfg.device = torch.cuda.current_device() diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 1957fe2443..f8ff2c9508 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import pathlib + import torch from composer.utils import reproducibility from transformers import AutoConfig, AutoModelForCausalLM @@ -24,7 +26,7 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): return batch -def test_onnx_export(tmp_path): +def test_onnx_export(tmp_path: pathlib.Path): reproducibility.seed_all(42) AutoConfig.register('mpt', MPTConfig) AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM) @@ -69,13 +71,14 @@ def test_onnx_export(tmp_path): with torch.no_grad(): orig_out = mpt(**sample_input) - import onnx # type: ignore - import onnx.checker # type: ignore - import onnxruntime as ort # type: ignore + import onnx + import onnx.checker + import onnxruntime as ort + from onnx import checker _ = onnx.load(str(tmp_path / 'mpt.onnx')) - onnx.checker.check_model(str(tmp_path / 'mpt.onnx')) + checker.check_model(str(tmp_path / 'mpt.onnx')) ort_session = ort.InferenceSession(str(tmp_path / 'mpt.onnx')) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 43b013e57a..5f1e826177 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -5,7 +5,7 @@ from transformers import AutoTokenizer -def get_config(conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml'): +def get_config(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'): with open(conf_path) as f: test_cfg = om.load(f) return test_cfg diff --git a/tests/test_training.py b/tests/test_training.py index 0f1c7a4bf2..7b783318d1 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,12 +1,13 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 - import os import sys import warnings +from typing import Optional, Union import pytest import torch +from omegaconf import DictConfig from omegaconf import OmegaConf as om # Add repo root to path so we can import scripts and test it @@ -15,10 +16,11 @@ from scripts.train.train import main -def gpt_tiny_cfg(conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml'): +def gpt_tiny_cfg(conf_path: str = 'scripts/train/yamls/pretrain/mpt-125m.yaml'): """Create gpt tiny cfg.""" with open(conf_path) as f: test_cfg = om.load(f) + assert isinstance(test_cfg, DictConfig) # removes requirement to download / process train set test_cfg.train_loader.dataset = test_cfg.eval_loader.dataset @@ -51,7 +53,7 @@ def gpt_tiny_cfg(conf_path='scripts/train/yamls/pretrain/mpt-125m.yaml'): reason='testing with cuda requires GPU')), ]) @pytest.mark.parametrize('logit_scale', [None, 0.036, 'inv_sqrt_d_model']) -def test_train(device, logit_scale): +def test_train(device: str, logit_scale: Optional[Union[float, str]]): if not os.path.isdir('./my-copy-c4/val'): pytest.xfail('c4 dataset not set up as expected')