diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index c4ca68d733..c0a1e65248 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -59,6 +59,7 @@ def __init__( use_cache: bool = False, init_config: Dict = init_config_defaults, fc_type: str = 'torch', + tie_word_embeddings: bool = True, verbose: Optional[int] = None, **kwargs: Any, ): @@ -128,6 +129,7 @@ def __init__( --- See llmfoundry.models.utils.param_init_fns.py for info on other param init config options fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. + tie_word_embeddings (bool): Whether to tie the input embedding and output layers. """ self.d_model = d_model self.n_heads = n_heads @@ -164,7 +166,11 @@ def __init__( warnings.warn( f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`' ) - super().__init__(**kwargs) + # tie_word_embeddings is set in Huggingface's PretrainedConfig __init__ + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self._validate_config() diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 0cb3ebd56c..10c042d27c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -231,10 +231,11 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') - def get_input_embeddings(self) -> nn.Embedding: + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: return self.wte - def set_input_embeddings(self, value: nn.Embedding) -> None: + def set_input_embeddings( + self, value: Union[SharedEmbedding, nn.Embedding]) -> None: self.wte = value @torch.no_grad() @@ -574,14 +575,20 @@ class MPTForCausalLM(MPTPreTrainedModel): def __init__(self, config: MPTConfig): super().__init__(config) - if not config.tie_word_embeddings: - raise ValueError( - 'MPTForCausalLM only supports tied word embeddings') - log.info(f'Instantiating an MPTForCausalLM model from {__file__}') self.transformer: MPTModel = MPTModel(config) + self.lm_head = None + if not config.tie_word_embeddings: + self.lm_head = nn.Linear( + config.d_model, + config.vocab_size, + bias=False, + device=config.init_device, + ) + self.lm_head._fsdp_wrap = True + for child in self.transformer.children(): if isinstance(child, torch.nn.ModuleList): continue @@ -602,19 +609,38 @@ def __init__(self, config: MPTConfig): ) self.logit_scale = logit_scale - def get_input_embeddings(self) -> nn.Embedding: - return self.transformer.wte + def get_input_embeddings(self) -> Union[SharedEmbedding, nn.Embedding]: + return self.transformer.get_input_embeddings() def set_input_embeddings( self, value: Union[SharedEmbedding, nn.Embedding]) -> None: - self.transformer.wte = value + self.transformer.set_input_embeddings(value) - def get_output_embeddings(self) -> nn.Embedding: - return self.transformer.wte + def get_output_embeddings( + self) -> Union[SharedEmbedding, nn.Embedding, nn.Linear]: + if self.lm_head is not None: + return self.lm_head + return self.transformer.get_input_embeddings() def set_output_embeddings( - self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None: - self.transformer.wte = new_embeddings + self, new_embeddings: Union[SharedEmbedding, nn.Embedding, + nn.Linear]) -> None: + if self.lm_head is not None: + self.lm_head = new_embeddings + else: + if not isinstance(new_embeddings, (SharedEmbedding, nn.Embedding)): + raise ValueError( + 'new_embeddings must be an instance of SharedEmbedding ' + + f'or nn.Embedding, but got {type(new_embeddings)}.') + warnings.warn( + 'Using `set_output_embeddings` to set the embedding layer of ' + + 'MPTForCausalLM with tied weights. Given weights are tied, ' + + 'using `set_input_embeddings` is recommended over using ' + + '`set_output_embeddings`.') + self.transformer.set_input_embeddings(new_embeddings) + + def tie_weights(self) -> None: + self.lm_head = None def set_decoder(self, decoder: MPTModel) -> None: self.transformer = decoder @@ -658,12 +684,14 @@ def forward( use_cache=use_cache, ) - # move outputs to same device as weights for token embedding - # needed to support HF `device_map` - logits = self.transformer.wte( - outputs.last_hidden_state.to(self.transformer.wte.weight.device), - True, - ) + if self.lm_head is not None: + logits = self.lm_head(outputs.last_hidden_state) + else: + # move outputs to same device as weights for token embedding + # needed to support HF `device_map` + out = outputs.last_hidden_state + out = out.to(self.transformer.wte.weight.device) + logits = self.transformer.wte(out, True) if self.logit_scale is not None: if self.logit_scale == 0: @@ -859,7 +887,11 @@ def flops_per_batch(self, batch: Mapping) -> int: # assume the backward pass is approximately 2x the forward pass bs, msl = batch['input_ids'].shape[0:2] - params_flops_per_token = 2 * self.n_active_params + params = self.n_active_params + if not self.model.transformer.config.tie_word_embeddings: + # embedding layers are lookup tables, therefore are not counted in the FLOP computation + params -= self.model.transformer.wte.weight.numel() + params_flops_per_token = 2 * params params_flops_per_seq = params_flops_per_token * msl attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 * (self.model.config.d_model * (msl**2))) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 6d5a282993..af94126225 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -248,20 +248,21 @@ def test_callback_inits_with_defaults(): @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) +@pytest.mark.parametrize( + 'model,tie_word_embeddings', + [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], +) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) @patch('os.cpu_count', MagicMock(return_value=None)) -def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - fsdp_state_dict_type: Optional[str], - log_to_mlflow: bool, - hf_save_interval: str, - save_interval: str, max_duration: str, - expected_hf_checkpoints: int, - expected_normal_checkpoints: int): +def test_huggingface_conversion_callback( + model: str, tmp_path: pathlib.Path, tie_word_embeddings: bool, + fsdp_state_dict_type: Optional[str], log_to_mlflow: bool, + hf_save_interval: str, save_interval: str, max_duration: str, + expected_hf_checkpoints: int, expected_normal_checkpoints: int): delete_transformers_cache() dist.initialize_dist(get_device('gpu')) @@ -298,9 +299,11 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, 'attn_impl': 'torch', }, 'loss_fn': 'torch_crossentropy', + 'tie_word_embeddings': tie_word_embeddings, } tokenizer_name = 'EleutherAI/gpt-neox-20b' elif model == 'neo': + assert tie_word_embeddings is None model_cfg = { 'name': 'hf_causal_lm', 'pretrained_model_name_or_path': 'EleutherAI/gpt-neo-125M', @@ -313,6 +316,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, } tokenizer_name = 'EleutherAI/gpt-neo-125M' elif model == 'llama2': + assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' @@ -489,19 +493,26 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, delete_transformers_cache() -@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) -def test_convert_and_generate(model: str, tmp_path: pathlib.Path): +@pytest.mark.parametrize( + 'model,tie_word_embeddings', + [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)], +) +def test_convert_and_generate(model: str, tie_word_embeddings: bool, + tmp_path: pathlib.Path): delete_transformers_cache() om_cfg = None if model == 'mpt': om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/testing.yaml') + om_cfg['tie_word_embeddings'] = tie_word_embeddings elif model == 'neo': + assert tie_word_embeddings is None om_cfg = get_config( conf_path='scripts/train/yamls/pretrain/gpt-neo-125m.yaml') om_cfg['model']['config_overrides']['hidden_size'] = 36 elif model == 'llama2': + assert tie_word_embeddings is None if 'HUGGING_FACE_HUB_TOKEN' not in os.environ: pytest.skip( 'The CI cluster does not have access to the Llama models, so skip this test.' @@ -562,11 +573,14 @@ def test_convert_and_generate(model: str, tmp_path: pathlib.Path): @pytest.mark.gpu -def test_convert_and_generate_triton(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_convert_and_generate_triton(tie_word_embeddings: str, + tmp_path: pathlib.Path): delete_transformers_cache() cfg = get_config() cfg['model']['init_device'] = 'cpu' + cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( 'EleutherAI/gpt-neox-20b') model = ComposerMPTCausalLM(cfg['model'], tokenizer) @@ -602,7 +616,9 @@ def test_convert_and_generate_triton(tmp_path: pathlib.Path): delete_transformers_cache() -def test_convert_and_generate_meta(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_convert_and_generate_meta(tie_word_embeddings: str, + tmp_path: pathlib.Path): delete_transformers_cache() from composer.utils import dist @@ -612,6 +628,7 @@ def test_convert_and_generate_meta(tmp_path: pathlib.Path): om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml') om_cfg['model']['init_device'] = 'cpu' + om_cfg['tie_word_embeddings'] = tie_word_embeddings tokenizer = transformers.AutoTokenizer.from_pretrained( om_cfg.tokenizer.name) original_model = COMPOSER_MODEL_REGISTRY[om_cfg['model'].name]( diff --git a/tests/test_model.py b/tests/test_model.py index 41b62f0ccf..3308c65fd3 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -466,7 +466,8 @@ def test_opt_wrapping(): @pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys()) @pytest.mark.parametrize('no_bias', [False, True]) -def test_mpt_creation(norm_type: str, no_bias: bool): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_mpt_creation(norm_type: str, no_bias: bool, tie_word_embeddings: bool): # Test that the config constructs the model as expected. hf_config = MPTConfig( init_device='cpu', @@ -482,6 +483,7 @@ def test_mpt_creation(norm_type: str, no_bias: bool): }, norm_type=norm_type, no_bias=no_bias, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -493,6 +495,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool): assert mpt.transformer.wte.weight.shape == torch.Size( [hf_config.vocab_size, hf_config.d_model]) + if not tie_word_embeddings: + assert mpt.lm_head is not None + assert mpt.lm_head.weight.shape == mpt.transformer.wte.weight.shape assert mpt.transformer.wpe.weight.shape == torch.Size( [hf_config.max_seq_len, hf_config.d_model]) assert mpt.transformer.emb_drop.p == 0.1 @@ -544,8 +549,9 @@ def test_mpt_creation(norm_type: str, no_bias: bool): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_padding(attention_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, tie_word_embeddings: bool): # Test that different placement of padding does not affect the output. if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -580,6 +586,7 @@ def test_forward_with_padding(attention_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval() @@ -736,10 +743,13 @@ def test_advanced_mask_building(attention_impl: str): assert torch.equal(attn_bias, expected_attn_bias) -@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'), - ('flash', 'gpu'), - ('triton', 'gpu'), - ('torch', 'gpu')]) +@pytest.mark.parametrize('attention_impl,device,precision', [ + ('torch', 'cpu', 'fp32'), + ('flash', 'gpu', 'amp_bf16'), + ('triton', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'amp_bf16'), + ('torch', 'gpu', 'fp32'), +]) @pytest.mark.parametrize('pos_emb_config', [{ 'alibi': False, 'rope': False @@ -766,7 +776,9 @@ def test_advanced_mask_building(attention_impl: str): 'factor': 1.0, }, }]) -def test_generate(attention_impl: str, device: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_generate(attention_impl: str, device: str, precision: str, + pos_emb_config: dict, tie_word_embeddings: bool): # Test that generate works, and produces the same output with or without # padding in the input. if not torch.cuda.is_available() and device == 'gpu': @@ -780,6 +792,8 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): device != 'gpu' or not is_flash_v2_installed()): pytest.skip( f'dail implementation of rope requires gpu and flash attention 2.') + if attention_impl == 'torch' and precision == 'amp_bf16' and tie_word_embeddings == False: + pytest.skip(f'This test configuration has precision / sampling issues.') composer_device = get_device(device) @@ -796,10 +810,11 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): 'attn_impl': attention_impl, **pos_emb_config, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) - mpt.eval() mpt = composer_device.module_to_device(mpt) + mpt.eval() # padding on the left of the input left_padding_input_ids = torch.tensor( @@ -830,8 +845,7 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): batched_attention_mask = composer_device.tensor_to_device( batched_attention_mask) - with get_precision_context('amp_bf16' if composer_device.name == - 'gpu' else 'fp32'): + with get_precision_context(precision): # check that a batch with different amounts of padding doesn't crash # and produces the right output shape batched_generation = mpt.generate(input_ids=batched_input_ids, @@ -861,8 +875,9 @@ def test_generate(attention_impl: str, device: str, pos_emb_config: dict): @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('use_cache', [False, True]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, - use_cache: bool): + use_cache: bool, tie_word_embeddings: bool): if not torch.cuda.is_available(): pytest.skip(f'This test requires CUDA to be available.') if not torch.cuda.device_count() >= world_size: @@ -882,6 +897,7 @@ def test_generate_with_device_map(tmp_path: pathlib.Path, world_size: int, 'attn_impl': 'torch', }, use_cache=use_cache, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.save_pretrained(save_path) @@ -994,8 +1010,10 @@ def test_save_from_pretrained(tmp_path: pathlib.Path): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_cache_and_padding(attn_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, + tie_word_embeddings: bool): # Tests that the result is the same with or without padding when using kv caching if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1028,6 +1046,7 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) @@ -1133,7 +1152,9 @@ def test_forward_with_cache_and_padding(attn_impl: str, device: str, 'factor': 1.0, }, }]) -def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict, + tie_word_embeddings: bool): # Test that model forward with and without the key-value cache produces the # same output. if not torch.cuda.is_available() and device == 'gpu': @@ -1168,6 +1189,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1237,7 +1259,7 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): torch.testing.assert_close( second_output.logits, full_output.logits[:, -1, :].unsqueeze(1), - atol=1e-2, + atol=1.1e-2, rtol=1e-2, ) @@ -1274,8 +1296,9 @@ def test_forward_with_cache(attn_impl: str, device: str, pos_emb_config: dict): 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generate_with_past_kv(attn_impl: str, device: str, - pos_emb_config: dict): + pos_emb_config: dict, tie_word_embeddings: bool): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1307,6 +1330,7 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1325,7 +1349,8 @@ def test_generate_with_past_kv(attn_impl: str, device: str, with mock.patch.object(MPTForCausalLM, 'forward', autospec=True) as forward_mocked: forward_mocked.return_value = CausalLMOutputWithPast( - logits=torch.randn((1, 3, hf_config.vocab_size)), + logits=composer_device.tensor_to_device( + torch.randn((1, 3, hf_config.vocab_size))), past_key_values=[(torch.randn(1, 3, hf_config.d_model), torch.randn(1, 3, hf_config.d_model)) for _ in range(hf_config.n_layers)]) @@ -1386,9 +1411,11 @@ def test_generate_with_past_kv(attn_impl: str, device: str, 'factor': 1.0, }, }]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_generation_kwargs_dont_crash(attn_impl: str, device: str, generation_kwargs: Dict[str, Any], - pos_emb_config: dict): + pos_emb_config: dict, + tie_word_embeddings: bool): if not torch.cuda.is_available() and device == 'gpu': pytest.skip( f'This test requires CUDA to be available in order to run with {attn_impl} attention.' @@ -1417,6 +1444,7 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, **pos_emb_config, }, use_cache=True, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) @@ -1467,7 +1495,9 @@ def test_generation_kwargs_dont_crash(attn_impl: str, device: str, 'factor': 1.0, }, }]) -def test_model_to(attention_impl: str, pos_emb_config: dict): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_model_to(attention_impl: str, pos_emb_config: dict, + tie_word_embeddings: bool): # test that moving the model to diff devices and dtypes in diff ways does not break the model if not torch.cuda.is_available(): pytest.skip( @@ -1498,6 +1528,7 @@ def test_model_to(attention_impl: str, pos_emb_config: dict): 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = mpt.bfloat16() @@ -1600,9 +1631,11 @@ def test_alibi_vs_hf(): }]) @pytest.mark.parametrize('output_attentions', [True, False]) @pytest.mark.parametrize('output_hidden_states', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) def test_forward_with_output_attentions_and_output_hidden_states( attn_impl: str, device: str, pos_emb_config: dict, - output_attentions: bool, output_hidden_states: bool): + output_attentions: bool, output_hidden_states: bool, + tie_word_embeddings: bool): # Test that model forward with output_attentions_and_output_hidden_states if not torch.cuda.is_available() and device == 'gpu': pytest.skip( @@ -1639,6 +1672,7 @@ def test_forward_with_output_attentions_and_output_hidden_states( 'name': 'baseline_', 'init_std': 0.02, }, + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt = composer_device.module_to_device(mpt) diff --git a/tests/test_mpt_gen.py b/tests/test_mpt_gen.py index c52b765480..413e39bf8c 100644 --- a/tests/test_mpt_gen.py +++ b/tests/test_mpt_gen.py @@ -55,9 +55,11 @@ def forward( @pytest.mark.gpu @pytest.mark.parametrize('attn_impl', ['triton', 'torch']) @pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) @patch('llmfoundry.models.mpt.modeling_mpt.MPTForCausalLM', new=MockMPTForCausalLM) def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, + tie_word_embeddings: bool, build_tiny_mpt: Callable[..., ComposerMPTCausalLM], mpt_tokenizer: PreTrainedTokenizerBase): @@ -67,11 +69,14 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, """ device = get_device('gpu') - model = build_tiny_mpt(attn_config={ - 'attn_impl': attn_impl, - 'attn_uses_sequence_id': False, - 'alibi': use_alibi - },) + model = build_tiny_mpt( + tie_word_embeddings=tie_word_embeddings, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + }, + ) model = device.module_to_device(model) model.eval() @@ -88,13 +93,25 @@ def test_mpt_generate_multi_gpu(attn_impl: str, use_alibi: bool, @pytest.mark.gpu -def test_mpt_generate_callback(build_tiny_mpt: Callable[..., +@pytest.mark.parametrize('attn_impl', ['triton', 'torch']) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_mpt_generate_callback(attn_impl: str, use_alibi: bool, + tie_word_embeddings: bool, + build_tiny_mpt: Callable[..., ComposerMPTCausalLM], tiny_ft_dataloader: DataLoader): device = get_device('gpu') # build mpt model - model = build_tiny_mpt() + model = build_tiny_mpt( + tie_word_embeddings=tie_word_embeddings, + attn_config={ + 'attn_impl': attn_impl, + 'attn_uses_sequence_id': False, + 'alibi': use_alibi + }, + ) model = device.module_to_device(model) # generate callback diff --git a/tests/test_onnx.py b/tests/test_onnx.py index d0e01746eb..becd3c773f 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -3,6 +3,7 @@ import pathlib +import pytest import torch from transformers import AutoModelForCausalLM @@ -25,7 +26,8 @@ def gen_random_batch(batch_size: int, vocab_size: int, max_seq_len: int): return batch -def test_onnx_export(tmp_path: pathlib.Path): +@pytest.mark.parametrize('tie_word_embeddings', [True, False]) +def test_onnx_export(tie_word_embeddings: bool, tmp_path: pathlib.Path): from transformers.models.auto.configuration_auto import CONFIG_MAPPING CONFIG_MAPPING._extra_content['mpt'] = MPTConfig AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM) @@ -48,6 +50,7 @@ def test_onnx_export(tmp_path: pathlib.Path): use_cache=True, vocab_size=vocab_size, norm_type='layernorm', + tie_word_embeddings=tie_word_embeddings, ) mpt = MPTForCausalLM(hf_config) mpt.eval()