diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 42f817b386..457f146986 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -407,6 +407,10 @@ def build_tokenizer( int(1e30), ) + if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None: + raise ValueError( + f'The tokenizer {tokenizer_name} must have an eos_token.') + if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: if dist.get_local_rank() == 0: diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 303afc9b7d..b35e053c5d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -48,6 +48,13 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): assert isinstance(tokenizer, PreTrainedTokenizerBase) +def test_tokenizer_no_EOS(): + with pytest.raises( + ValueError, + match='The tokenizer bert-base-uncased must have an eos_token.'): + build_tokenizer('bert-base-uncased', {}) + + def test_build_callback_fails(): with pytest.raises(ValueError): build_callback('nonexistent_callback', {}, {})