diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 9cbe0f1af1..36a2d426df 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -88,48 +88,29 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]): raise ValueError(f'Not sure how to build algorithm: {name}') -def build_optimizer(cfg: DictConfig, model: torch.nn.Module): - if cfg.name == 'decoupled_adamw': - return DecoupledAdamW(model.parameters(), - lr=cfg.lr, - betas=cfg.betas, - eps=cfg.eps, - weight_decay=cfg.weight_decay) - elif cfg.name == 'decoupled_lionw': - return DecoupledLionW(model.parameters(), - lr=cfg.lr, - betas=cfg.betas, - weight_decay=cfg.weight_decay) - elif cfg.name == 'clip_lion': - return DecoupledClipLion(model.parameters(), - lr=cfg.lr, - betas=cfg.betas, - weight_decay=cfg.weight_decay, - outlier_threshold=cfg.outlier_threshold) - elif cfg.name == 'adalr_lion': - return DecoupledAdaLRLion(model.parameters(), - lr=cfg.lr, - betas=cfg.betas, - weight_decay=cfg.weight_decay, - outlier_threshold=cfg.outlier_threshold, - timeout=cfg.timeout, - lr_penalty=cfg.lr_penalty, - min_scale=cfg.min_scale) +def build_optimizer(model: torch.nn.Module, name: str, + optimizer_config: Dict[str, Any]): + if name == 'decoupled_adamw': + return DecoupledAdamW(model.parameters(), **optimizer_config) + elif name == 'decoupled_lionw': + return DecoupledLionW(model.parameters(), **optimizer_config) + elif name == 'clip_lion': + return DecoupledClipLion(model.parameters(), **optimizer_config) + elif name == 'adalr_lion': + return DecoupledAdaLRLion(model.parameters(), **optimizer_config) else: - raise ValueError(f'Not sure how to build optimizer: {cfg.name}') + raise ValueError(f'Not sure how to build optimizer: {name}') -def build_scheduler(cfg: DictConfig): - if cfg.name == 'constant_with_warmup': - return ConstantWithWarmupScheduler(t_warmup=cfg.t_warmup) - elif cfg.name == 'cosine_with_warmup': - return CosineAnnealingWithWarmupScheduler(t_warmup=cfg.t_warmup, - alpha_f=cfg.alpha_f) - elif cfg.name == 'linear_decay_with_warmup': - return LinearWithWarmupScheduler(t_warmup=cfg.t_warmup, - alpha_f=cfg.alpha_f) +def build_scheduler(name: str, scheduler_config: Dict[str, Any]): + if name == 'constant_with_warmup': + return ConstantWithWarmupScheduler(**scheduler_config) + elif name == 'cosine_with_warmup': + return CosineAnnealingWithWarmupScheduler(**scheduler_config) + elif name == 'linear_decay_with_warmup': + return LinearWithWarmupScheduler(**scheduler_config) else: - raise ValueError(f'Not sure how to build scheduler: {cfg.name}') + raise ValueError(f'Not sure how to build scheduler: {name}') def build_tokenizer(om_tokenizer_config: DictConfig) -> PreTrainedTokenizerBase: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index aa210b3b37..09dc5e7e6e 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union from composer.utils import dist -from omegaconf import DictConfig +from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om from llmfoundry.models.utils import init_empty_weights @@ -16,14 +16,24 @@ def pop_config(cfg: DictConfig, key: str, must_exist: bool = True, - default_value: Any = None) -> Any: + default_value: Any = None, + convert: bool = False) -> Any: """Pop a value from the main config file and return it. If the key does not exist, return the default_value or raise a RuntimeError - depending on the must_exist flag. + depending on the must_exist flag. If the convert flag is set to True, then + we will convert the value to a python object using OmegaConf.to_container. """ value = cfg.pop(key, None) - if value is not None: + if value is not None and convert: + if not isinstance(value, DictConfig) and not isinstance( + value, ListConfig): + raise ValueError( + f'The key: {key} has a value: {value} that cannot be \ + converted to a dict or list. Please check your yaml.' + ) + return om.to_container(value) + elif value is not None: return value elif must_exist: raise NameError( diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml new file mode 100644 index 0000000000..504bbff7af --- /dev/null +++ b/mcli/mcli-llama2-finetune.yaml @@ -0,0 +1,151 @@ +integrations: +- integration_type: git_repo + git_repo: mosaicml/llm-foundry + git_commit: 148c0793907a6afa48a892620e637ef5f90cdaf1 # TODO: repin this after next release + pip_install: -e .[gpu] + ssh_clone: false # Should be true if using a private repo + +command: | + cd llm-foundry/scripts + composer train/train.py /mnt/config/parameters.yaml +image: mosaicml/llm-foundry:1.13.1_cu117-latest +name: llama2-finetune + +compute: + # Note: Finetuning the 70b model requires at least 16x80GB GPUs + gpus: 8 # Number of GPUs to use + ## These configurations are optional + # cluster: TODO # Name of the cluster to use for this run + # gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments + +# The below is injected as a YAML file: /mnt/config/parameters.yaml +parameters: + tokenizer_name: meta-llama/Llama-2-7b-hf + max_seq_len: 4096 + global_seed: 17 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + + # IMPORTANT: Uncomment if using the 70b model + # max_split_size_mb: 512 + + # Model + model: + name: hf_causal_lm + pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf + pretrained: true + # Note: you must have set the HUGGING_FACE_HUB_TOKEN environment variable and have access to the llama2 models + use_auth_token: true + attention_patch_type: triton + + # Tokenizer + tokenizer: + name: ${tokenizer_name} + kwargs: + model_max_length: ${max_seq_len} + + # Dataloaders + train_loader: + name: finetuning + dataset: + hf_name: mosaicml/dolly_hhrlhf + split: train + max_seq_len: ${max_seq_len} + allow_pad_trimming: false + decoder_only_format: true + shuffle: true + # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...` + # # to profile this run's optimal packing_ratio as it depends on GPU count, + # # batch size, sequence length + # packing_ratio: + drop_last: true + num_workers: 8 + pin_memory: false + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + + eval_loader: + name: finetuning + dataset: + hf_name: mosaicml/dolly_hhrlhf + split: test + max_seq_len: ${max_seq_len} + allow_pad_trimming: false + decoder_only_format: true + # packing_ratio: + shuffle: false + drop_last: true + num_workers: 8 + pin_memory: false + prefetch_factor: 2 + persistent_workers: true + timeout: 0 + + # Optimization + scheduler: + name: cosine_with_warmup + t_warmup: 100ba + alpha_f: 0.1 + + # Note: You may want to change learning rate, betas, weight decay + optimizer: + name: decoupled_lionw + lr: 5.0e-7 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 + + algorithms: + gradient_clipping: + clipping_type: norm + clipping_threshold: 1.0 + + max_duration: 1ep + eval_first: false + eval_interval: 1ep + eval_subset_num_batches: -1 + global_train_batch_size: 64 + + # System + seed: ${global_seed} + device_eval_batch_size: 8 + device_train_microbatch_size: auto + precision: amp_bf16 + + # FSDP + fsdp_config: + sharding_strategy: FULL_SHARD + mixed_precision: PURE + activation_checkpointing: true + activation_checkpointing_reentrant: false + activation_cpu_offload: false + limit_all_gathers: true + verbose: false + + # Logging + progress_bar: false + log_to_console: true + console_log_interval: 1ba + + callbacks: + speed_monitor: + window_size: 10 + lr_monitor: {} + memory_monitor: {} + runtime_estimator: {} + +# loggers: +# wandb: {} + +# Checkpoint to local filesystem or remote object store +# save_interval: 2000ba +# save_num_checkpoints_to_keep: 1 # Important, this cleans up checkpoints saved to DISK +# save_folder: ./{run_name}/checkpoints +# save_folder: s3://my-bucket/my-folder/{run_name}/checkpoints + +# Load from local filesystem or remote object store +# load_path: ./gpt-1b/checkpoints/latest-rank{rank}.pt +# load_path: s3://my-bucket/my-folder/gpt-1b/checkpoints/latest-rank{rank}.pt diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 26c10b0d8e..72377d0fcd 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -1,120 +1,34 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -# Note: This script is specifically for converting MPT composer checkpoints to HuggingFace format -# For composer checkpoints containing model that are in the transformers library, see -# https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.models.write_huggingface_pretrained_from_composer_checkpoint.html - import json import os +import random +import string import tempfile from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import sentencepiece as spm import torch import transformers +from composer.models.huggingface import get_hf_config_from_composer_state_dict from composer.utils import (get_file, maybe_create_object_store_from_uri, parse_uri, safe_torch_load) from transformers import (AutoConfig, AutoTokenizer, PretrainedConfig, - PreTrainedTokenizer) + PreTrainedTokenizer, PreTrainedTokenizerBase) from llmfoundry import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility -# TODO: maybe move this functionality to Composer -def get_hf_config_from_composer_state_dict( - state_dict: Dict[str, Any]) -> PretrainedConfig: - if 'state' not in state_dict: - raise RuntimeError( - 'Unexpected composer state dictionary. Did you pass in a full composer checkpoint?' - ) - if 'integrations' not in state_dict[ - 'state'] or 'huggingface' not in state_dict['state']['integrations']: - raise RuntimeError( - 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!' - ) - hf_config_dict = state_dict['state']['integrations']['huggingface'][ - 'model']['config']['content'] - - # Always set init_device='cpu' - hf_config_dict['init_device'] = 'cpu' - - AutoConfig.register('mpt', MPTConfig) - - # backwards compatibility changes - if hf_config_dict['model_type'] == 'mosaic_gpt': - hf_config_dict['model_type'] = 'mpt' - - if 'attn_config' not in hf_config_dict: - attn_config = {} - attn_config['attn_type'] = 'multihead_attention' - attn_config['attn_pdrop'] = hf_config_dict['attn_pdrop'] - del hf_config_dict['attn_pdrop'] - attn_config['attn_impl'] = hf_config_dict['attn_impl'] - del hf_config_dict['attn_impl'] - attn_config['qk_ln'] = hf_config_dict['attn_qk_ln'] - del hf_config_dict['attn_qk_ln'] - attn_config['clip_qkv'] = hf_config_dict['attn_clip_qkv'] - del hf_config_dict['attn_clip_qkv'] - attn_config['softmax_scale'] = hf_config_dict['softmax_scale'] - del hf_config_dict['softmax_scale'] - attn_config['prefix_lm'] = hf_config_dict['prefix_lm'] - del hf_config_dict['prefix_lm'] - attn_config['attn_uses_sequence_id'] = hf_config_dict[ - 'attn_uses_sequence_id'] - del hf_config_dict['attn_uses_sequence_id'] - attn_config['alibi'] = hf_config_dict['alibi'] - del hf_config_dict['alibi'] - attn_config['alibi_bias_max'] = hf_config_dict['alibi_bias_max'] - del hf_config_dict['alibi_bias_max'] - - hf_config_dict['attn_config'] = attn_config - - if 'init_config' not in hf_config_dict: - init_config = {} - - init_config['name'] = hf_config_dict['param_init_fn'] - del hf_config_dict['param_init_fn'] - init_config['fan_mode'] = hf_config_dict['fan_mode'] - del hf_config_dict['fan_mode'] - init_config['init_nonlinearity'] = hf_config_dict['init_nonlinearity'] - del hf_config_dict['init_nonlinearity'] - init_config['init_gain'] = hf_config_dict['init_gain'] - del hf_config_dict['init_gain'] - init_config['init_std'] = hf_config_dict['init_std'] - del hf_config_dict['init_std'] - init_config['init_div_is_residual'] = hf_config_dict[ - 'init_div_is_residual'] - del hf_config_dict['init_div_is_residual'] - init_config['emb_init_std'] = hf_config_dict['emb_init_std'] - del hf_config_dict['emb_init_std'] - init_config['emb_init_uniform_lim'] = hf_config_dict[ - 'emb_init_uniform_lim'] - del hf_config_dict['emb_init_uniform_lim'] - - hf_config_dict['init_config'] = init_config - - if 'mlp_ratio' in hf_config_dict: - hf_config_dict['expansion_ratio'] = hf_config_dict['mlp_ratio'] - del hf_config_dict['mlp_ratio'] - - if 'low_precision_layernorm' in hf_config_dict: - if hf_config_dict['low_precision_layernorm']: - hf_config_dict['norm_type'] = 'low_precision_layernorm' - else: - hf_config_dict['norm_type'] = 'layernorm' - del hf_config_dict['low_precision_layernorm'] - - return AutoConfig.for_model(**hf_config_dict) - - -# TODO: maybe move this functionality to Composer +# TODO: move this functionality to composer once the bug fixes are upstreamed def get_hf_tokenizer_from_composer_state_dict( - state_dict: Dict[str, Any]) -> Optional[PreTrainedTokenizer]: + state_dict: Dict[str, Any], + tokenizer_save_dir: Optional[str] = None +) -> Optional[PreTrainedTokenizer]: if 'state' not in state_dict: raise RuntimeError( 'Unexpected composer state dictionary. Did you pass in a full composer checkpoint?' @@ -128,38 +42,51 @@ def get_hf_tokenizer_from_composer_state_dict( 'tokenizer'] hf_tokenizer = None if hf_tokenizer_state != {}: - with tempfile.TemporaryDirectory() as _tmp_dir: - for filename, saved_content in hf_tokenizer_state.items(): - tokenizer_file_path = Path( - _tmp_dir) / f'{filename}{saved_content["file_extension"]}' - if saved_content['file_extension'] == '.json': - with open(tokenizer_file_path, 'w') as _tmp_file: - json.dump(saved_content['content'], _tmp_file) - elif saved_content['file_extension'] == '.txt': - with open(tokenizer_file_path, 'w') as _tmp_file: - for line in saved_content['content']: - _tmp_file.write(line) - _tmp_file.write('\n') - elif saved_content['file_extension'] == '.model': - s = spm.SentencePieceProcessor() - s.load_from_serialized_proto(saved_content['content']) - with open(tokenizer_file_path, 'wb') as _tmp_file: - _tmp_file.write(s.serialized_model_proto()) - hf_tokenizer = AutoTokenizer.from_pretrained(_tmp_dir) - - # remove 'name_or_path' - hf_tokenizer.name_or_path = '' - hf_tokenizer.init_kwargs['name_or_path'] = '' + if tokenizer_save_dir is None: + unique_suffix = ''.join( + random.choices(string.ascii_letters + string.digits, k=6)) + tokenizer_save_dir = os.path.join( + os.getcwd(), f'tokenizer-save-dir-{unique_suffix}') + os.makedirs(tokenizer_save_dir, exist_ok=True) + + for filename, saved_content in hf_tokenizer_state.items(): + # This cannot be a temporary directory because huggingface relies on the slow tokenizer file + # being persistent on disk + tokenizer_file_path = Path( + tokenizer_save_dir + ) / f'{filename}{saved_content["file_extension"]}' + if saved_content['file_extension'] == '.json': + with open(tokenizer_file_path, 'w') as _tmp_file: + json.dump(saved_content['content'], _tmp_file) + elif saved_content['file_extension'] == '.txt': + with open(tokenizer_file_path, 'w') as _tmp_file: + for line in saved_content['content']: + _tmp_file.write(line) + _tmp_file.write('\n') + elif saved_content['file_extension'] == '.py': + with open(tokenizer_file_path, 'w') as _tmp_file: + _tmp_file.write(saved_content['content']) + elif saved_content['file_extension'] == '.model': + s = spm.SentencePieceProcessor() + s.load_from_serialized_proto(saved_content['content']) + with open(tokenizer_file_path, 'wb') as _tmp_file: + _tmp_file.write(s.serialized_model_proto()) + + hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir) + + # remove 'name_or_path' + hf_tokenizer.name_or_path = '' + hf_tokenizer.init_kwargs['name_or_path'] = '' return hf_tokenizer def write_huggingface_pretrained_from_composer_checkpoint( - checkpoint_path: Union[Path, str], - output_path: Union[Path, str], - output_precision: str = 'fp32', - local_checkpoint_save_location: Optional[Union[Path, - str]] = None) -> None: + checkpoint_path: Union[Path, str], + output_path: Union[Path, str], + output_precision: str = 'fp32', + local_checkpoint_save_location: Optional[Union[Path, str]] = None +) -> Tuple[PretrainedConfig, Optional[PreTrainedTokenizerBase]]: """Convert a Composer checkpoint to a pretrained HF checkpoint folder. Write a ``config.json`` and ``pytorch_model.bin``, like @@ -274,12 +201,14 @@ def write_huggingface_pretrained_from_composer_checkpoint( print('Done.') print('#' * 30) + return hf_config, hf_tokenizer + def parse_args() -> Namespace: """Parse commandline arguments.""" parser = ArgumentParser( description= - 'Convert an MPT Composer checkpoint and Omegaconf model config into a standard HuggingFace checkpoint folder, and optionally upload to the hub.' + 'Convert a HuggingFace causal LM in a Composer checkpoint into a standard HuggingFace checkpoint folder, and optionally upload to the hub.' ) parser.add_argument('--composer_path', type=str, required=True) parser.add_argument('--hf_output_path', type=str, required=True) @@ -297,9 +226,16 @@ def parse_args() -> Namespace: def convert_composer_to_hf(args: Namespace) -> None: + # Register MPT auto classes so that this script works with MPT + # This script will not work without modification for other custom models, + # but will work for other HuggingFace causal LMs + AutoConfig.register('mpt', MPTConfig) + MPTConfig.register_for_auto_class() + MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') + _, _, local_folder_path = parse_uri(args.hf_output_path) - write_huggingface_pretrained_from_composer_checkpoint( + config, tokenizer = write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path=args.composer_path, output_path=local_folder_path, output_precision=args.output_precision, @@ -311,19 +247,18 @@ def convert_composer_to_hf(args: Namespace) -> None: 'bf16': torch.bfloat16, }[args.output_precision] - # register config auto class - MPTConfig.register_for_auto_class() + print(f'Loading model from {local_folder_path}') + if config.model_type == 'mpt': + config.attn_config['attn_impl'] = 'torch' - # register model auto class - MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') + if config.model_type == 'mpt': + loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path, + config=config, + torch_dtype=dtype) + else: + loaded_hf_model = transformers.AutoModelForCausalLM.from_pretrained( + local_folder_path, config=config, torch_dtype=dtype) - print(f'Loading model from {local_folder_path}') - config = MPTConfig.from_pretrained(local_folder_path) - # You have to edit the config this way, because attn_config is a nested dictionary - config.attn_config['attn_impl'] = 'torch' - loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path, - config=config, - torch_dtype=dtype) delattr(loaded_hf_model.config, '_name_or_path') loaded_hf_model.save_pretrained(local_folder_path) @@ -332,8 +267,10 @@ def convert_composer_to_hf(args: Namespace) -> None: tokenizer = transformers.AutoTokenizer.from_pretrained(local_folder_path) tokenizer.save_pretrained(local_folder_path) - print('Editing files for HF compatibility...') - edit_files_for_hf_compatibility(local_folder_path) + # Only need to edit files for MPT because it has custom code + if config.model_type == 'mpt': + print('Editing files for HF compatibility...') + edit_files_for_hf_compatibility(local_folder_path) object_store = maybe_create_object_store_from_uri(str(args.hf_output_path)) diff --git a/scripts/train/train.py b/scripts/train/train.py index 7686ee2510..20e05fc83a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -4,7 +4,7 @@ import os import sys import warnings -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from composer import Trainer @@ -106,7 +106,7 @@ def build_composer_model(model_cfg: DictConfig, def build_composer_peft_model( - model_cfg: DictConfig, lora_cfg: DictConfig, + pretrained_model_name_or_path: str, lora_args: Dict[str, Any], tokenizer: PreTrainedTokenizerBase) -> ComposerHFCausalLM: try: from peft import LoraConfig, get_peft_model @@ -119,11 +119,11 @@ def build_composer_peft_model( # 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM. print('Building Lora config...') - lora_cfg = LoraConfig(**lora_cfg.args) + lora_cfg = LoraConfig(**lora_args) print('Building model from HuggingFace checkpoint...') - model = MPTForCausalLM.from_pretrained( - model_cfg.pretrained_model_name_or_path, trust_remote_code=True) + model = MPTForCausalLM.from_pretrained(pretrained_model_name_or_path, + trust_remote_code=True) print('Model built!') print('Adding Lora modules...') @@ -214,24 +214,29 @@ def main(cfg: DictConfig): # Mandatory model training configs model_config: DictConfig = pop_config(cfg, 'model', must_exist=True) tokenizer_config: DictConfig = pop_config(cfg, 'tokenizer', must_exist=True) - optimizer_config: DictConfig = pop_config(cfg, 'optimizer', must_exist=True) - scheduler_config: DictConfig = pop_config(cfg, 'scheduler', must_exist=True) + optimizer_config: Dict[str, Any] = pop_config(cfg, + 'optimizer', + must_exist=True, + convert=True) + scheduler_config: Dict[str, Any] = pop_config(cfg, + 'scheduler', + must_exist=True, + convert=True) train_loader_config: DictConfig = pop_config(cfg, 'train_loader', must_exist=True) # Optional fsdp data, fine-tuning, and eval configs - fsdp_dict_config: Optional[DictConfig] = pop_config(cfg, - 'fsdp_config', - must_exist=False, - default_value=None) - fsdp_config: Optional[Dict] = om.to_container( - fsdp_dict_config - ) if fsdp_dict_config is not None else None # type: ignore - lora_config: Optional[DictConfig] = pop_config(cfg, - 'lora', - must_exist=False, - default_value=None) + fsdp_config: Optional[Dict[str, Any]] = pop_config(cfg, + 'fsdp_config', + must_exist=False, + default_value=None, + convert=True) + lora_config: Optional[Dict[str, Any]] = pop_config(cfg, + 'lora', + must_exist=False, + default_value=None, + convert=True) eval_loader_config: Optional[DictConfig] = pop_config(cfg, 'eval_loader', must_exist=False, @@ -392,7 +397,8 @@ def main(cfg: DictConfig): with init_context: if lora_config is not None: # frozen model + trainable lora modules model: ComposerHFCausalLM = build_composer_peft_model( - model_config, lora_config, tokenizer) + model_config.pretrained_model_name_or_path, lora_config['args'], + tokenizer) print_trainable_parameters(model) # should not be 100% else: # standard model @@ -402,6 +408,32 @@ def main(cfg: DictConfig): n_params = sum(p.numel() for p in model.parameters()) logged_cfg.update({'n_params': n_params}) + # Optimizer + optimizer_name: str = optimizer_config.pop('name') + optimizer = build_optimizer(model, optimizer_name, optimizer_config) + + # Scheduler + scheduler_name: str = scheduler_config.pop('name') + scheduler = build_scheduler(scheduler_name, scheduler_config) + + # Loggers + loggers = [ + build_logger(str(name), logger_cfg) + for name, logger_cfg in logger_configs.items() + ] if logger_configs else None + + # Callbacks + callbacks = [ + build_callback(str(name), callback_cfg) + for name, callback_cfg in callback_configs.items() + ] if callback_configs else None + + # Algorithms + algorithms = [ + build_algorithm(str(name), algorithm_cfg) + for name, algorithm_cfg in algorithm_configs.items() + ] if algorithm_configs else None + # Dataloaders print('Building train loader...') train_loader = build_dataloader( @@ -449,33 +481,6 @@ def main(cfg: DictConfig): } model_gauntlet_callback = ModelGauntlet(**model_gauntlet) - # Optimizer - optimizer = build_optimizer(optimizer_config, model) - - # Scheduler - scheduler = build_scheduler(scheduler_config) - - # Loggers - loggers = [ - build_logger(str(name), logger_cfg) - for name, logger_cfg in logger_configs.items() - ] if logger_configs else None - - # Callbacks - - callbacks: List[Callback] = [ - build_callback(str(name), callback_cfg) - for name, callback_cfg in callback_configs.items() - ] if callback_configs else [] - - if model_gauntlet_callback is not None: - callbacks.append(model_gauntlet_callback) - - # Algorithms - algorithms = [ - build_algorithm(str(name), algorithm_cfg) - for name, algorithm_cfg in algorithm_configs.items() - ] if algorithm_configs else None # Build the Trainer print('Building trainer...') diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d5372911d5..1561f965c9 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -22,6 +22,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om +from llmfoundry import COMPOSER_MODEL_REGISTRY from scripts.inference.convert_composer_to_hf import convert_composer_to_hf @@ -43,19 +44,44 @@ def get_config( os.environ['TOKENIZERS_PARALLELISM'] = 'false' with open(conf_path) as f: test_cfg = om.load(f) + return cast(DictConfig, test_cfg) -def test_convert_and_generate_torch(tmp_path: pathlib.Path): +@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) +def test_convert_and_generate(model: str, tmp_path: pathlib.Path): delete_transformers_cache() - cfg = get_config() - cfg['model']['init_device'] = 'cpu' - cfg['model']['attn_config']['attn_impl'] = 'torch' + om_cfg = None + if model == 'mpt': + om_cfg = get_config( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + elif model == 'neo': + om_cfg = get_config( + conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') + om_cfg['model']['config_overrides']['hidden_size'] = 36 + elif model == 'llama2': + if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: + pytest.skip( + 'The CI cluster does not have access to the Llama models, so skip this test.' + ) + om_cfg = get_config( + conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') + om_cfg['model'][ + 'pretrained_model_name_or_path'] = 'meta-llama/Llama-2-7b-hf' + om_cfg['model']['config_overrides']['num_hidden_layers'] = 2 + om_cfg['model']['use_auth_token'] = True + om_cfg['tokenizer']['name'] = 'meta-llama/Llama-2-7b-hf' + else: + raise ValueError(f'Unknown model {model}') + assert om_cfg is not None + + om_cfg['model']['init_device'] = 'cpu' tokenizer = transformers.AutoTokenizer.from_pretrained( - 'EleutherAI/gpt-neox-20b') - model = ComposerMPTCausalLM(cfg['model'], tokenizer) - trainer = Trainer(model=model) + om_cfg.tokenizer.name, use_auth_token=model == 'llama2') + original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( + om_cfg['model'], tokenizer) + trainer = Trainer(model=original_model, device='cpu') trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt')) args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'), @@ -66,21 +92,29 @@ def test_convert_and_generate_torch(tmp_path: pathlib.Path): test_uploaded_model=False) convert_composer_to_hf(args) - config = transformers.AutoConfig.from_pretrained(os.path.join( - tmp_path, 'hf-output-folder'), - trust_remote_code=True) - config.attn_config['attn_impl'] = 'torch' - model = transformers.AutoModelForCausalLM.from_pretrained( + loaded_config = transformers.AutoConfig.from_pretrained( + os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( os.path.join(tmp_path, 'hf-output-folder'), - config=config, + config=loaded_config, trust_remote_code=True) tokenizer = transformers.AutoTokenizer.from_pretrained( os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True) - output = model.generate(tokenizer('hello', - return_tensors='pt')['input_ids'], - max_new_tokens=1) - assert output.shape == (1, 2) + output = loaded_model.generate(tokenizer('hello', + return_tensors='pt')['input_ids'], + max_new_tokens=1) + assert output.shape == (1, 2 + (1 if model == 'llama2' else 0)) + + assert sum(p.numel() for p in original_model.model.parameters()) == sum( + p.numel() for p in loaded_model.parameters()) + assert all( + str(type(module1)).split('.')[-1] == str(type(module2)).split('.')[-1] + for module1, module2 in zip(original_model.model.modules(), + loaded_model.modules())) + for p1, p2 in zip(original_model.model.parameters(), + loaded_model.parameters()): + assert torch.allclose(p1, p2) delete_transformers_cache() diff --git a/tests/test_train_inputs.py b/tests/test_train_inputs.py index 857a2b8c57..e208a0a1ee 100644 --- a/tests/test_train_inputs.py +++ b/tests/test_train_inputs.py @@ -87,3 +87,31 @@ def test_optional_mispelled_params_raise_warning(self, str(warning.message) for warning in warning_list) # restore configs. cfg = copy.deepcopy(old_cfg) + + def test_extra_params_in_optimizer_cfg_errors(self, + cfg: DictConfig) -> None: + cfg.optimizer.beta2 = 'extra-parameter' + with pytest.raises(TypeError): + main(cfg) + + def test_invalid_name_in_optimizer_cfg_errors(self, + cfg: DictConfig) -> None: + cfg.optimizer.name = 'invalid-optimizer' + with pytest.raises(ValueError) as exception_info: + main(cfg) + assert str(exception_info.value + ) == 'Not sure how to build optimizer: invalid-optimizer' + + def test_extra_params_in_scheduler_cfg_errors(self, + cfg: DictConfig) -> None: + cfg.scheduler.t_warmup_extra = 'extra-parameter' + with pytest.raises(TypeError): + main(cfg) + + def test_invalid_name_in_scheduler_cfg_errors(self, + cfg: DictConfig) -> None: + cfg.scheduler.name = 'invalid-scheduler' + with pytest.raises(ValueError) as exception_info: + main(cfg) + assert str(exception_info.value + ) == 'Not sure how to build scheduler: invalid-scheduler'