Skip to content

Commit

Permalink
Fix InvalidPromptResponseKeysError bug
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Apr 23, 2024
1 parent 0c6bd75 commit 724b668
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 32 deletions.
27 changes: 12 additions & 15 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
InvalidRoleError,
MisconfiguredHfDatasetError,
NotEnoughChatDataError,
TooManyKeysInExampleError,
UnableToProcessPromptResponseError,
UnknownExampleTypeError)
# yapf: enable
Expand Down Expand Up @@ -108,14 +107,20 @@ def _get_example_type(example: Example) -> ExampleType:
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a Mapping, but found {type(example)}')
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
if (len(example.keys()) == 1 and
any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS)):
return 'chat'
elif any(p in example for p in _ALLOWED_PROMPT_KEYS) and any(
r in example for r in _ALLOWED_RESPONSE_KEYS):
elif (len(example.keys()) == 2 and
any(p in example for p in _ALLOWED_PROMPT_KEYS) and
any(r in example for r in _ALLOWED_RESPONSE_KEYS)):
return 'prompt_response'
else:
raise UnknownExampleTypeError(example)
raise UnknownExampleTypeError((
f'Found keys {example.keys()} in dataset. Unknown example type. For prompt and response '
f'finetuning, the valid prompt keys are {_ALLOWED_PROMPT_KEYS} and the valid response keys are '
f'{_ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {_ALLOWED_MESSAGES_KEYS}'
))


def _is_empty_or_nonexistent(dirpath: str) -> bool:
Expand All @@ -136,8 +141,6 @@ def _get_key(dictionary: Mapping[str, Any], allowed_keys: set[str]):
f'Expected dictionary to be a mapping, but found {type(dictionary)}'
)
desired_keys = allowed_keys.intersection(dictionary.keys())
if len(desired_keys) != 1:
raise TooManyKeysInExampleError(allowed_keys, desired_keys)
return list(desired_keys)[0]


Expand Down Expand Up @@ -307,12 +310,6 @@ def _tokenize_prompt_response_formatted_example(
prompt_keys = example_keys.intersection(_ALLOWED_PROMPT_KEYS)
response_keys = example_keys.intersection(_ALLOWED_RESPONSE_KEYS)

if len(prompt_keys) != 1:
raise TooManyKeysInExampleError(_ALLOWED_PROMPT_KEYS, prompt_keys)

if len(response_keys) != 1:
raise TooManyKeysInExampleError(_ALLOWED_RESPONSE_KEYS, response_keys)

prompt_key = prompt_keys.pop()
response_key = response_keys.pop()
prompt = example[prompt_key]
Expand Down Expand Up @@ -366,7 +363,7 @@ def tokenize_formatted_example(
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise UnknownExampleTypeError(example)
raise NotImplementedError


def is_valid_ift_example(max_seq_len: int, target_prompts: str,
Expand Down
15 changes: 1 addition & 14 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

"""Custom exceptions for the LLMFoundry."""
from collections.abc import Mapping
from typing import Any, Dict, List


Expand Down Expand Up @@ -42,19 +41,7 @@ def __init__(self, dataset_name: str, split: str,
class UnknownExampleTypeError(KeyError):
"""Error thrown when an unknown example type is used in a task."""

def __init__(self, example: Mapping) -> None:
self.example = example
message = f'Unknown example type {example=}'
super().__init__(message)


class TooManyKeysInExampleError(ValueError):
"""Error thrown when a data sample has too many keys."""

def __init__(self, desired_keys: set[str], keys: set[str]) -> None:
self.desired_keys = desired_keys
self.keys = keys
message = f'Data sample has {len(keys)} keys in `allowed_keys`: {desired_keys} Please specify exactly one. Provided keys: {keys}'
def __init__(self, message: str) -> None:
super().__init__(message)


Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,10 @@ def test_malformed_data(
match='Expected response to be')
if add_unknown_example_type:
error_context = pytest.raises(UnknownExampleTypeError,
match='Unknown example type')
match=r'.*Unknown example type')
if add_too_many_example_keys:
error_context = pytest.raises(TooManyKeysInExampleError,
match='Please specify exactly one.')
error_context = pytest.raises(UnknownExampleTypeError,
match=r'.*Unknown example type')

with error_context:
dl = build_finetuning_dataloader(cfg, tokenizer,
Expand Down

0 comments on commit 724b668

Please sign in to comment.