diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index 837b645b6b..38cc562c9d 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -6,6 +6,8 @@ build_icl_evaluators, build_logger, build_optimizer, build_scheduler, build_tokenizer) + from llmfoundry.utils.checkpoint_conversion_helpers import ( + convert_and_save_ft_weights, get_hf_tokenizer_from_composer_state_dict) from llmfoundry.utils.config_utils import (calculate_batch_size_info, log_config, pop_config, update_batch_size_info) @@ -23,6 +25,8 @@ 'build_icl_evaluators', 'build_tokenizer', 'calculate_batch_size_info', + 'convert_and_save_ft_weights', + 'get_hf_tokenizer_from_composer_state_dict', 'update_batch_size_info', 'log_config', 'pop_config', diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py new file mode 100644 index 0000000000..4d7a152157 --- /dev/null +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -0,0 +1,295 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Helper methods for the checkpoint conversion scripts. + +The checkpoint conversion scripts are located in the +llmfoundry/scripts/inference/benchmarking/ folder. Users should run those +scripts directly to convert between checkpoints; this file contains only common +utility functions that are present in multiple scripts. +""" + +import json +import os +import random +import string +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import sentencepiece as spm +from transformers import AutoTokenizer, PreTrainedTokenizer + + +def _get_weight_data_type(data_type: str): + if data_type == 'fp32': + return np.float32 + elif data_type == 'fp16': + return np.float16 + else: + raise RuntimeError('Unsupported data type: {data_type} for conversion.') + + +# TODO: move this functionality to composer once the bug fixes are upstreamed +def get_hf_tokenizer_from_composer_state_dict( + state_dict: Dict[str, Any], + tokenizer_save_dir: Optional[str] = None +) -> Optional[PreTrainedTokenizer]: + if 'state' not in state_dict: + raise RuntimeError( + 'Unexpected composer state dictionary. Did you pass in a full composer checkpoint?' + ) + if 'integrations' not in state_dict[ + 'state'] or 'huggingface' not in state_dict['state']['integrations']: + raise RuntimeError( + 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!' + ) + hf_tokenizer_state = state_dict['state']['integrations']['huggingface'][ + 'tokenizer'] + hf_tokenizer = None + if hf_tokenizer_state != {}: + if tokenizer_save_dir is None: + unique_suffix = ''.join( + random.choices(string.ascii_letters + string.digits, k=6)) + tokenizer_save_dir = os.path.join( + os.getcwd(), f'tokenizer-save-dir-{unique_suffix}') + os.makedirs(tokenizer_save_dir, exist_ok=True) + + for filename, saved_content in hf_tokenizer_state.items(): + # This cannot be a temporary directory because huggingface relies on the slow tokenizer file + # being persistent on disk + tokenizer_file_path = Path( + tokenizer_save_dir + ) / f'{filename}{saved_content["file_extension"]}' + if saved_content['file_extension'] == '.json': + with open(tokenizer_file_path, 'w') as _tmp_file: + json.dump(saved_content['content'], _tmp_file) + elif saved_content['file_extension'] == '.txt': + with open(tokenizer_file_path, 'w') as _tmp_file: + for line in saved_content['content']: + _tmp_file.write(line) + _tmp_file.write('\n') + elif saved_content['file_extension'] == '.py': + with open(tokenizer_file_path, 'w') as _tmp_file: + _tmp_file.write(saved_content['content']) + elif saved_content['file_extension'] == '.model': + s = spm.SentencePieceProcessor() + s.load_from_serialized_proto(saved_content['content']) + with open(tokenizer_file_path, 'wb') as _tmp_file: + _tmp_file.write(s.serialized_model_proto()) + + hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir) + + # remove 'name_or_path' + hf_tokenizer.name_or_path = '' + hf_tokenizer.init_kwargs['name_or_path'] = '' + + return hf_tokenizer + + +def _write_zero_bias(weight_name: str, weight_file_path: str, + bias_shape: Union[Tuple[int, ...], int]) -> None: + """Write zeros for bias when converting MPT to FasterTransformer weights. + + MPT model might not have bias while FT expects bias. + + Args: + weight_name (str): Name of the weight tensor. + weight_file_path (str): Output path for storing the weight (NOT zero bias). + bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array. + """ + if 'weight' not in weight_file_path: + raise RuntimeError( + f'Cannot write zero bias for {weight_name}. Input is not a weight tensor' + ) + print(f'zero bias for weight: {weight_name}') + bias_file_path = weight_file_path.replace('.weight', '.bias') + bias = np.zeros(bias_shape, dtype=np.float32) + bias.tofile(bias_file_path) + + +def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, + tensor_name: str, config: Dict[str, Any], + data: np.ndarray): + """Convert each MPT weight to a FasterTransformer compatible format. + + Args: + save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist. + infer_gpu_num (int): The number of gpus you are planning to use for inference. + tensor_name (str): Name of the weight tensor. Used in naming the output file. + config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters. + data (np.ndarray): Tensor data in np.ndarray format. + + Returns: + None: Writes to a file in `save_dir`. File name is based on the `tensor_name` + """ + if tensor_name.find('input_layernorm.weight') != -1 or tensor_name.find('input_layernorm.bias') != -1 or \ + tensor_name.find('attention.dense.bias') != -1 or tensor_name.find('post_attention_layernorm.weight') != -1 or \ + tensor_name.find('post_attention_layernorm.bias') != -1 or tensor_name.find('mlp.dense_4h_to_h.bias') != -1 or \ + tensor_name.find('final_layernorm.weight') != -1 or tensor_name.find('final_layernorm.bias') != -1: + + save_path = os.path.join(save_dir, f'model.{tensor_name}.bin') + data.tofile(save_path) + if 'weight' in tensor_name and config['no_bias']: + _write_zero_bias(tensor_name, save_path, data.shape[-1]) + + elif tensor_name.find('attention.dense.weight') != -1: + assert data.shape == ( + config['d_model'], + config['d_model']), f'unexpected dim for {tensor_name}' + # nn.Linear weights are transposed + data = data.T + split_vals = np.split(data, infer_gpu_num, axis=0) + for j in range(infer_gpu_num): + save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') + split_vals[j].tofile(save_path) + if config['no_bias']: + fake_weight_path = os.path.join(save_dir, + f'model.{tensor_name}.bin') + _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1]) + + elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1: + assert data.shape == ( + config['d_model'], config['mlp_ratio'] * + config['d_model']), f'unexpected dim for {tensor_name}' + # nn.Linear weights are transposed + data = data.T + split_vals = np.split(data, infer_gpu_num, axis=0) + for j in range(infer_gpu_num): + save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') + split_vals[j].tofile(save_path) + if config['no_bias']: + fake_weight_path = os.path.join(save_dir, + f'model.{tensor_name}.bin') + _write_zero_bias(tensor_name, fake_weight_path, data.shape[-1]) + + elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1: + assert data.shape == ( + config['mlp_ratio'] * config['d_model'], + config['d_model']), f'unexpected dim for {tensor_name}' + # nn.Linear weights are transposed + data = data.T + + split_vals = np.split(data, infer_gpu_num, axis=-1) + for j in range(infer_gpu_num): + save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') + split_vals[j].tofile(save_path) + if config['no_bias']: + _write_zero_bias(tensor_name, save_path, + split_vals[j].shape[-1]) + + elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1: + assert data.shape == ( + config['mlp_ratio'] * + config['d_model'],), f'unexpected dim for {tensor_name}' + split_vals = np.split(data, infer_gpu_num, axis=-1) + for j in range(infer_gpu_num): + save_path = os.path.join(save_dir + f'model.{tensor_name}.{j}.bin') + split_vals[j].tofile(save_path) + + elif tensor_name.find('attention.query_key_value.bias') != -1: + assert data.shape == ( + 3 * config['d_model'],), f'unexpected dim for {tensor_name}' + + data = data.reshape(3, config['d_model']) + + split_vals = np.split(data, infer_gpu_num, axis=-1) + + for j in range(infer_gpu_num): + save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') + split_vals[j].tofile(save_path) + + elif tensor_name.find('attention.query_key_value.weight') != -1: + assert data.shape == ( + 3 * config['d_model'], + config['d_model']), f'unexpected dim for {tensor_name}' + # nn.Linear weights are transposed + data = data.T + + data = data.reshape(config['d_model'], 3, config['d_model']) + split_vals = np.split(data, infer_gpu_num, axis=-1) + + for j in range(infer_gpu_num): + save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') + split_vals[j].tofile(save_path) + if config['no_bias']: + _write_zero_bias(tensor_name, save_path, + (3, split_vals[j].shape[-1])) + + else: + raise RuntimeError(f'Tensor with name {tensor_name} is not handled') + + +def convert_and_save_ft_weights(named_params: dict, + config: dict, + infer_gpu_num: int = 1, + weight_data_type: str = 'fp32', + save_dir: str = ''): + """Convert a Composer MPT checkpoint to a FasterTransformer format. + + Args: + named_params (Dict[str, Parameter]): A dictionary containing the Composer MPT model's parameter names and data. + config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters. + infer_gpu_num (int): The number of gpus you are planning to use for inference. + weight_data_type (str): The dtype of the converted FasterTransformer model. + save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist. + + Returns: + None: Writes to the `save_dir` folder. File names within this folder are based on the model parameter names. + """ + np_weight_data_type = _get_weight_data_type(weight_data_type) + + param_remapping = { + 'norm_1.bias': 'input_layernorm.bias', + 'norm_1.weight': 'input_layernorm.weight', + 'attn.Wqkv.bias': 'attention.query_key_value.bias', + 'attn.Wqkv.weight': 'attention.query_key_value.weight', + 'attn.out_proj.bias': 'attention.dense.bias', + 'attn.out_proj.weight': 'attention.dense.weight', + 'norm_2.bias': 'post_attention_layernorm.bias', + 'norm_2.weight': 'post_attention_layernorm.weight', + 'ffn.up_proj.bias': 'mlp.dense_h_to_4h.bias', + 'ffn.up_proj.weight': 'mlp.dense_h_to_4h.weight', + 'ffn.down_proj.bias': 'mlp.dense_4h_to_h.bias', + 'ffn.down_proj.weight': 'mlp.dense_4h_to_h.weight', + } + + for name, param in named_params.items(): + print(f'Working on parameter {name} ...') + data = param.detach().cpu().numpy().astype(np_weight_data_type) + if name.find('weight') == -1 and name.find('bias') == -1: + print(f'found a parameter name that is not handled: {name}') + continue + if name == 'transformer.wpe.weight': + assert data.shape == ( + config['max_seq_len'], + config['d_model']), f'unexpected dim for {name}' + data.tofile(os.path.join(save_dir, 'model.wpe.bin')) + elif name == 'transformer.wte.weight': + assert data.shape == ( + config['vocab_size'], + config['d_model']), f'unexpected dim for {name}' + data.tofile(os.path.join(save_dir, 'model.wte.bin')) + elif name == 'transformer.norm_f.bias': + assert data.shape == ( + config['d_model'],), f'unexpected dim for {name}' + data.tofile(os.path.join(save_dir, + 'model.final_layernorm.bias.bin')) + elif name == 'transformer.norm_f.weight': + assert data.shape == ( + config['d_model'],), f'unexpected dim for {name}' + save_path = os.path.join(save_dir, + 'model.final_layernorm.weight.bin') + data.tofile(save_path) + if config['no_bias']: + _write_zero_bias(name, save_path, data.shape[-1]) + elif name == 'transformer.lm_head.weight': + data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin')) + else: + for mpt_pattern, ft_pattern in param_remapping.items(): + if name.find(mpt_pattern) != -1: + new_name = name.replace('transformer.blocks.', + 'layers.').replace( + mpt_pattern, ft_pattern) + _convert_weight_to_ft_each(save_dir, infer_gpu_num, + new_name, config, data) diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index f620c6516e..94ecfcfaff 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -97,8 +97,8 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int], max_seq_len: int, device_eval_batch_size: int, model_gauntlet_config: Optional[Union[str, DictConfig]], fsdp_config: Optional[Dict], num_retries: int, - loggers_cfg: Dict[str, Any], precision: str, - model_gauntlet_df: Optional[pd.DataFrame]): + loggers_cfg: Dict[str, Any], python_log_level: str, + precision: str, model_gauntlet_df: Optional[pd.DataFrame]): print(f'Evaluating model: {model_cfg.model_name}', flush=True) # Build tokenizer and model tokenizer = build_tokenizer(model_cfg.tokenizer) @@ -154,6 +154,7 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int], progress_bar=False, log_to_console=True, dist_timeout=dist_timeout, + python_log_level=python_log_level, ) if torch.cuda.is_available(): @@ -191,6 +192,10 @@ def main(cfg: DictConfig): 'device_eval_batch_size', must_exist=True) precision: str = pop_config(cfg, 'precision', must_exist=True) + python_log_level: str = pop_config(cfg, + 'python_log_level', + must_exist=False, + default_value='debug') # Optional Evaluation Parameters with default values seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17) @@ -240,8 +245,10 @@ def main(cfg: DictConfig): fsdp_config=fsdp_config, num_retries=num_retries, loggers_cfg=loggers_cfg, + python_log_level=python_log_level, precision=precision, - model_gauntlet_df=model_gauntlet_df) + model_gauntlet_df=model_gauntlet_df, + ) if model_gauntlet_callback is not None: # TODO(bmosaicml) This needs to be refactored to fix the typing issue diff --git a/scripts/inference/README.md b/scripts/inference/README.md index ba4956c0d5..92e10c7a2a 100644 --- a/scripts/inference/README.md +++ b/scripts/inference/README.md @@ -199,6 +199,12 @@ python convert_hf_mpt_to_ft.py -i mpt-7b -o mpt-ft-7b --infer_gpu_num 1 ``` You can change `infer_gpu_num` to > 1 to prepare a FT checkpoint for multi-gpu inference. Please open a Github issue if you discover any problems! +## Converting a Composer MPT to FasterTransformer +We include a script `convert_composer_mpt_to_ft.py` that directly converts a Composer MPT checkpoint to the FasterTransformer format. You can either provide a path to a local Composer checkpoint or a URI to a file stored in a cloud supported by Composer (e.g. `s3://`). Simply run: +``` +python convert_composer_mpt_to_ft.py -i -o mpt-ft-7b --infer_gpu_num 1 +``` + ## Running MPT with FasterTransformer This step assumes that you already have converted an MPT checkpoint to FT format by following the instructions in [Converting an HF MPT to FasterTransformer](#converting-an-hf-mpt-to-fastertransformer). It also assumes that you have diff --git a/scripts/inference/convert_composer_mpt_to_ft.py b/scripts/inference/convert_composer_mpt_to_ft.py new file mode 100644 index 0000000000..d260c31491 --- /dev/null +++ b/scripts/inference/convert_composer_mpt_to_ft.py @@ -0,0 +1,232 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# Note: This script is specifically for converting MPT Composer checkpoints to FasterTransformer format. + +import configparser +import os +import tempfile +from argparse import ArgumentParser, Namespace +from pathlib import Path +from typing import Optional, Union + +import torch +from composer.utils import get_file, safe_torch_load +from transformers import PreTrainedTokenizer + +from llmfoundry.utils import (convert_and_save_ft_weights, + get_hf_tokenizer_from_composer_state_dict) + + +def save_ft_config(composer_config: dict, + tokenizer: PreTrainedTokenizer, + save_dir: str, + infer_gpu_num: int = 1, + weight_data_type: str = 'fp32', + force: bool = False): + + config = configparser.ConfigParser() + config['gpt'] = {} + try: + config['gpt']['model_name'] = 'mpt' + config['gpt']['head_num'] = str(composer_config['n_heads']) + n_embd = composer_config['d_model'] + config['gpt']['size_per_head'] = str(n_embd // + composer_config['n_heads']) + config['gpt']['inter_size'] = str(n_embd * composer_config['mlp_ratio']) + config['gpt']['max_pos_seq_len'] = str(composer_config['max_seq_len']) + config['gpt']['num_layer'] = str(composer_config['n_layers']) + config['gpt']['vocab_size'] = str(composer_config['vocab_size']) + config['gpt']['start_id'] = str(tokenizer.bos_token_id) + config['gpt']['end_id'] = str(tokenizer.eos_token_id) + config['gpt']['weight_data_type'] = weight_data_type + config['gpt']['tensor_para_size'] = str(infer_gpu_num) + # nn.LayerNorm default eps is 1e-5 + config['gpt']['layernorm_eps'] = str(1e-5) + if composer_config['alibi']: + config['gpt']['has_positional_encoding'] = str(False) + config['gpt']['use_attention_linear_bias'] = str(True) + if composer_config['attn_clip_qkv'] and not force: + raise RuntimeError( + 'clip_qkv is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.' + ) + if composer_config['attn_qk_ln'] and not force: + raise RuntimeError( + 'qk_ln is enabled for this MPT model. This may not work as expected in FT. Use --force to force a conversion.' + ) + + with open(os.path.join(save_dir, 'config.ini'), 'w') as configfile: + config.write(configfile) + return config + except: + print(f'Failed to save the config in config.ini.') + raise + + +def write_ft_checkpoint_from_composer_checkpoint( + checkpoint_path: Union[Path, str], + infer_gpu_num: int, + save_dir: str, + output_precision: str = 'fp32', + local_checkpoint_save_location: Optional[Union[Path, + str]] = None) -> None: + """Convert a Composer checkpoint to a FasterTransformer checkpoint folder. + + .. note:: This function may not work properly if you used surgery algorithms when you trained your model. In that case you may need to + edit the parameter conversion methods to properly convert your custom model. + + Args: + checkpoint_path (Union[Path, str]): Path to the composer checkpoint, can be a local path, or a remote path beginning with ``s3://``, or another backend + supported by Composer. + infer_gpu_num (int): The number of gpus you are planning to use for inference. + save_dir (str): Path of the directory to save the checkpoint in FT format. + output_precision (str, optional): The precision of the output weights saved to the FasterTransformer model. Can be either ``fp32`` or ``fp16``. + local_checkpoint_save_location (Optional[Union[Path, str]], optional): If specified, where to save the checkpoint file to locally. + If the input ``checkpoint_path`` is already a local path, this will be a symlink. + Defaults to None, which will use a temporary file. + """ + dtype = { + 'fp32': torch.float32, + 'fp16': torch.float16, + }[output_precision] + + # default local path to a tempfile if path is not provided + if local_checkpoint_save_location is None: + tmp_dir = tempfile.TemporaryDirectory() + local_checkpoint_save_location = Path( + tmp_dir.name) / 'local-composer-checkpoint.pt' + + # download the checkpoint file + print( + f'Downloading checkpoint from {checkpoint_path} -> {local_checkpoint_save_location}' + ) + get_file(str(checkpoint_path), str(local_checkpoint_save_location)) + + # Load the Composer checkpoint. Use it to get the + # Composer state dict and weights + print('Loading checkpoint into CPU RAM...') + composer_state_dict = safe_torch_load(local_checkpoint_save_location) + + # Extract Composer config from state dict + if 'state' not in composer_state_dict: + raise RuntimeError( + f'"state" is not an available key in the provided composer checkpoint. Is {local_checkpoint_save_location} ill-formed?' + ) + if 'integrations' not in composer_state_dict[ + 'state'] or 'huggingface' not in composer_state_dict['state'][ + 'integrations']: + raise RuntimeError( + 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!' + ) + composer_config = composer_state_dict['state']['integrations'][ + 'huggingface']['model']['config']['content'] + + # Extract the HF tokenizer + print('#' * 30) + print('Extracting HF Tokenizer...') + hf_tokenizer = get_hf_tokenizer_from_composer_state_dict( + composer_state_dict) + if hf_tokenizer is None: + print('Warning! No HF Tokenizer found!') + + # Extract the model weights + weights_state_dict = composer_state_dict['state']['model'] + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( + weights_state_dict, prefix='model.') + + # Converting weights to desired dtype + for k, v in weights_state_dict.items(): + if isinstance(v, torch.Tensor): + weights_state_dict[k] = v.to(dtype=dtype) + + # Convert the weights using the config and tokenizer to FasterTransformer format + print('#' * 30) + print('Saving FasterTransformer config...') + save_ft_config(composer_config, + tokenizer=hf_tokenizer, + save_dir=save_dir, + weight_data_type=output_precision) + print('#' * 30) + print('Converting weights to FasterTransformer format...') + convert_and_save_ft_weights(named_params=weights_state_dict, + config=composer_config, + infer_gpu_num=infer_gpu_num, + weight_data_type=output_precision, + save_dir=save_dir) + + print('#' * 30) + print( + f'FasterTransformer checkpoint folder successfully created at {save_dir}.' + ) + + print('Done.') + print('#' * 30) + + +def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description= + 'Convert an MPT Composer checkpoint into a standard FasterTransformer checkpoint folder.' + ) + parser.add_argument( + '--composer_path', + '-i', + type=str, + help='Composer checkpoint path. Can be a local file path or cloud URI', + required=True) + parser.add_argument( + '--local_checkpoint_save_location', + type=str, + help='If specified, where to save the checkpoint file to locally. \ + If the input ``checkpoint_path`` is already a local path, this will be a symlink. \ + Defaults to None, which will use a temporary file.', + default=None) + parser.add_argument( + '--ft_save_dir', + '-o', + type=str, + help='Directory to save FasterTransformer converted checkpoint in', + required=True) + parser.add_argument('--infer_gpu_num', + '-i_g', + type=int, + help='How many gpus for inference?', + required=True) + parser.add_argument( + '--force', + action='store_true', + help= + 'Force conversion to FT even if some features may not work as expected in FT' + ) + parser.add_argument( + '--output_precision', + type=str, + help= + 'Data type of weights in the FasterTransformer output model. Input checkpoint weights will be converted to this dtype.', + choices=['fp32', 'fp16'], + default='fp32') + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + print('\n=============== Argument ===============') + for key in vars(args): + print('{}: {}'.format(key, vars(args)[key])) + print('========================================') + + save_dir = os.path.join(args.ft_save_dir, f'{args.infer_gpu_num}-gpu') + + if os.path.exists(save_dir) == False: + os.makedirs(save_dir) + else: + raise RuntimeError(f'Output path {save_dir} already exists!') + + write_ft_checkpoint_from_composer_checkpoint( + checkpoint_path=args.composer_path, + infer_gpu_num=args.infer_gpu_num, + save_dir=save_dir, + output_precision=args.output_precision, + local_checkpoint_save_location=args.local_checkpoint_save_location) diff --git a/scripts/inference/convert_composer_to_hf.py b/scripts/inference/convert_composer_to_hf.py index 72377d0fcd..5c4d4117c5 100644 --- a/scripts/inference/convert_composer_to_hf.py +++ b/scripts/inference/convert_composer_to_hf.py @@ -1,86 +1,25 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json import os -import random -import string import tempfile from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Optional, Tuple, Union -import sentencepiece as spm import torch import transformers from composer.models.huggingface import get_hf_config_from_composer_state_dict from composer.utils import (get_file, maybe_create_object_store_from_uri, parse_uri, safe_torch_load) -from transformers import (AutoConfig, AutoTokenizer, PretrainedConfig, - PreTrainedTokenizer, PreTrainedTokenizerBase) +from transformers import AutoConfig, PretrainedConfig, PreTrainedTokenizerBase from llmfoundry import MPTConfig, MPTForCausalLM +from llmfoundry.utils import get_hf_tokenizer_from_composer_state_dict from llmfoundry.utils.huggingface_hub_utils import \ edit_files_for_hf_compatibility -# TODO: move this functionality to composer once the bug fixes are upstreamed -def get_hf_tokenizer_from_composer_state_dict( - state_dict: Dict[str, Any], - tokenizer_save_dir: Optional[str] = None -) -> Optional[PreTrainedTokenizer]: - if 'state' not in state_dict: - raise RuntimeError( - 'Unexpected composer state dictionary. Did you pass in a full composer checkpoint?' - ) - if 'integrations' not in state_dict[ - 'state'] or 'huggingface' not in state_dict['state']['integrations']: - raise RuntimeError( - 'Did not find HuggingFace related state (e.g., tokenizer) in the provided composer checkpoint!' - ) - hf_tokenizer_state = state_dict['state']['integrations']['huggingface'][ - 'tokenizer'] - hf_tokenizer = None - if hf_tokenizer_state != {}: - if tokenizer_save_dir is None: - unique_suffix = ''.join( - random.choices(string.ascii_letters + string.digits, k=6)) - tokenizer_save_dir = os.path.join( - os.getcwd(), f'tokenizer-save-dir-{unique_suffix}') - os.makedirs(tokenizer_save_dir, exist_ok=True) - - for filename, saved_content in hf_tokenizer_state.items(): - # This cannot be a temporary directory because huggingface relies on the slow tokenizer file - # being persistent on disk - tokenizer_file_path = Path( - tokenizer_save_dir - ) / f'{filename}{saved_content["file_extension"]}' - if saved_content['file_extension'] == '.json': - with open(tokenizer_file_path, 'w') as _tmp_file: - json.dump(saved_content['content'], _tmp_file) - elif saved_content['file_extension'] == '.txt': - with open(tokenizer_file_path, 'w') as _tmp_file: - for line in saved_content['content']: - _tmp_file.write(line) - _tmp_file.write('\n') - elif saved_content['file_extension'] == '.py': - with open(tokenizer_file_path, 'w') as _tmp_file: - _tmp_file.write(saved_content['content']) - elif saved_content['file_extension'] == '.model': - s = spm.SentencePieceProcessor() - s.load_from_serialized_proto(saved_content['content']) - with open(tokenizer_file_path, 'wb') as _tmp_file: - _tmp_file.write(s.serialized_model_proto()) - - hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_save_dir) - - # remove 'name_or_path' - hf_tokenizer.name_or_path = '' - hf_tokenizer.init_kwargs['name_or_path'] = '' - - return hf_tokenizer - - def write_huggingface_pretrained_from_composer_checkpoint( checkpoint_path: Union[Path, str], output_path: Union[Path, str], @@ -250,6 +189,7 @@ def convert_composer_to_hf(args: Namespace) -> None: print(f'Loading model from {local_folder_path}') if config.model_type == 'mpt': config.attn_config['attn_impl'] = 'torch' + config.init_device = 'cpu' if config.model_type == 'mpt': loaded_hf_model = MPTForCausalLM.from_pretrained(local_folder_path, diff --git a/scripts/inference/convert_hf_mpt_to_ft.py b/scripts/inference/convert_hf_mpt_to_ft.py index ceb4c5c770..104d0d6b15 100644 --- a/scripts/inference/convert_hf_mpt_to_ft.py +++ b/scripts/inference/convert_hf_mpt_to_ft.py @@ -24,151 +24,10 @@ import argparse import configparser import os -from typing import Any, Dict, Tuple, Union -import numpy as np import transformers - -def get_weight_data_type(data_type: str): - if data_type == 'fp32': - return np.float32 - elif data_type == 'fp16': - return np.float16 - else: - raise RuntimeError('Unsupported data type: {data_type} for conversion') - - -def write_zero_bias(weight_name: str, weight_file_path: str, - bias_shape: Union[Tuple[int, ...], int]) -> None: - """Write zeros for bias. - - MPT model might not have bias while FT expects bias. - - Args: - weight_name (str): Name of the weight tensor. - weight_file_path (str): Output path for storing the weight (NOT zero bias). - bias_shape (Union[Tuple[int, ...], int]): Shape of the bias array. - """ - if 'weight' not in weight_file_path: - raise RuntimeError( - f'Cannot write zero bias for {weight_name}. Input is not a weight tensor' - ) - print(f'zero bias for weight: {weight_name}') - bias_file_path = weight_file_path.replace('.weight', '.bias') - bias = np.zeros(bias_shape, dtype=np.float32) - bias.tofile(bias_file_path) - - -def convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, - tensor_name: str, config: Dict[str, Any], - data: np.ndarray): - """Convert an MPT checkpoint to a FasterTransformer compatible format. - - Args: - save_dir (str): Path of the directory to save the weight in FT format. The directory must already exist. - infer_gpu_num (int): The number of gpus you are planning to use for inference. - tensor_name (str): Name of the weight tensor. Used in naming the output file. - config (Dict[str, Any]): Configuration for the model. This is used in getting model specific parameters. - data (np.ndarray): Tensor data in np.ndarray format. - - Returns: - None: Writes to a file in `save_dir`. File name is based on the `tensor_name` - """ - if tensor_name.find('input_layernorm.weight') != -1 or tensor_name.find('input_layernorm.bias') != -1 or \ - tensor_name.find('attention.dense.bias') != -1 or tensor_name.find('post_attention_layernorm.weight') != -1 or \ - tensor_name.find('post_attention_layernorm.bias') != -1 or tensor_name.find('mlp.dense_4h_to_h.bias') != -1 or \ - tensor_name.find('final_layernorm.weight') != -1 or tensor_name.find('final_layernorm.bias') != -1: - - save_path = os.path.join(save_dir, f'model.{tensor_name}.bin') - data.tofile(save_path) - if 'weight' in tensor_name and config['no_bias']: - write_zero_bias(tensor_name, save_path, data.shape[-1]) - - elif tensor_name.find('attention.dense.weight') != -1: - assert data.shape == ( - config['d_model'], - config['d_model']), f'unexpected dim for {tensor_name}' - # nn.Linear weights are transposed - data = data.T - split_vals = np.split(data, infer_gpu_num, axis=0) - for j in range(infer_gpu_num): - save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') - split_vals[j].tofile(save_path) - if config['no_bias']: - fake_weight_path = os.path.join(save_dir, - f'model.{tensor_name}.bin') - write_zero_bias(tensor_name, fake_weight_path, data.shape[-1]) - - elif tensor_name.find('mlp.dense_4h_to_h.weight') != -1: - assert data.shape == ( - config['d_model'], config['expansion_ratio'] * - config['d_model']), f'unexpected dim for {tensor_name}' - # nn.Linear weights are transposed - data = data.T - split_vals = np.split(data, infer_gpu_num, axis=0) - for j in range(infer_gpu_num): - save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') - split_vals[j].tofile(save_path) - if config['no_bias']: - fake_weight_path = os.path.join(save_dir, - f'model.{tensor_name}.bin') - write_zero_bias(tensor_name, fake_weight_path, data.shape[-1]) - - elif tensor_name.find('mlp.dense_h_to_4h.weight') != -1: - assert data.shape == ( - config['expansion_ratio'] * config['d_model'], - config['d_model']), f'unexpected dim for {tensor_name}' - # nn.Linear weights are transposed - data = data.T - - split_vals = np.split(data, infer_gpu_num, axis=-1) - for j in range(infer_gpu_num): - save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') - split_vals[j].tofile(save_path) - if config['no_bias']: - write_zero_bias(tensor_name, save_path, split_vals[j].shape[-1]) - - elif tensor_name.find('mlp.dense_h_to_4h.bias') != -1: - assert data.shape == ( - config['expansion_ratio'] * - config['d_model'],), f'unexpected dim for {tensor_name}' - split_vals = np.split(data, infer_gpu_num, axis=-1) - for j in range(infer_gpu_num): - save_path = os.path.join(save_dir + f'model.{tensor_name}.{j}.bin') - split_vals[j].tofile(save_path) - - elif tensor_name.find('attention.query_key_value.bias') != -1: - assert data.shape == ( - 3 * config['d_model'],), f'unexpected dim for {tensor_name}' - - data = data.reshape(3, config['d_model']) - - split_vals = np.split(data, infer_gpu_num, axis=-1) - - for j in range(infer_gpu_num): - save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') - split_vals[j].tofile(save_path) - - elif tensor_name.find('attention.query_key_value.weight') != -1: - assert data.shape == ( - 3 * config['d_model'], - config['d_model']), f'unexpected dim for {tensor_name}' - # nn.Linear weights are transposed - data = data.T - - data = data.reshape(config['d_model'], 3, config['d_model']) - split_vals = np.split(data, infer_gpu_num, axis=-1) - - for j in range(infer_gpu_num): - save_path = os.path.join(save_dir, f'model.{tensor_name}.{j}.bin') - split_vals[j].tofile(save_path) - if config['no_bias']: - write_zero_bias(tensor_name, save_path, - (3, split_vals[j].shape[-1])) - - else: - raise RuntimeError(f'Tensor with name {tensor_name} is not handled') +from llmfoundry.utils import convert_and_save_ft_weights def convert_mpt_to_ft(model_name_or_path: str, @@ -243,62 +102,14 @@ def convert_mpt_to_ft(model_name_or_path: str, print(f'Failed to save the config in config.ini.') raise - np_weight_data_type = get_weight_data_type(weight_data_type) - - param_remapping = { - 'norm_1.bias': 'input_layernorm.bias', - 'norm_1.weight': 'input_layernorm.weight', - 'attn.Wqkv.bias': 'attention.query_key_value.bias', - 'attn.Wqkv.weight': 'attention.query_key_value.weight', - 'attn.out_proj.bias': 'attention.dense.bias', - 'attn.out_proj.weight': 'attention.dense.weight', - 'norm_2.bias': 'post_attention_layernorm.bias', - 'norm_2.weight': 'post_attention_layernorm.weight', - 'ffn.up_proj.bias': 'mlp.dense_h_to_4h.bias', - 'ffn.up_proj.weight': 'mlp.dense_h_to_4h.weight', - 'ffn.down_proj.bias': 'mlp.dense_4h_to_h.bias', - 'ffn.down_proj.weight': 'mlp.dense_4h_to_h.weight', + named_params_dict = { + name: param for name, param in model.named_parameters() } - - for name, param in model.named_parameters(): - print(f'Working on parameter {name} ...') - data = param.detach().cpu().numpy().astype(np_weight_data_type) - if name.find('weight') == -1 and name.find('bias') == -1: - print(f'found a parameter name that is not handled: {name}') - continue - if name == 'transformer.wpe.weight': - assert data.shape == ( - hf_config['max_seq_len'], - hf_config['d_model']), f'unexpected dim for {name}' - data.tofile(os.path.join(save_dir, 'model.wpe.bin')) - elif name == 'transformer.wte.weight': - assert data.shape == ( - hf_config['vocab_size'], - hf_config['d_model']), f'unexpected dim for {name}' - data.tofile(os.path.join(save_dir, 'model.wte.bin')) - elif name == 'transformer.norm_f.bias': - assert data.shape == ( - hf_config['d_model'],), f'unexpected dim for {name}' - data.tofile(os.path.join(save_dir, - 'model.final_layernorm.bias.bin')) - elif name == 'transformer.norm_f.weight': - assert data.shape == ( - hf_config['d_model'],), f'unexpected dim for {name}' - save_path = os.path.join(save_dir, - 'model.final_layernorm.weight.bin') - data.tofile(save_path) - if hf_config['no_bias']: - write_zero_bias(name, save_path, data.shape[-1]) - elif name == 'transformer.lm_head.weight': - data.tofile(os.path.join(save_dir, 'model.lm_head.weight.bin')) - else: - for mpt_pattern, ft_pattern in param_remapping.items(): - if name.find(mpt_pattern) != -1: - new_name = name.replace('transformer.blocks.', - 'layers.').replace( - mpt_pattern, ft_pattern) - convert_weight_to_ft_each(save_dir, infer_gpu_num, new_name, - hf_config, data) + convert_and_save_ft_weights(named_params=named_params_dict, + config=hf_config, + infer_gpu_num=infer_gpu_num, + weight_data_type=weight_data_type, + save_dir=save_dir) if __name__ == '__main__': diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 1561f965c9..e16832d803 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -157,3 +157,66 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): assert output.shape == (1, 2) delete_transformers_cache() + + +def test_convert_and_generate_meta(tmp_path: pathlib.Path): + delete_transformers_cache() + + from composer.utils import dist + gathered_paths = dist.all_gather_object(tmp_path) + tmp_path_gathered = gathered_paths[0] + + om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') + + om_cfg['model']['init_device'] = 'cpu' + tokenizer = transformers.AutoTokenizer.from_pretrained( + om_cfg.tokenizer.name) + original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( + om_cfg['model'], tokenizer) + trainer = Trainer(model=original_model, device='cpu') + trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt')) + + # patch in the meta device for testing + sd = torch.load(os.path.join(tmp_path_gathered, 'checkpoint.pt'), + map_location='cpu') + sd['state']['integrations']['huggingface']['model']['config']['content'][ + 'init_device'] = 'meta' + torch.save(sd, os.path.join(tmp_path_gathered, 'checkpoint.pt')) + + args = Namespace(composer_path=os.path.join(tmp_path_gathered, + 'checkpoint.pt'), + hf_output_path=os.path.join(tmp_path_gathered, + 'hf-output-folder'), + output_precision='fp32', + local_checkpoint_save_location=None, + hf_repo_for_upload=None, + test_uploaded_model=False) + convert_composer_to_hf(args) + + loaded_config = transformers.AutoConfig.from_pretrained( + os.path.join(tmp_path_gathered, 'hf-output-folder'), + trust_remote_code=True) + loaded_model = transformers.AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_path_gathered, 'hf-output-folder'), + config=loaded_config, + trust_remote_code=True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + os.path.join(tmp_path_gathered, 'hf-output-folder'), + trust_remote_code=True) + + output = loaded_model.generate(tokenizer('hello', + return_tensors='pt')['input_ids'], + max_new_tokens=1) + assert output.shape == (1, 2) + + assert sum(p.numel() for p in original_model.model.parameters()) == sum( + p.numel() for p in loaded_model.parameters()) + assert all( + str(type(module1)).split('.')[-1] == str(type(module2)).split('.')[-1] + for module1, module2 in zip(original_model.model.modules(), + loaded_model.modules())) + for p1, p2 in zip(original_model.model.parameters(), + loaded_model.parameters()): + assert torch.allclose(p1, p2) + + delete_transformers_cache()