Skip to content

Commit

Permalink
Fix InvalidPromptResponseKeysError bug
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Apr 23, 2024
1 parent 0c6bd75 commit 6df5823
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
21 changes: 15 additions & 6 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,23 @@ def _get_example_type(example: Example) -> ExampleType:
if not isinstance(example, Mapping):
raise TypeError(
f'Expected example to be a Mapping, but found {type(example)}')
if any(allowed_message_key in example
for allowed_message_key in _ALLOWED_MESSAGES_KEYS):
if (
len(example.keys()) == 1 and
any(allowed_message_key in example for allowed_message_key in _ALLOWED_MESSAGES_KEYS)
):
return 'chat'
elif any(p in example for p in _ALLOWED_PROMPT_KEYS) and any(
r in example for r in _ALLOWED_RESPONSE_KEYS):
elif (
len(example.keys()) == 2 and
any(p in example for p in _ALLOWED_PROMPT_KEYS) and
any(r in example for r in _ALLOWED_RESPONSE_KEYS)
):
return 'prompt_response'
else:
raise UnknownExampleTypeError(example)
raise UnknownExampleTypeError((
f"Found keys {example.keys()} in dataset. Unknown example type. For prompt and response ",
f"finetuning, the valid prompt keys are {_ALLOWED_PROMPT_KEYS} and the valid response keys are ",
f"{_ALLOWED_RESPONSE_KEYS}. For chat finetuning, the allowed keys are {_ALLOWED_MESSAGES_KEYS}",
))


def _is_empty_or_nonexistent(dirpath: str) -> bool:
Expand Down Expand Up @@ -366,7 +375,7 @@ def tokenize_formatted_example(
return _tokenize_prompt_response_formatted_example(
prompt_response_example, tokenizer)
else:
raise UnknownExampleTypeError(example)
raise NotImplementedError


def is_valid_ift_example(max_seq_len: int, target_prompts: str,
Expand Down
4 changes: 1 addition & 3 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def __init__(self, dataset_name: str, split: str,
class UnknownExampleTypeError(KeyError):
"""Error thrown when an unknown example type is used in a task."""

def __init__(self, example: Mapping) -> None:
self.example = example
message = f'Unknown example type {example=}'
def __init__(self, message: str) -> None:
super().__init__(message)


Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def test_malformed_data(
match='Expected response to be')
if add_unknown_example_type:
error_context = pytest.raises(UnknownExampleTypeError,
match='Unknown example type')
match=r'.*Unknown example type')
if add_too_many_example_keys:
error_context = pytest.raises(TooManyKeysInExampleError,
match='Please specify exactly one.')
Expand Down

0 comments on commit 6df5823

Please sign in to comment.