diff --git a/.gitignore b/.gitignore index 7ba86b45cb..83be856ec3 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,9 @@ ENV/ env.bak/ venv.bak/ +# python venv installed in the dir, llmfoundry-venv +*-venv + # Spyder project settings .spyderproject .spyproject @@ -143,3 +146,6 @@ dmypy.json # macOS .DS_Store + +# notebooks +notebooks/ diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index db979a910c..b6bb8e663b 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -1,9 +1,9 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -try: - import torch +import torch +try: from llmfoundry import optim, utils from llmfoundry.data import (ConcatTokensDataset, MixtureOfDenoisersCollator, NoConcatDataset, @@ -24,7 +24,7 @@ except ImportError as e: try: - is_cuda_available = torch.cuda.is_available() # type: ignore + is_cuda_available = torch.cuda.is_available() except: is_cuda_available = False diff --git a/llmfoundry/callbacks/fdiff_callback.py b/llmfoundry/callbacks/fdiff_callback.py index ede24f1f4f..bcef73875d 100644 --- a/llmfoundry/callbacks/fdiff_callback.py +++ b/llmfoundry/callbacks/fdiff_callback.py @@ -47,10 +47,12 @@ def batch_end(self, state: State, logger: Logger): def eval_end(self, state: State, logger: Logger): if self.diff_eval_metrics: evaluator = state.dataloader_label - metrics = list(state.eval_metrics[evaluator].keys()) # type: ignore + assert evaluator is not None, 'dataloader should have been set' + + metrics = list(state.eval_metrics[evaluator].keys()) for k in metrics: - mkey = '/'.join(['metrics', evaluator, k]) # type: ignore + mkey = '/'.join(['metrics', evaluator, k]) if mkey in self.eval_prev_metric.keys(): logger.log_metrics({ f'{mkey}_fdiff': @@ -59,5 +61,5 @@ def eval_end(self, state: State, logger: Logger): }) for k in metrics: - mkey = '/'.join(['metrics', evaluator, k]) # type: ignore + mkey = '/'.join(['metrics', evaluator, k]) self.eval_prev_metric[mkey] = state.eval_metric_values[k] diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index 89dc4e965e..476f9a0948 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -74,9 +74,9 @@ def generate(self, state: State, logger: Logger): dummy_input = device.tensor_to_device(dummy_input) with get_precision_context(state.precision): with torch.no_grad(): - _ = model.model(input_ids=dummy_input) # type: ignore + _ = model.model(input_ids=dummy_input) - output_token_ids = model.model.generate( # type: ignore + output_token_ids = model.model.generate( input_ids=tokenized_input['input_ids'], attention_mask=tokenized_input['attention_mask'], synced_gpus=True, @@ -85,9 +85,11 @@ def generate(self, state: State, logger: Logger): if dist.get_global_rank() == 0: if self.wandb_logger is not None: - artifact = wandb.Artifact( - 'generate_samples_' + str(wandb.run.id), # type: ignore - type='predictions') + assert wandb.run is not None, 'wandb should have started run' + + artifact = wandb.Artifact('generate_samples_' + + str(wandb.run.id), + type='predictions') rows = [] for i in range(len(self.prompts)): diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 09c060c654..8a953f5841 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -354,7 +354,7 @@ def build_text_denoising_dataloader( cfg: DictConfig, tokenizer: Tokenizer, device_batch_size: int, -) -> DataLoader: +) -> DataLoader[Dict]: """Constructor function for a Mixture of Denoisers dataloader. This function constructs a dataloader that can be used to train an @@ -480,7 +480,7 @@ def build_text_denoising_dataloader( batch_size=device_batch_size, ) - if dataset.tokenizer.pad_token is None: # type: ignore + if dataset.tokenizer.pad_token is None: dataset.tokenizer.pad_token = dataset.tokenizer.eos_token if cfg.dataset.get('packing_ratio'): @@ -564,7 +564,7 @@ def noise_token_sequence( else: u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0) mean_span_length = float(np.round(1 + u * (length - 1))) - mask_ratio = mean_span_length / length # type: ignore + mask_ratio = mean_span_length / length use_sentinels = False else: use_sentinels = True @@ -871,9 +871,9 @@ def _format_tokens_for_decoder_only( tokenizer = build_tokenizer(tokenizer_cfg) loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size) + assert isinstance(loader.dataset, StreamingTextDataset) - print( - f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') # type: ignore + print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') packing = cfg.dataset.get('packing_ratio') is not None if packing: diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 01acfbafb1..c1e845263c 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -6,6 +6,9 @@ import os from typing import Mapping, Union +# required for loading a python model into composer +import peft +import transformers from composer.metrics.nlp import (InContextLearningLMAccuracy, InContextLearningLMExpectedCalibrationError, InContextLearningMCExpectedCalibrationError, @@ -30,7 +33,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): """Configures a :class:`.HuggingFaceModel` around a Causal LM. Args: - cfg (DictConfig): An omegaconf dictionary used to configure the model: + om_model_config (DictConfig | peft.peft_model.PeftModel | transformers.PreTrainedModel): either n omegaconf dictionary used to configure the model, or an instantiated model object from the peft or transformers library. + if DictConfig, the following keys are required: cfg.pretrained_model_name_or_path (str): The name of or local path to the HF Causal LM (e.g., `gpt2` to instantiate a GPT2LMHeadModel). cfg.config_overrides (dict, optional): An optional dictionary of keyword @@ -45,34 +49,12 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): - trust_remote_code = om_model_config.get('trust_remote_code', True) - use_auth_token = om_model_config.get('use_auth_token', False) - config = AutoConfig.from_pretrained( - om_model_config.pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - ) - - # set config overrides - for k, v in om_model_config.get('config_overrides', {}).items(): - if not hasattr(config, k): - raise ValueError( - f'config does not have attribute "{k}" to override ({k}: {v}).' - ) - - attr = getattr(config, k) - if isinstance(attr, Mapping): - extra_keys = [_k for _k in v.keys() if _k not in attr.keys()] - if extra_keys: - raise ValueError( - f'Config dict override got unknown keys. ' - f'Extra keys: {extra_keys}. ' - f'Expected (a subset of) keys: {list(attr.keys())}.') - getattr(config, k).update(v) - else: - setattr(config, k, v) + def __init__(self, + om_model_config: Union[DictConfig, peft.peft_model.PeftModel, + transformers.PreTrainedModel], + tokenizer: Tokenizer): + # set up training and eval metrics train_metrics = [ LanguageCrossEntropy(), LanguagePerplexity(), @@ -87,64 +69,116 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): InContextLearningMCExpectedCalibrationError() ] - init_device = om_model_config.get('init_device', 'cpu') - - # Get the device we want to initialize, and use the - # reolved version to initialize the HF model - resolved_init_device = hf_get_init_device(init_device) - - # We need to have all non-zero local ranks be not-pretrained - # Rank 0 will still be pretrained, and distribute the weights appropriately - if dist.get_local_rank() != 0 and init_device == 'mixed': - om_model_config.pretrained = False - - if resolved_init_device == 'cpu': - if om_model_config.pretrained: - model = AutoModelForCausalLM.from_pretrained( - om_model_config.pretrained_model_name_or_path, - trust_remote_code=trust_remote_code, - use_auth_token=use_auth_token, - config=config) + # if we are passed a DictConfig, we need to instantiate the model + if isinstance(om_model_config, DictConfig): + + # load the model config + trust_remote_code = om_model_config.get('trust_remote_code', True) + use_auth_token = om_model_config.get('use_auth_token', False) + config = AutoConfig.from_pretrained( + om_model_config.pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + ) + + # set config overrides + for k, v in om_model_config.get('config_overrides', {}).items(): + if not hasattr(config, k): + raise ValueError( + f'config does not have attribute "{k}" to override ({k}: {v}).' + ) + + attr = getattr(config, k) + if isinstance(attr, Mapping): + extra_keys = [ + _k for _k in v.keys() if _k not in attr.keys() + ] + if extra_keys: + raise ValueError( + f'Config dict override got unknown keys. ' + f'Extra keys: {extra_keys}. ' + f'Expected (a subset of) keys: {list(attr.keys())}.' + ) + getattr(config, k).update(v) + else: + setattr(config, k, v) + + # below we set up the device to initialize the model on + init_device = om_model_config.get('init_device', 'cpu') + + # Get the device we want to initialize, and use the + # reolved version to initialize the HF model + resolved_init_device = hf_get_init_device(init_device) + + # We need to have all non-zero local ranks be not-pretrained + # Rank 0 will still be pretrained, and distribute the weights appropriately + if dist.get_local_rank() != 0 and init_device == 'mixed': + om_model_config.pretrained = False + + # initialize the model on the correct device + if resolved_init_device == 'cpu': + if om_model_config.pretrained: + model = AutoModelForCausalLM.from_pretrained( + om_model_config.pretrained_model_name_or_path, + trust_remote_code=trust_remote_code, + use_auth_token=use_auth_token, + config=config) + else: + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=trust_remote_code, + ) + elif resolved_init_device == 'meta': + if om_model_config.pretrained: + raise ValueError( + 'Setting cfg.pretrained=True is not supported when init_device="meta".' + ) + with init_empty_weights(include_buffers=False): + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=trust_remote_code, + ) else: - model = AutoModelForCausalLM.from_config( - config, - trust_remote_code=trust_remote_code, - ) - elif resolved_init_device == 'meta': - if om_model_config.pretrained: raise ValueError( - 'Setting cfg.pretrained=True is not supported when init_device="meta".' - ) - with init_empty_weights(include_buffers=False): - model = AutoModelForCausalLM.from_config( - config, - trust_remote_code=trust_remote_code, + f'init_device="{init_device}" must be either "cpu" or "meta".' ) - else: - raise ValueError( - f'init_device="{init_device}" must be either "cpu" or "meta".') - signal_file_path = '.local_rank0_completed_autoresume' - if dist.get_local_rank() == 0: - with open(signal_file_path, 'wb') as f: - f.write(b'local_rank0_completed_download') + signal_file_path = '.local_rank0_completed_autoresume' + if dist.get_local_rank() == 0: + with open(signal_file_path, 'wb') as f: + f.write(b'local_rank0_completed_download') + + # Avoid the collective call until the local rank zero has finished trying to download the checkpoint + # so that we don't timeout for large downloads. This syncs all processes on the node + with dist.local_rank_zero_download_and_wait(signal_file_path): + # Then, wait to ensure every node has finished downloading the checkpoint + dist.barrier() - # Avoid the collective call until the local rank zero has finished trying to download the checkpoint - # so that we don't timeout for large downloads. This syncs all processes on the node - with dist.local_rank_zero_download_and_wait(signal_file_path): - # Then, wait to ensure every node has finished downloading the checkpoint - dist.barrier() + if dist.get_local_rank() == 0: + os.remove(signal_file_path) - if dist.get_local_rank() == 0: - os.remove(signal_file_path) + z_loss = om_model_config.get('z_loss', 0.0) + + # elif the model is either a PeftModel or a PreTrainedModel + elif isinstance( + om_model_config, + (peft.peft_model.PeftModel, transformers.PreTrainedModel)): + model = om_model_config + init_device = 'cpu' + z_loss = 0.0 + + # else, unsoported type + else: + raise ValueError( + f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}' + ) composer_model = super().__init__(model=model, shift_labels=True, tokenizer=tokenizer, metrics=train_metrics, eval_metrics=eval_metrics, - z_loss=om_model_config.get( - 'z_loss', 0.0), + z_loss=z_loss, init_device=init_device) return composer_model diff --git a/llmfoundry/models/hf/hf_prefix_lm.py b/llmfoundry/models/hf/hf_prefix_lm.py index 31f7c16b60..863db0d08a 100644 --- a/llmfoundry/models/hf/hf_prefix_lm.py +++ b/llmfoundry/models/hf/hf_prefix_lm.py @@ -98,8 +98,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): if om_model_config.get('adapt_vocab_for_denoising', False): adapt_tokenizer_for_denoising(tokenizer) - vocab_size = len(tokenizer) - init_device = om_model_config.get('init_device', 'cpu') # Get the device we want to initialize, and use the diff --git a/llmfoundry/models/hf/hf_t5.py b/llmfoundry/models/hf/hf_t5.py index 028267e3da..e3b26e4b26 100644 --- a/llmfoundry/models/hf/hf_t5.py +++ b/llmfoundry/models/hf/hf_t5.py @@ -91,8 +91,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): if om_model_config.get('adapt_vocab_for_denoising', False): adapt_tokenizer_for_denoising(tokenizer) - vocab_size = len(tokenizer) - init_device = om_model_config.get('init_device', 'cpu') # Get the device we want to initialize, and use the diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index 51c2cb4e65..8eff7b545e 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -1,6 +1,8 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from typing import Dict, Type + import torch @@ -107,7 +109,7 @@ def forward(self, x): self.eps).to(dtype=x.dtype) -NORM_CLASS_REGISTRY = { +NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = { 'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index c4538d0e62..01a3f4b3b0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -33,21 +33,31 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY from llmfoundry.models.mpt.configuration_mpt import MPTConfig -# NOTE: We import all the utils directly just so that HuggingFace will detect -# all the files that it needs to copy into its modules folder. Otherwise it misses -# the ones imported in the submodule + +# NOTE: All utils are imported directly even if unused so that +# HuggingFace can detect all the needed files to copy into its modules folder. +# Otherwise, certain modules are missing. +# isort: off from llmfoundry.models.utils.adapt_tokenizer import ( - AutoTokenizerForMOD, adapt_tokenizer_for_denoising) + AutoTokenizerForMOD, # type: ignore (see note), + adapt_tokenizer_for_denoising, # type: ignore (see note) +) from llmfoundry.models.utils.hf_prefixlm_converter import ( - add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm) -from llmfoundry.models.utils.meta_init_context import init_empty_weights -from llmfoundry.models.utils.param_init_fns import ( # type: ignore - MODEL_INIT_REGISTRY, generic_param_init_fn_) + add_bidirectional_mask_if_missing, # type: ignore (see note) + convert_hf_causal_lm_to_prefix_lm, # type: ignore (see note) +) +from llmfoundry.models.utils.meta_init_context import \ + init_empty_weights # type: ignore (see note) +from llmfoundry.models.utils.param_init_fns import ( + generic_param_init_fn_, # type: ignore (see note) + MODEL_INIT_REGISTRY, +) try: from llmfoundry.models.layers.flash_attn_triton import flash_attn_func except: pass +# isort: on Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -145,7 +155,7 @@ def __init__(self, config: MPTConfig): def get_input_embeddings(self): return self.wte - def set_input_embeddings(self, value): + def set_input_embeddings(self, value: nn.Embedding): self.wte = value @torch.no_grad() @@ -212,9 +222,8 @@ def _attn_bias( if prefix_mask is not None and (attention_mask.shape != prefix_mask.shape): raise ValueError( - f'attention_mask shape={attention_mask.shape} ' +\ - f'and prefix_mask shape={prefix_mask.shape} are not equal.' - ) + f'attention_mask shape={attention_mask.shape} ' + + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill( ~attention_mask.view(-1, 1, 1, s_k), min_val) @@ -226,10 +235,9 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, s_k, s_q = attn_bias.shape[-2:] if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len): raise ValueError( - 'attn_bias does not match the expected shape. ' +\ - f'The last two dimensions should both be {self.config.max_length} ' +\ - f'but are {s_k} and {s_q}.' - ) + 'attn_bias does not match the expected shape. ' + + f'The last two dimensions should both be {self.config.max_length} ' + + f'but are {s_k} and {s_q}.') seq_len = prefix_mask.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( @@ -267,8 +275,10 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor, # Restrict attention to tokens that share the same value # in sequence_id cannot_attend = torch.logical_not( - torch.eq(sequence_id.view(-1, seq_len, 1), - sequence_id.view(-1, 1, seq_len))).unsqueeze(1) + torch.eq( + sequence_id.view(-1, seq_len, 1), + sequence_id.view(-1, 1, seq_len), + )).unsqueeze(1) min_val = torch.finfo(attn_bias.dtype).min attn_bias = attn_bias.masked_fill(cannot_attend, min_val) @@ -285,12 +295,16 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, + inputs_embeds: Optional[torch.Tensor] = None, ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict + if return_dict is not None else self.config.return_dict) + use_cache = (use_cache + if use_cache is not None else self.config.use_cache) if attention_mask is not None: attention_mask = attention_mask.bool() + if prefix_mask is not None: prefix_mask = prefix_mask.bool() @@ -306,8 +320,9 @@ def forward( 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.' ) - if attention_mask is not None and attention_mask[:, 0].sum( - ) != attention_mask.shape[0] and self.training: + if (attention_mask is not None and + attention_mask[:, 0].sum() != attention_mask.shape[0] and + self.training): raise NotImplementedError( 'MPT does not support training with left padding.') @@ -316,16 +331,21 @@ def forward( 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.' ) + # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT) + if inputs_embeds is not None: + raise NotImplementedError( + 'inputs_embeds is not implemented for MPT.') + if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( - 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' +\ - 'and the model is in train mode.' - ) + 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + + 'and the model is in train mode.') elif (self.attn_uses_sequence_id is False) and (sequence_id is not None): warnings.warn( - 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' +\ + 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.' ) @@ -343,7 +363,8 @@ def forward( if past_key_values is not None: if len(past_key_values) != self.config.n_layers: raise ValueError( - f'past_key_values must provide a past_key_value for each attention ' +\ + f'past_key_values must provide a past_key_value for each attention ' + + f'layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).' ) # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim). @@ -358,16 +379,19 @@ def forward( f'Cannot forward input with past sequence length {past_position} and current sequence length ' f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.' ) - pos = torch.arange(past_position, - S + past_position, - dtype=torch.long, - device=input_ids.device).unsqueeze(0) + pos = torch.arange( + past_position, + S + past_position, + dtype=torch.long, + device=input_ids.device, + ).unsqueeze(0) if attention_mask is not None: # adjust the position indices to account for padding tokens - pos = torch.clamp(pos - torch.cumsum( - (~attention_mask).to(torch.int32), dim=1)[:, - past_position:], - min=0) + pos = torch.clamp( + pos - torch.cumsum((~attention_mask).to(torch.int32), + dim=1)[:, past_position:], + min=0, + ) pos_emb = self.wpe(pos) # type: ignore x = tok_emb + pos_emb @@ -386,7 +410,8 @@ def forward( dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, - sequence_id=sequence_id) + sequence_id=sequence_id, + ) # initialize the past key values cache if it should be used if use_cache and past_key_values is None: @@ -399,8 +424,8 @@ def forward( if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) - past_key_value = past_key_values[ - b_idx] if past_key_values is not None else None + past_key_value = (past_key_values[b_idx] + if past_key_values is not None else None) x, attn_weights, past_key_value = block( x, past_key_value=past_key_value, @@ -456,7 +481,9 @@ def __init__(self, config: MPTConfig): raise ValueError( 'MPTForCausalLM only supports tied word embeddings') - self.transformer = MPTModel(config) + print(f'Instantiating an MPTForCausalLM model from {__file__}') + + self.transformer: MPTModel = MPTModel(config) for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): @@ -508,10 +535,17 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = (return_dict + if return_dict is not None else self.config.return_dict) + use_cache = (use_cache + if use_cache is not None else self.config.use_cache) + # if input_embeds is not none, raise a not implemented error + if inputs_embeds is not None: + raise NotImplementedError( + 'inputs_embeds has to be None (for hf/peft support).') # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.transformer( input_ids=input_ids, @@ -529,7 +563,8 @@ def forward( # needed to support HF `device_map` logits = self.transformer.wte( outputs.last_hidden_state.to(self.transformer.wte.weight.device), - True) + True, + ) if self.logit_scale is not None: if self.logit_scale == 0: @@ -540,10 +575,12 @@ def forward( loss = None if labels is not None: - labels = torch.roll(labels, shifts=-1) - labels[:, -1] = -100 - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), - labels.to(logits.device).view(-1)) + _labels = torch.roll(labels, shifts=-1) + _labels[:, -1] = -100 + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), + _labels.to(logits.device).view(-1), + ) return CausalLMOutputWithPast( loss=loss, @@ -651,7 +688,7 @@ def __init__( InContextLearningMultipleChoiceAccuracy(), InContextLearningQAAccuracy(), InContextLearningLMExpectedCalibrationError(), - InContextLearningMCExpectedCalibrationError() + InContextLearningMCExpectedCalibrationError(), ] super().__init__( @@ -670,6 +707,7 @@ def __init__( if loss_fn_config == 'fused_crossentropy': try: from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + if hf_config.verbose > 1: warnings.warn('Using Fused Cross Entropy Loss.') self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) @@ -700,6 +738,7 @@ def forward(self, batch): attention_mask=batch.get('attention_mask', None), prefix_mask=batch.get('bidirectional_mask', None), sequence_id=batch.get('sequence_id', None), + inputs_embeds=batch.get('inputs_embeds', None), ) def loss(self, outputs, batch): @@ -715,7 +754,7 @@ def flops_per_batch(self, batch): bs, msl = batch['input_ids'].shape[0:2] params_flops_per_token = 2 * self.n_active_params params_flops_per_seq = params_flops_per_token * msl - attn_flops_per_seq = self.model.config.n_layers * 2 * 2 * ( - self.model.config.d_model * (msl**2)) + attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 * + (self.model.config.d_model * (msl**2))) return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index a123aa9e26..647a070870 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -30,6 +30,8 @@ def convert_to_relative_import( def find_module_file(module_name: str) -> str: + if not module_name: + raise ValueError(f'Invalid input: {module_name=}') module = importlib.import_module(module_name) module_file = module.__file__ return module_file diff --git a/pyproject.toml b/pyproject.toml index c86ad98125..444206c5df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ include = [ # Pyright [tool.pyright] -exclude = ['env-**'] +exclude = ['env-**', 'venv*'] +ignore = ['llmfoundry/models/layers/flash_attn_triton.py'] stubPath = "" # suppress useless 'stubPath is not a valid directory' errors reportUnnecessaryIsInstance = "warning" diff --git a/scripts/inference/README.md b/scripts/inference/README.md index de5069b5c0..ba4956c0d5 100644 --- a/scripts/inference/README.md +++ b/scripts/inference/README.md @@ -109,13 +109,47 @@ For MPT models specifically, you can pass args like `--attn_impl triton`, and `- ## Interactive Chat with HF models -Chat models need to pass conversation history back to the model for multi-turn conversations. To make that easier, we include `hf_chat.py`. Chat models usually require an introductory/system prompt, as well as a wrapper around user and model messages, to fit the training format. Default values work with our ChatML-trained models, but you can specify these values with CLI args: +Chat models need to pass conversation history back to the model for multi-turn conversations. To make that easier, we include `hf_chat.py`. Chat models usually require an introductory/system prompt, as well as a wrapper around user and model messages, to fit the training format. Default values work with our ChatML-trained models, but you can set other important values like generation kwargs: ```bash -python hf_chat.py -n my_hf_model/ --system_prompt="You are a helpful assistant\n" --user_msg_fmt="user: {}\n" --assistant_msg_fmt="assistant: {}\n" --max_new_tokens=512 +# using an MPT/ChatML style model +python hf_chat.py -n mosaicml/mpt-7b-chat-v2 \ + --max_new_tokens=2048 \ + --temperature 0.3 \ + --top_k 0 \ + --model_dtype bf16 \ + --trust_remote_code ``` + +```bash +# using an MPT/ChatML style model on > 1 GPU +python hf_chat.py -n mosaicml/mpt-7b-chat-v2 \ + --max_new_tokens=1024 \ + --temperature 0.3 \ + --top_k 0 \ + --model_dtype bf16 \ + --trust_remote_code \ + --device_map auto +``` + +The script also works with other style models. Here is an example of using it with a Vicuna-style model: + + +```bash +python hf_chat.py -n eachadea/vicuna-7b-1.1 --system_prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." --user_msg_fmt="USER: {}\n" --assistant_msg_fmt="ASSISTANT: {}\n" --max_new_tokens=512 +``` + +The `system_prompt` is the message that gives the bot context for the conversation, and can be used to make the bot take on different personalities. + +In the REPL you see while using `hf_chat.py` you can enter text to interact with the model (hit return TWICE to send, this allows you to input text with single newlines), you can also enter the following commands: + +- `clear` — clear the conversation history, and start a new conversation (does not change system prompt) +- `system` — change the system prompt +- `history` — see the conversation history +- `quit` — exit + ## Converting an HF model to ONNX We include a script `convert_hf_to_onnx.py` that demonstrates how to convert your HF model to ONNX format. For more details and examples diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index 2d1253ec65..28f9af90b8 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -134,8 +134,9 @@ def turn(self, user_inp: str) -> None: self.history[-1][-1] = assistant_response def __call__(self) -> None: + print(self.cli_instructions) while True: - print(self.cli_instructions) + print('User:') user_inp_lines = [] while True: line = input() diff --git a/scripts/train/train.py b/scripts/train/train.py index d34b19068f..01442cd085 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -9,9 +9,13 @@ from composer import Trainer from composer.core import Evaluator from composer.utils import dist, get_device, reproducibility +from omegaconf import DictConfig from omegaconf import OmegaConf as om +from peft import LoraConfig, get_peft_model +from transformers import PreTrainedTokenizer -from llmfoundry import (COMPOSER_MODEL_REGISTRY, build_finetuning_dataloader, +from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM, + MPTForCausalLM, build_finetuning_dataloader, build_text_denoising_dataloader) from llmfoundry.data.text_data import build_text_dataloader from llmfoundry.models.utils import init_empty_weights @@ -67,6 +71,40 @@ def build_composer_model(model_cfg, tokenizer): return COMPOSER_MODEL_REGISTRY[model_cfg.name](model_cfg, tokenizer) +def build_composer_peft_model( + model_cfg: DictConfig, lora_cfg: DictConfig, + tokenizer: PreTrainedTokenizer) -> ComposerHFCausalLM: + # 1) loads a hf model, 2) adds peft modules, 3) wraps it in a ComposerHFCausalLM. + print('Building Lora config...') + lora_cfg = LoraConfig(**lora_cfg.args) + + print('Building model from HuggingFace checkpoint...') + model = MPTForCausalLM.from_pretrained( + cfg.model.pretrained_model_name_or_path, trust_remote_code=True) + print('Model built!') + + print('Adding Lora modules...') + model = get_peft_model(model, lora_cfg) + print('Lora modules added!') + + model = ComposerHFCausalLM(model, tokenizer) + + return model + + +def print_trainable_parameters(model) -> None: + # Prints the number of trainable parameters in the model. + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}' + ) + + def build_dataloader(cfg, tokenizer, device_batch_size): if cfg.name == 'text': return build_text_dataloader( @@ -161,7 +199,13 @@ def main(cfg): # Build Model print('Initializing model...') with init_context: - model = build_composer_model(cfg.model, tokenizer) + if cfg.get('lora', + None) is not None: # frozen model + trainable lora modules + model: ComposerHFCausalLM = build_composer_peft_model( + cfg.model, cfg.lora, tokenizer) + print_trainable_parameters(model) # should not be 100% + else: # standard model + model = build_composer_model(cfg.model, tokenizer) cfg.n_params = sum(p.numel() for p in model.parameters()) print(f'{cfg.n_params=:.2e}') diff --git a/setup.py b/setup.py index 3302030fa6..3c34b844a0 100644 --- a/setup.py +++ b/setup.py @@ -57,11 +57,15 @@ 'omegaconf>=2.2.3,<3', 'slack-sdk<4', 'mosaicml-cli>=0.3,<1', - 'onnx==1.13.1', - 'onnxruntime==1.14.1', + 'onnx==1.14.0', + 'onnxruntime==1.15.1', 'cmake>=3.25.0,<=3.26.3', # required for triton-pre-mlir below # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python', + 'loralib==0.1.1', # lora core + 'peft @ git+https://github.com/huggingface/peft.git', # TODO: pin it down only after it stabilizes. + 'bitsandbytes==0.39.1', # 8bit + 'scipy>=1.10.0,<=1.11.0', # bitsandbytes dependency; TODO: eliminate when incorporated to bitsandbytes ] extra_deps = {} diff --git a/tests/test_hf_mpt_gen.py b/tests/test_hf_mpt_gen.py index 59739e23b3..8b1c4df4cb 100644 --- a/tests/test_hf_mpt_gen.py +++ b/tests/test_hf_mpt_gen.py @@ -3,7 +3,7 @@ import pytest from composer.core.precision import get_precision_context -from composer.utils import dist, get_device, reproducibility +from composer.utils import get_device, reproducibility from omegaconf import OmegaConf as om from llmfoundry import COMPOSER_MODEL_REGISTRY diff --git a/tests/test_model.py b/tests/test_model.py index 93b83eb6ae..0a13ecd145 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -29,6 +29,7 @@ ComposerHFPrefixLM) from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias +from llmfoundry.models.layers.blocks import MPTBlock from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils import build_tokenizer @@ -359,13 +360,15 @@ def test_loss_fn(): pytest.skip('Fused cross entropy was not installed') # run numerical test in pure fp32 - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False # type: ignore (third-party) + torch.backends.cudnn.allow_tf32 = False # type: ignore (third-party) conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: test_cfg = om.load(f) + assert isinstance(test_cfg, DictConfig) + test_cfg.device = 'cuda:0' test_cfg.model.init_device = 'cuda:0' test_cfg.model.init_config = { @@ -471,25 +474,24 @@ def test_mpt_creation(norm_type, no_bias): assert mpt.config.expansion_ratio == 2 assert mpt.config.max_seq_len == 2048 - assert mpt.transformer.wte.weight.shape == torch.Size( # type: ignore + assert mpt.transformer.wte.weight.shape == torch.Size( [hf_config.vocab_size, hf_config.d_model]) - assert mpt.transformer.wpe.weight.shape == torch.Size( # type: ignore + assert mpt.transformer.wpe.weight.shape == torch.Size( [hf_config.max_seq_len, hf_config.d_model]) - assert mpt.transformer.emb_drop.p == 0.1 # type: ignore - assert len(mpt.transformer.blocks) == 2 # type: ignore + assert mpt.transformer.emb_drop.p == 0.1 + assert len(mpt.transformer.blocks) == 2 d_model = hf_config.d_model - for block in mpt.transformer.blocks: # type: ignore - assert block.norm_1.weight.shape == torch.Size([d_model - ]) # type: ignore - assert block.norm_2.weight.shape == torch.Size([d_model - ]) # type: ignore - assert block.ffn.up_proj.weight.shape == torch.Size( # type: ignore + for block in mpt.transformer.blocks: + assert isinstance(block, MPTBlock) + assert block.norm_1.weight.shape == torch.Size([d_model]) + assert block.norm_2.weight.shape == torch.Size([d_model]) + assert block.ffn.up_proj.weight.shape == torch.Size( [hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model]) - assert block.ffn.down_proj.weight.shape == torch.Size( # type: ignore + assert block.ffn.down_proj.weight.shape == torch.Size( [hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio]) - assert block.resid_attn_dropout.p == 0.2 # type: ignore - assert block.resid_ffn_dropout.p == 0.2 # type: ignore + assert block.resid_attn_dropout.p == 0.2 + assert block.resid_ffn_dropout.p == 0.2 @pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'),