Skip to content

Commit

Permalink
Fix InvalidPromptResponseKeysError bug (#1131)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Apr 24, 2024
1 parent 72da1d7 commit 6252f79
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 43 deletions.
37 changes: 15 additions & 22 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
stitch_turns_decoder_only,
stitch_turns_encoder_decoder)
# yapf: disable
from llmfoundry.utils.exceptions import (ConsecutiveRepeatedChatRolesError,
from llmfoundry.utils.exceptions import (ALLOWED_MESSAGES_KEYS,
ALLOWED_PROMPT_KEYS,
ALLOWED_RESPONSE_KEYS,
ConsecutiveRepeatedChatRolesError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidFileExtensionError,
Expand All @@ -64,7 +67,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
InvalidRoleError,
MisconfiguredHfDatasetError,
NotEnoughChatDataError,
TooManyKeysInExampleError,
UnableToProcessPromptResponseError,
UnknownExampleTypeError)
# yapf: enable
Expand All @@ -79,9 +81,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
'StreamingFinetuningDataset',
]

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}
_ALLOWED_MESSAGES_KEYS = {'messages'}
_ALLOWED_ROLE_KEYS = {'role'}
_ALLOWED_CONTENT_KEYS = {'content'}
_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'}
Expand Down Expand Up @@ -113,11 +112,13 @@ def _get_example_type(example: Example) -> ExampleType:
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a Mapping, but found {type(example)}')
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
if (len(example.keys()) == 1 and
any(allowed_message_key in example
for allowed_message_key in ALLOWED_MESSAGES_KEYS)):
return 'chat'
elif any(p in example for p in _ALLOWED_PROMPT_KEYS) and any(
r in example for r in _ALLOWED_RESPONSE_KEYS):
elif (len(example.keys()) == 2 and
any(p in example for p in ALLOWED_PROMPT_KEYS) and
any(r in example for r in ALLOWED_RESPONSE_KEYS)):
return 'prompt_response'
else:
raise UnknownExampleTypeError(example)
Expand All @@ -141,16 +142,14 @@ def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
f'Expected dictionary to be a mapping, but found {type(dictionary)}'
)
desired_keys = allowed_keys.intersection(dictionary.keys())
if len(desired_keys) != 1:
raise TooManyKeysInExampleError(allowed_keys, desired_keys)
return list(desired_keys)[0]


def _validate_chat_formatted_example(example: ChatFormattedDict):
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a mapping, but found {type(example)}')
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)]
if not isinstance(messages, List):
raise TypeError(
f'Expected messages to be an iterable, but found {type(messages)}')
Expand Down Expand Up @@ -200,7 +199,7 @@ def _slice_chat_formatted_example(
KeyError: If a message does not have a role or content.
"""
_validate_chat_formatted_example(example)
messages = example[_get_key(example, _ALLOWED_MESSAGES_KEYS)]
messages = example[_get_key(example, ALLOWED_MESSAGES_KEYS)]

last_message = messages[-1]
if last_message['role'] != 'assistant':
Expand Down Expand Up @@ -309,14 +308,8 @@ def _tokenize_prompt_response_formatted_example(
tokenizer: PreTrainedTokenizerBase) -> TokenizedExample:
"""Tokenize a formatted example and validate expected keys."""
example_keys = set(example.keys())
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys)

if len(response_keys) != 1:
raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys)
prompt_keys = example_keys.intersection(ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(ALLOWED_RESPONSE_KEYS)

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
Expand Down Expand Up @@ -371,7 +364,7 @@ def tokenize_formatted_example(
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise UnknownExampleTypeError(example)
raise NotImplementedError


def is_valid_ift_example(max_seq_len: int, target_prompts: str,
Expand Down
24 changes: 12 additions & 12 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from typing import Any, Dict, List

__all__ = [
'ALLOWED_RESPONSE_KEYS',
'ALLOWED_PROMPT_KEYS',
'ALLOWED_MESSAGES_KEYS',
'MissingHuggingFaceURLSplitError',
'NotEnoughDatasetSamplesError',
'UnknownExampleTypeError',
'TooManyKeysInExampleError',
'NotEnoughChatDataError',
'ConsecutiveRepeatedChatRolesError',
'InvalidLastChatMessageRoleError',
Expand All @@ -29,6 +31,10 @@
'MisconfiguredHfDatasetError',
]

ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
ALLOWED_PROMPT_KEYS = {'prompt'}
ALLOWED_MESSAGES_KEYS = {'messages'}


# Finetuning dataloader exceptions
class MissingHuggingFaceURLSplitError(ValueError):
Expand Down Expand Up @@ -68,17 +74,11 @@ class UnknownExampleTypeError(KeyError):

def __init__(self, example: Mapping) -> None:
self.example = example
message = f'Unknown example type {example=}'
super().__init__(message)


class TooManyKeysInExampleError(ValueError):
"""Error thrown when a data sample has too many keys."""

def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
self.desired_keys = desired_keys
self.keys = keys
message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
message = (
f'Found keys {example.keys()} in dataset. Unknown example type. For prompt and response '
f'finetuning, the valid prompt keys are {ALLOWED_PROMPT_KEYS} and the valid response keys are '
f'{ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {ALLOWED_MESSAGES_KEYS}'
)
super().__init__(message)


Expand Down
7 changes: 3 additions & 4 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
InvalidRoleError,
MisconfiguredHfDatasetError,
NotEnoughDatasetSamplesError,
TooManyKeysInExampleError,
UnknownExampleTypeError)
# yapf: enable
from scripts.data_prep.convert_dataset_hf import main as main_hf
Expand Down Expand Up @@ -789,10 +788,10 @@ def test_malformed_data(
match='Expected response to be')
if add_unknown_example_type:
error_context = pytest.raises(UnknownExampleTypeError,
match='Unknown example type')
match=r'.*Unknown example type')
if add_too_many_example_keys:
error_context = pytest.raises(TooManyKeysInExampleError,
match='Please specify exactly one.')
error_context = pytest.raises(UnknownExampleTypeError,
match=r'.*Unknown example type')

with error_context:
dl = build_finetuning_dataloader(cfg, tokenizer,
Expand Down
10 changes: 5 additions & 5 deletions tests/data/test_template_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import pytest
import transformers

from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
_ALLOWED_RESPONSE_KEYS,
_slice_chat_formatted_example,
from llmfoundry.data.finetuning.tasks import (_slice_chat_formatted_example,
dataset_constructor,
tokenize_formatted_example)
from llmfoundry.utils.builders import build_tokenizer
from llmfoundry.utils.exceptions import (ALLOWED_PROMPT_KEYS,
ALLOWED_RESPONSE_KEYS)


def test_tokenize_chat_example_malformed():
Expand Down Expand Up @@ -167,8 +167,8 @@ def test_tokenize_instruct_example_malformed():
def test_tokenize_instruct_example_well_formed():
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')

for prompt_key in _ALLOWED_PROMPT_KEYS:
for response_key in _ALLOWED_RESPONSE_KEYS:
for prompt_key in ALLOWED_PROMPT_KEYS:
for response_key in ALLOWED_RESPONSE_KEYS:

example = {prompt_key: 'prompt', response_key: 'response'}
tokenized_example = tokenize_formatted_example(example, tokenizer)
Expand Down

0 comments on commit 6252f79

Please sign in to comment.