From 6252f791b52c68e3dd26eebbe1359b409d65c8d3 Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Wed, 24 Apr 2024 16:37:06 -0400 Subject: [PATCH] Fix InvalidPromptResponseKeysError bug (#1131) --- llmfoundry/data/finetuning/tasks.py | 37 ++++++++++-------------- llmfoundry/utils/exceptions.py | 24 +++++++-------- tests/data/test_dataloader.py | 7 ++--- tests/data/test_template_tokenization.py | 10 +++---- 4 files changed, 35 insertions(+), 43 deletions(-) diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index e6a3afb188..05a01b80c6 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -53,7 +53,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: stitch_turns_decoder_only, stitch_turns_encoder_decoder) # yapf: disable -from llmfoundry.utils.exceptions import (ConsecutiveRepeatedChatRolesError, +from llmfoundry.utils.exceptions import (ALLOWED_MESSAGES_KEYS, + ALLOWED_PROMPT_KEYS, + ALLOWED_RESPONSE_KEYS, + ConsecutiveRepeatedChatRolesError, IncorrectMessageKeyQuantityError, InvalidContentTypeError, InvalidFileExtensionError, @@ -64,7 +67,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: InvalidRoleError, MisconfiguredHfDatasetError, NotEnoughChatDataError, - TooManyKeysInExampleError, UnableToProcessPromptResponseError, UnknownExampleTypeError) # yapf: enable @@ -79,9 +81,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: 'StreamingFinetuningDataset', ] -_ALLOWED_RESPONSE_KEYS = {'response', 'completion'} -_ALLOWED_PROMPT_KEYS = {'prompt'} -_ALLOWED_MESSAGES_KEYS = {'messages'} _ALLOWED_ROLE_KEYS = {'role'} _ALLOWED_CONTENT_KEYS = {'content'} _ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'} @@ -113,11 +112,13 @@ def _get_example_type(example: Example) -> ExampleType: if not isinstance(example, Mapping): raise TypeError( f'Expected example to be a Mapping, but found {type(example)}') - if any(allowed_message_key in example - for allowed_message_key in _ALLOWED_MESSAGES_KEYS): + if (len(example.keys()) == 1 and + any(allowed_message_key in example + for allowed_message_key in ALLOWED_MESSAGES_KEYS)): return 'chat' - elif any(p in example for p in _ALLOWED_PROMPT_KEYS) and any( - r in example for r in _ALLOWED_RESPONSE_KEYS): + elif (len(example.keys()) == 2 and + any(p in example for p in ALLOWED_PROMPT_KEYS) and + any(r in example for r in ALLOWED_RESPONSE_KEYS)): return 'prompt_response' else: raise UnknownExampleTypeError(example) @@ -141,8 +142,6 @@ def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]): f'Expected dictionary to be a mapping, but found {type(dictionary)}' ) desired_keys = allowed_keys.intersection(dictionary.keys()) - if len(desired_keys) != 1: - raise TooManyKeysInExampleError(allowed_keys, desired_keys) return list(desired_keys)[0] @@ -150,7 +149,7 @@ def _validate_chat_formatted_example(example: ChatFormattedDict): if not isinstance(example, Mapping): raise TypeError( f'Expected example to be a mapping, but found {type(example)}') - messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)] + messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)] if not isinstance(messages, List): raise TypeError( f'Expected messages to be an iterable, but found {type(messages)}') @@ -200,7 +199,7 @@ def _slice_chat_formatted_example( KeyError: If a message does not have a role or content. """ _validate_chat_formatted_example(example) - messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)] + messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)] last_message = messages[-1] if last_message['role'] != 'assistant': @@ -309,14 +308,8 @@ def _tokenize_prompt_response_formatted_example( tokenizer: PreTrainedTokenizerBase) -> TokenizedExample: """Tokenize a formatted example and validate expected keys.""" example_keys = set(example.keys()) - prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS) - response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS) - - if len(prompt_keys) != 1: - raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys) - - if len(response_keys) != 1: - raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys) + prompt_keys = example_keys.intersection(ALLOWED_PROMPT_KEYS) + response_keys = example_keys.intersection(ALLOWED_RESPONSE_KEYS) prompt_key = prompt_keys.pop() response_key = response_keys.pop() @@ -371,7 +364,7 @@ def tokenize_formatted_example( return _tokenize_prompt_response_formatted_example( prompt_response_example, tokenizer) else: - raise UnknownExampleTypeError(example) + raise NotImplementedError def is_valid_ift_example(max_seq_len: int, target_prompts: str, diff --git a/llmfoundry/utils/exceptions.py b/llmfoundry/utils/exceptions.py index 7a6be2be29..ba34b29be3 100644 --- a/llmfoundry/utils/exceptions.py +++ b/llmfoundry/utils/exceptions.py @@ -6,10 +6,12 @@ from typing import Any, Dict, List __all__ = [ + 'ALLOWED_RESPONSE_KEYS', + 'ALLOWED_PROMPT_KEYS', + 'ALLOWED_MESSAGES_KEYS', 'MissingHuggingFaceURLSplitError', 'NotEnoughDatasetSamplesError', 'UnknownExampleTypeError', - 'TooManyKeysInExampleError', 'NotEnoughChatDataError', 'ConsecutiveRepeatedChatRolesError', 'InvalidLastChatMessageRoleError', @@ -29,6 +31,10 @@ 'MisconfiguredHfDatasetError', ] +ALLOWED_RESPONSE_KEYS = {'response', 'completion'} +ALLOWED_PROMPT_KEYS = {'prompt'} +ALLOWED_MESSAGES_KEYS = {'messages'} + # Finetuning dataloader exceptions class MissingHuggingFaceURLSplitError(ValueError): @@ -68,17 +74,11 @@ class UnknownExampleTypeError(KeyError): def __init__(self, example: Mapping) -> None: self.example = example - message = f'Unknown example type {example=}' - super().__init__(message) - - -class TooManyKeysInExampleError(ValueError): - """Error thrown when a data sample has too many keys.""" - - def __init__(self, desired_keys: set[str], keys: set[str]) -> None: - self.desired_keys = desired_keys - self.keys = keys - message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}' + message = ( + f'Found keys {example.keys()} in dataset. Unknown example type. For prompt and response ' + f'finetuning, the valid prompt keys are {ALLOWED_PROMPT_KEYS} and the valid response keys are ' + f'{ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {ALLOWED_MESSAGES_KEYS}' + ) super().__init__(message) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 3eb5e3773d..47910325cd 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -43,7 +43,6 @@ InvalidRoleError, MisconfiguredHfDatasetError, NotEnoughDatasetSamplesError, - TooManyKeysInExampleError, UnknownExampleTypeError) # yapf: enable from scripts.data_prep.convert_dataset_hf import main as main_hf @@ -789,10 +788,10 @@ def test_malformed_data( match='Expected response to be') if add_unknown_example_type: error_context = pytest.raises(UnknownExampleTypeError, - match='Unknown example type') + match=r'.*Unknown example type') if add_too_many_example_keys: - error_context = pytest.raises(TooManyKeysInExampleError, - match='Please specify exactly one.') + error_context = pytest.raises(UnknownExampleTypeError, + match=r'.*Unknown example type') with error_context: dl = build_finetuning_dataloader(cfg, tokenizer, diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py index 756912342f..79f17b4ace 100644 --- a/tests/data/test_template_tokenization.py +++ b/tests/data/test_template_tokenization.py @@ -6,12 +6,12 @@ import pytest import transformers -from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS, - _ALLOWED_RESPONSE_KEYS, - _slice_chat_formatted_example, +from llmfoundry.data.finetuning.tasks import (_slice_chat_formatted_example, dataset_constructor, tokenize_formatted_example) from llmfoundry.utils.builders import build_tokenizer +from llmfoundry.utils.exceptions import (ALLOWED_PROMPT_KEYS, + ALLOWED_RESPONSE_KEYS) def test_tokenize_chat_example_malformed(): @@ -167,8 +167,8 @@ def test_tokenize_instruct_example_malformed(): def test_tokenize_instruct_example_well_formed(): tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') - for prompt_key in _ALLOWED_PROMPT_KEYS: - for response_key in _ALLOWED_RESPONSE_KEYS: + for prompt_key in ALLOWED_PROMPT_KEYS: + for response_key in ALLOWED_RESPONSE_KEYS: example = {prompt_key: 'prompt', response_key: 'response'} tokenized_example = tokenize_formatted_example(example, tokenizer)