Skip to content

Commit

Permalink
add sentencepiece support (#2093)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and Bandish Shah committed Mar 31, 2023
1 parent fba78d7 commit 98a2699
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
41 changes: 31 additions & 10 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,25 @@ def hf_from_composer_checkpoint(
if hf_tokenizer_state != {}:
with tempfile.TemporaryDirectory() as _tmp_dir:
for filename, saved_content in hf_tokenizer_state.items():
with open(Path(_tmp_dir) / f'{filename}{saved_content["file_extension"]}', 'w') as _tmp_file:
if saved_content['file_extension'] == '.json':
tokenizer_file_path = Path(_tmp_dir) / f'{filename}{saved_content["file_extension"]}'
if saved_content['file_extension'] == '.json':
with open(tokenizer_file_path, 'w') as _tmp_file:
json.dump(saved_content['content'], _tmp_file)
elif saved_content['file_extension'] == '.txt':
elif saved_content['file_extension'] == '.txt':
with open(tokenizer_file_path, 'w') as _tmp_file:
for line in saved_content['content']:
_tmp_file.write(line)
_tmp_file.write('\n')
elif saved_content['file_extension'] == '.model':
try:
import sentencepiece as spm
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='sentencepiece',
conda_package='sentencepiece') from e
s = spm.SentencePieceProcessor()
s.load_from_serialized_proto(saved_content['content'])
with open(tokenizer_file_path, 'wb') as _tmp_file:
_tmp_file.write(s.serialized_model_proto())
hf_tokenizer = transformers.AutoTokenizer.from_pretrained(_tmp_dir)

# we need to set the name_or_path back because otherwise it is the tmp dir we are loading from here
Expand Down Expand Up @@ -388,15 +400,24 @@ def get_metadata(self):
if self.tokenizer is not None:
for tokenizer_file_name in tokenizer_dir.iterdir():
tokenizer_file_path = tokenizer_dir / tokenizer_file_name
with open(tokenizer_file_path) as _tokenizer_file:
tokenizer_file_extension = tokenizer_file_path.suffix
if tokenizer_file_extension == '.txt':
tokenizer_file_extension = tokenizer_file_path.suffix
if tokenizer_file_extension == '.txt':
with open(tokenizer_file_path) as _tokenizer_file:
tokenizer_file_content = _tokenizer_file.read().split('\n')
elif tokenizer_file_extension == '.json':
elif tokenizer_file_extension == '.json':
with open(tokenizer_file_path) as _tokenizer_file:
tokenizer_file_content = json.load(_tokenizer_file)
else:
raise ValueError(
f'Unexpected file ending {tokenizer_file_name} in output of tokenizer.save_pretrained.')
elif tokenizer_file_extension == '.model':
try:
import sentencepiece as spm
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='sentencepiece',
conda_package='sentencepiece') from e
s = spm.SentencePieceProcessor(model_file=str(tokenizer_file_path))
tokenizer_file_content = s.serialized_model_proto()
else:
raise ValueError(
f'Unexpected file ending {tokenizer_file_name} in output of tokenizer.save_pretrained.')
tokenizer_output[tokenizer_file_path.stem] = {
'file_extension': tokenizer_file_extension,
'content': tokenizer_file_content
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def package_files(prefix: str, directory: str, extension: str):
'datasets>=2.4,<3',
]

extra_deps['sentencepiece'] = ['sentencepiece==0.1.97']

extra_deps['mlperf'] = [
# TODO: use pip when available: https://github.com/mlcommons/logging/issues/218
# "mlperf_logging @ git+https://github.com/mlperf/logging.git",
Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def check_hf_tokenizer_equivalence(tokenizer1, tokenizer2):
tokenizer1.__dict__['init_kwargs'].pop('tokenizer_file', None)
tokenizer2.__dict__['init_kwargs'].pop('tokenizer_file', None)

# vocab_file will be the path that the tokenizer was loaded from, which will just be a temporary directory for
# the reloaded tokenizer, so we remove it and don't compare it between the two tokenizers
tokenizer1.__dict__.pop('vocab_file', None)
tokenizer2.__dict__.pop('vocab_file', None)

assert tokenizer1.__dict__ == tokenizer2.__dict__


Expand Down Expand Up @@ -350,6 +355,29 @@ def test_hf_loading_load_save_paths(checkpoint_upload_path: Optional[str], local
assert os.path.getsize(local_save_checkpoint_path) > 1000


@pytest.mark.parametrize('modify_tokenizer', [False, True])
def test_hf_loading_sentencepiece_tokenizer(modify_tokenizer: bool, tmp_path: Path, tiny_t5_model):
transformers = pytest.importorskip('transformers')

t0_pp_tokenizer = transformers.AutoTokenizer.from_pretrained('bigscience/T0pp')

if modify_tokenizer:
assert t0_pp_tokenizer is not None # pyright
t0_pp_tokenizer.add_special_tokens({'bos_token': '[NEWSPECIAL]'})
t0_pp_tokenizer.add_special_tokens({'additional_special_tokens': ['[MOSAICML']})
t0_pp_tokenizer.add_tokens(['totallyarealtoken', 'mosaicml'])
tiny_t5_model.resize_token_embeddings(len(t0_pp_tokenizer))

trainer = get_lm_trainer(tiny_t5_model, t0_pp_tokenizer, str(tmp_path), is_conditional_generation=True)
trainer.save_checkpoint(str(tmp_path / 'hf-checkpoint.pt'))

hf_loaded_model, hf_loaded_tokenizer = HuggingFaceModel.hf_from_composer_checkpoint(
checkpoint_path=str(tmp_path / 'hf-checkpoint.pt'))

check_hf_model_equivalence(hf_loaded_model, tiny_t5_model)
check_hf_tokenizer_equivalence(hf_loaded_tokenizer, t0_pp_tokenizer)


@pytest.mark.parametrize('modify_tokenizer', [False, True])
def test_hf_loading_tokenizer(modify_tokenizer: bool, tmp_path: Path, tiny_bert_model, tiny_bert_tokenizer):
pytest.importorskip('transformers')
Expand Down

0 comments on commit 98a2699

Please sign in to comment.