Skip to content

Commit

Permalink
Preserve eos in encoding when max_seq_length = -1 (#1694)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
sanderland and awaelchli authored Aug 26, 2024
1 parent ea01fbc commit f655f01
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
4 changes: 3 additions & 1 deletion litgpt/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def __getitem__(self, idx: int) -> Dict[str, Tensor]:
prompt = self.prompt_style.apply(prompt=example["instruction"], **example)
encoded_prompt = self.tokenizer.encode(prompt, max_length=self.max_seq_length)
encoded_response = self.tokenizer.encode(example["output"], bos=False, eos=True, max_length=self.max_seq_length)
encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)[: self.max_seq_length]
encoded_prompt_and_response = torch.cat((encoded_prompt, encoded_response)).type(torch.int64)
if self.max_seq_length > 0: # do not slice off last token when self.max_seq_length = -1
encoded_prompt_and_response = encoded_prompt_and_response[: self.max_seq_length]

# The labels are the full prompt with response, but with the prompt masked out
labels = encoded_prompt_and_response.clone()
Expand Down
10 changes: 7 additions & 3 deletions tests/data/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@pytest.mark.parametrize("mask_prompt", [True, False])
@pytest.mark.parametrize("ignore_index", [-1, -100])
@pytest.mark.parametrize("max_seq_length", [1000, 5])
@pytest.mark.parametrize("max_seq_length", [1000, 5, -1])
def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer):
class Style(PromptStyle):
def apply(self, prompt, **kwargs):
Expand All @@ -34,8 +34,12 @@ def apply(self, prompt, **kwargs):
torch.tensor([i, i, i, i, i, i, i, i, i, i, i, i, 66, 97, 114, 1]) if mask_prompt else expected_input_ids
)

assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length])
assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length])
if max_seq_length == -1:
assert torch.equal(dataset[0]["input_ids"], expected_input_ids)
assert torch.equal(dataset[0]["labels"], expected_labels)
else:
assert torch.equal(dataset[0]["input_ids"], expected_input_ids[:max_seq_length])
assert torch.equal(dataset[0]["labels"], expected_labels[:max_seq_length])


@pytest.mark.parametrize("ignore_index", [-1, -100])
Expand Down

0 comments on commit f655f01

Please sign in to comment.