From 7c82d1fa977a850bf5c1971a2c18f85f700203c3 Mon Sep 17 00:00:00 2001 From: Max Marion Date: Wed, 28 Feb 2024 09:59:20 -0800 Subject: [PATCH] Remove "generation_length" in favor of "generation_kwargs" (#3014) * kill generation_length * fix tests * fix test * add deprecation warning * fix test * add gen_len back into static_keys * simplify setting variable in forward and add test * simply test * trailing comma * trailing comma * linting --------- Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com> --- .../in_context_learning_evaluation.py | 5 ++-- composer/models/huggingface.py | 12 ++++++++- .../test_in_context_learning_datasets.py | 25 +++++++++++-------- tests/models/test_hf_model.py | 25 +++++++++++++++++-- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/composer/datasets/in_context_learning_evaluation.py b/composer/datasets/in_context_learning_evaluation.py index 459487f158..38c73bd876 100644 --- a/composer/datasets/in_context_learning_evaluation.py +++ b/composer/datasets/in_context_learning_evaluation.py @@ -715,10 +715,10 @@ def __init__( 'mode': 'generate', 'labels': [], 'cot_delimiter': self.cot_delimiter, - 'generation_length': self.max_answer_length, 'stopping_criteria': early_stopping_criteria, 'do_normalization': do_normalization, 'generation_kwargs': { + 'max_new_tokens': self.max_answer_length, 'pad_token_id': self.pad_tok_id, 'use_cache': True, 'eos_token_id': self.tokenizer.eos_token_id, @@ -1260,7 +1260,6 @@ class InContextLearningCodeEvalDataset(InContextLearningDataset): - test_outputs: List of test outputs - languages: List of languages - pass_at_k: Passed value for pass_at_k - - generation_length: Derrived maximum generation length - generation_kwargs: Dictionary of kwargs neeeded for generation. Includes the following, which will be individually overwritten by keys in generaiton_kwargs if set (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig for more details): @@ -1349,7 +1348,6 @@ def __init__( 'test_outputs': [], 'languages': [], 'pass_at_k': pass_at_k, - 'generation_length': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length), 'generation_kwargs': { 'pad_token_id': self.pad_tok_id, 'num_beams': 1, # single beam @@ -1357,6 +1355,7 @@ def __init__( 'temperature': 0.2, # good default for code 'use_cache': True, 'eos_token_id': self.tokenizer.eos_token_id, + 'max_new_tokens': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length), }, 'sample_id': [], 'pass_at_k': list(pass_at_k), diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index 439f8b50fe..07149e4ce6 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -466,10 +466,20 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): raise ValueError( 'Generation eval cannot be used without providing a tokenizer to the model constructor.') + if 'generation_length' in batch: + warnings.warn( + ('`generation_length` has been deprecated in favor of passing `max_new_tokens` directly into `generation_kwargs`.' + 'It will be removed in v0.21'), + DeprecationWarning, + ) + if 'generation_kwargs' in batch: + batch['generation_kwargs']['max_new_tokens'] = batch['generation_length'] + else: + batch['generation_kwargs'] = {'max_new_tokens': batch['generation_length']} + self.labels = batch.pop('labels') generation = self.generate(batch['input_ids'], attention_mask=batch['attention_mask'], - max_new_tokens=batch['generation_length'], synced_gpus=dist.get_world_size() > 1, **batch.get('generation_kwargs', {})) diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index 3611e20dd1..66d14cc76f 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -296,7 +296,7 @@ def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path): continuation_delimiter=': ', destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'), generation_kwargs=None) - assert len(dl.base_batch['generation_kwargs']) == 3 + assert len(dl.base_batch['generation_kwargs']) == 4 def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path): @@ -321,7 +321,7 @@ def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path): generation_kwargs={'temperature': 0.9}) assert 'generation_kwargs' in dl.base_batch assert dl.base_batch['generation_kwargs']['temperature'] == 0.9 - assert len(dl.base_batch['generation_kwargs']) == 4 + assert len(dl.base_batch['generation_kwargs']) == 5 @pytest.mark.filterwarnings( @@ -1255,8 +1255,8 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path): assert len(split2['labels']) == 1 assert all(isinstance(v, list) for v in split1['labels'] + split2['labels']) - assert isinstance(split1['generation_length'], int) - assert isinstance(split2['generation_length'], int) + assert isinstance(split1['generation_kwargs']['max_new_tokens'], int) + assert isinstance(split2['generation_kwargs']['max_new_tokens'], int) assert isinstance(split1['generation_kwargs'], dict) assert isinstance(split2['generation_kwargs'], dict) @@ -1326,7 +1326,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews assert batch['mode'] == 'generate' # the maximum generation length from the small test data - assert batch['generation_length'] == maximum_answer_length + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) decoded_batch = tokenizer.batch_decode(batch['input_ids']) @@ -1376,7 +1376,7 @@ def test_qa_task_with_cot_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data - assert batch['generation_length'] == maximum_answer_length + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) decoded_batch = tokenizer.batch_decode(batch['input_ids']) assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch) @@ -1491,8 +1491,11 @@ def test_code_eval_split_batch(dataset_uri, tmp_path): assert len(batch[field]) == size assert all(isinstance(val, type_) for val in batch[field]) - static_keys = {'pass_at_k': (int, list), 'generation_length': int, 'generation_kwargs': dict} + static_keys = {'pass_at_k': (int, list), 'generation_kwargs': dict} for batch in batches: + assert 'generation_kwargs' in batch + assert 'max_new_tokens' in batch['generation_kwargs'] + assert isinstance(batch['generation_kwargs']['max_new_tokens'], int) for field, type_ in static_keys.items(): assert isinstance(batch[field], type_) @@ -1544,7 +1547,7 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data - assert batch['generation_length'] == 129 + assert batch['generation_kwargs']['max_new_tokens'] == 129 has_left_padding.extend([item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) assert not all(has_left_padding) # longest should be pushed left @@ -1613,7 +1616,7 @@ def test_code_eval_test_cases(dataset_uri, tmp_path, tiny_llama_tokenizer): assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data - assert batch['generation_length'] == 129 + assert batch['generation_kwargs']['max_new_tokens'] == 129 assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left mod = types.ModuleType('test_module') @@ -1703,7 +1706,7 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data - assert batch['generation_length'] == 122 + assert batch['generation_kwargs']['max_new_tokens'] == 122 has_left_padding.extend([item[0] == tokenizer.eos_token_id for item in batch['input_ids']]) assert not all(has_left_padding) # longest should be pushed left @@ -2459,7 +2462,7 @@ def test_hf_dataloading_custom_parsing(dataset_uri, tiny_gpt2_tokenizer, tmp_pat assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length) assert batch['mode'] == 'generate' # the maximum generation length from the small test data - assert batch['generation_length'] == maximum_answer_length + assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids']) decoded_batch = tokenizer.batch_decode(batch['input_ids']) diff --git a/tests/models/test_hf_model.py b/tests/models/test_hf_model.py index e677941e9e..56bc13106b 100644 --- a/tests/models/test_hf_model.py +++ b/tests/models/test_hf_model.py @@ -1195,11 +1195,12 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f for k, v in input_dict.items(): input_dict[k] = device.tensor_to_device(v) input_dict['mode'] = 'generate' + input_dict['generation_kwargs'] = {} - input_dict['generation_length'] = 5 + input_dict['generation_kwargs']['max_new_tokens'] = 5 input_dict['labels'] = [['answer1'], ['answer2']] generation1 = model.eval_forward(input_dict, None) - input_dict['generation_length'] = 3 + input_dict['generation_kwargs']['max_new_tokens'] = 3 input_dict['labels'] = [['answer1'], ['answer2']] generation2 = model.eval_forward(input_dict, None) @@ -1208,6 +1209,26 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f assert all(isinstance(decoded_generation, str) for decoded_generation in generation2) +def test_eval_forward_generate_adjust_generation_length(tiny_gpt2_model, tiny_gpt2_tokenizer): + model = HuggingFaceModel(tiny_gpt2_model, tokenizer=tiny_gpt2_tokenizer, use_logits=True) + input_dict = tiny_gpt2_tokenizer(['hello', 'goodbyes'], return_tensors='pt', padding=True) + + input_dict['mode'] = 'generate' + input_dict['generation_kwargs'] = {} + input_dict['generation_length'] = 5 + input_dict['labels'] = [['answer1'], ['answer2']] + with pytest.warns(DeprecationWarning): + generation1 = model.eval_forward(input_dict, None) + + input_dict['generation_length'] = 3 + input_dict['labels'] = [['answer1'], ['answer2']] + generation2 = model.eval_forward(input_dict, None) + + assert len(generation1) == len(generation2) == 2 + assert all(isinstance(decoded_generation, str) for decoded_generation in generation1) + assert all(isinstance(decoded_generation, str) for decoded_generation in generation2) + + @pytest.mark.parametrize('peft_type', ['LORA', 'loRa']) @pytest.mark.parametrize('task_type', ['CAUSAL_LM', 'causal_lm']) def test_peft_init(peft_type: str, task_type: str, tiny_gpt2_model, gpt2_peft_config):