From 22ae919c1d6b2b542278399586a1835e0e632bba Mon Sep 17 00:00:00 2001 From: Sam Havens Date: Thu, 30 Nov 2023 17:47:43 -0800 Subject: [PATCH] Support inputs_embeds (#687) * support inputs_embeds * update tests to test inputs_embeds * make iids optional inputs to fwd * remove check for both iids and inputs_embeds in MPTForCausalLM. It is checked in the base model, and it is actually a common practice to pass both during autoregressive generation. Embeds are used first, then once the kvcache is nonempty, iids are used instead * reorder kwargs * add more tests * fix device merge artifact in test_model.oy * fix generate test * yapf --- llmfoundry/models/mpt/modeling_mpt.py | 51 +++++++++-------- tests/test_model.py | 79 +++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 29 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 274c1b76e5..d6b23c04d0 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -368,7 +368,7 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor, def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, @@ -412,11 +412,6 @@ def forward( 'prefix_mask is a required argument when MPT is configured with prefix_lm=True.' ) - # Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT) - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds is not implemented for MPT.') - if self.training: if self.attn_uses_sequence_id and sequence_id is None: raise ValueError( @@ -430,14 +425,25 @@ def forward( 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.' ) - S = input_ids.size(1) + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + 'You cannot specify both input_ids and inputs_embeds.') + elif input_ids is not None: + S = input_ids.size(1) + x = self.wte(input_ids) + input_device = input_ids.device + elif inputs_embeds is not None: + S = inputs_embeds.size(1) + x = inputs_embeds + input_device = inputs_embeds.device + else: + raise ValueError('You must specify input_ids or inputs_embeds') assert ( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' rotary_emb_w_meta_info = None - x = self.wte(input_ids) if self.learned_pos_emb or self.rope: past_position = 0 if past_key_values is not None: @@ -467,7 +473,7 @@ def forward( past_position, S + past_position, dtype=torch.long, - device=input_ids.device, + device=input_device, ).unsqueeze(0) if attention_mask is not None: # adjust the position indices to account for padding tokens @@ -652,7 +658,7 @@ def get_decoder(self) -> MPTModel: def forward( self, - input_ids: torch.LongTensor, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, @@ -669,11 +675,6 @@ def forward( use_cache = (use_cache if use_cache is not None else self.config.use_cache) - # if input_embeds is not none, raise a not implemented error - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds has to be None (for hf/peft support).') - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, @@ -684,6 +685,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, + inputs_embeds=inputs_embeds, ) if self.lm_head is not None: @@ -773,10 +775,6 @@ def prepare_inputs_for_generation( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: Any, ) -> Dict[str, Any]: - if inputs_embeds is not None: - raise NotImplementedError( - 'inputs_embeds is not implemented for MPT yet') - attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError( @@ -787,6 +785,7 @@ def prepare_inputs_for_generation( else: sequence_id = None + # only last token for inputs_ids if past is defined in kwargs if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -800,14 +799,20 @@ def prepare_inputs_for_generation( else: prefix_mask = None - return { - 'input_ids': input_ids, + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update({ 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), - } + }) + return model_inputs @staticmethod def _reorder_cache( @@ -898,7 +903,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: add_bidirectional_mask_if_missing(batch) # Note: prefix_mask is only used if model.prefix_lm is True return self.model( - input_ids=batch['input_ids'], + input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), prefix_mask=batch.get('bidirectional_mask', None), sequence_id=batch.get('sequence_id', None), diff --git a/tests/test_model.py b/tests/test_model.py index 4d5b0a4dbc..acb2074ae9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -5,7 +5,7 @@ import os import pathlib import warnings -from typing import Any, Dict, Union, cast +from typing import Any, Dict, List, Optional, Union, cast from unittest import mock import pytest @@ -94,13 +94,26 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): return test_cfg, model, optimizer -def gen_random_batch(batch_size: int, test_cfg: Union[DictConfig, ListConfig]): +def gen_random_batch(batch_size: int, + test_cfg: Union[DictConfig, ListConfig], + inputs: Optional[List[str]] = None): + # inputs can be [], ['input_ids'], ['input_ids', 'inputs_embeds'], and ['inputs_embeds'] + # default to only input ids + if inputs == None: + inputs = ['input_ids'] # generate input batch of random data, suitable for a Causal or Prefix LM batch = {} - batch['input_ids'] = torch.randint( - low=0, - high=test_cfg.model.vocab_size, - size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + for inp in inputs: + if inp == 'input_ids': + batch['input_ids'] = torch.randint( + low=0, + high=test_cfg.model.vocab_size, + size=(batch_size, test_cfg.max_seq_len)).to(test_cfg.device) + if inp == 'inputs_embeds': + batch['inputs_embeds'] = torch.randn( + batch_size, test_cfg.max_seq_len, + test_cfg.model.d_model).to(test_cfg.device) + batch['labels'] = torch.randint(low=0, high=test_cfg.model.vocab_size, size=(batch_size, test_cfg.max_seq_len)).to( @@ -150,6 +163,34 @@ def test_full_forward_and_backward(batch_size: int = 2): assert not torch.equal(original_params, updated_params) +def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2): + test_cfg, model, optimizer = get_objs( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + + batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds']) + + model.train() + original_params = next(model.parameters()).clone().data + outputs = model(batch) + loss = model.loss(outputs, batch) + loss.backward() + optimizer.step() + updated_params = next(model.parameters()).clone().data + assert not torch.equal(original_params, updated_params) + + +@pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']]) +def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]): + test_cfg, model, _ = get_objs( + conf_path='scripts/train/yamls/pretrain/testing.yaml') + + batch = gen_random_batch(2, test_cfg, inputs=inputs) + + model.train() + with pytest.raises(ValueError): + _ = model(batch) + + def test_attention_mechanism(batch_size: int = 2): test_cfg, model, _ = get_objs( conf_path='scripts/train/yamls/pretrain/testing.yaml') @@ -825,6 +866,9 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, no_padding_attention_mask = composer_device.tensor_to_device( no_padding_attention_mask) + # inputs_embeds + inputs_embeds = composer_device.tensor_to_device(torch.randn(2, 3, 128)) + # a single batch with different amounts of left padding in the input batched_input_ids = torch.tensor([[50256, 50256, 50256, 11274, 16390, 11], [50256, 50256, 16, 11274, 16390, 11]]) @@ -860,6 +904,29 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, assert generation_with_no_padding[:, 3:].equal( generation_with_left_padding[:, 6:]) + # check that both/neither ids and embeds do not error + # note that we need to set the BOS token ID for generating from neither + _ = mpt.generate(input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=False) + _ = mpt.generate(input_ids=no_padding_input_ids, + inputs_embeds=inputs_embeds, + attention_mask=no_padding_attention_mask, + max_new_tokens=5, + use_cache=True) + _ = mpt.generate(input_ids=None, + inputs_embeds=None, + max_new_tokens=5, + use_cache=False, + bos_token_id=50256) + _ = mpt.generate(input_ids=None, + inputs_embeds=None, + max_new_tokens=5, + use_cache=True, + bos_token_id=50256) + @pytest.mark.gpu @pytest.mark.parametrize('world_size', [1, 2])