From a8e878351f477943c2e33e8a2fab3fe3b2a8bb02 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 20 Sep 2023 18:32:23 -0700 Subject: [PATCH] Support for using tiktoken tokenizers (#610) --- .pre-commit-config.yaml | 4 + llmfoundry/__init__.py | 2 + .../callbacks/monolithic_ckpt_callback.py | 3 + .../models/inference_api_wrapper/__init__.py | 3 +- .../inference_api_wrapper/openai_causal_lm.py | 93 +----- llmfoundry/tokenizers/__init__.py | 8 + llmfoundry/tokenizers/tiktoken.py | 290 ++++++++++++++++++ llmfoundry/utils/builders.py | 13 +- llmfoundry/utils/config_utils.py | 2 +- pyproject.toml | 2 +- scripts/data_prep/convert_dataset_hf.py | 12 +- tests/horrible_strings.py | 106 +++++++ tests/test_dataloader.py | 2 + tests/test_inference_api_eval_wrapper.py | 10 +- tests/test_tiktoken.py | 203 ++++++++++++ tests/test_training.py | 1 + 16 files changed, 650 insertions(+), 104 deletions(-) create mode 100644 llmfoundry/tokenizers/__init__.py create mode 100644 llmfoundry/tokenizers/tiktoken.py create mode 100644 tests/horrible_strings.py create mode 100644 tests/test_tiktoken.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 66990493ae..d4c8cc699c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,7 @@ repos: hooks: - id: yapf name: yapf + exclude: tests/horrible_strings.py description: A formatter for Python files. entry: yapf args: [-i, -vv, -p] # inplace @@ -50,6 +51,7 @@ repos: - id: debug-statements - id: destroyed-symlinks - id: double-quote-string-fixer + exclude: tests/horrible_strings.py - id: end-of-file-fixer - id: fix-byte-order-marker - id: mixed-line-ending @@ -93,6 +95,7 @@ repos: hooks: - id: pyright name: pyright + exclude: tests/horrible_strings.py entry: pyright language: node types: [python] @@ -104,6 +107,7 @@ repos: hooks: - id: trufflehog name: secret scan + exclude: tests/horrible_strings.py entry: trufflehog filesystem ./ args: - --only-verified diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py index 60560f120e..a60781dc3d 100644 --- a/llmfoundry/__init__.py +++ b/llmfoundry/__init__.py @@ -23,6 +23,7 @@ from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig, MPTForCausalLM, MPTModel, MPTPreTrainedModel) + from llmfoundry.tokenizers import TiktokenTokenizerWrapper except ImportError as e: try: @@ -64,6 +65,7 @@ 'build_alibi_bias', 'optim', 'utils', + 'TiktokenTokenizerWrapper', ] __version__ = '0.2.0' diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index 6d72762323..a71db9bd43 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -74,6 +74,9 @@ def _save_checkpoint(self, state: State, logger: Logger) -> None: ) if self.upload_to_object_store else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: + # pyright doesn't know about enter_result + assert isinstance(temp_save_dir, str) + save_path = str(Path(temp_save_dir) / Path(filename)) dirname = os.path.dirname(save_path) if dirname: diff --git a/llmfoundry/models/inference_api_wrapper/__init__.py b/llmfoundry/models/inference_api_wrapper/__init__.py index b9cd71ad47..496abf2aa6 100644 --- a/llmfoundry/models/inference_api_wrapper/__init__.py +++ b/llmfoundry/models/inference_api_wrapper/__init__.py @@ -4,11 +4,10 @@ from llmfoundry.models.inference_api_wrapper.interface import \ InferenceAPIEvalWrapper from llmfoundry.models.inference_api_wrapper.openai_causal_lm import ( - OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAITokenizerWrapper) + OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper) __all__ = [ 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', - 'OpenAITokenizerWrapper', 'InferenceAPIEvalWrapper', ] diff --git a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py index 72dfd9db76..609112b944 100644 --- a/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py +++ b/llmfoundry/models/inference_api_wrapper/openai_causal_lm.py @@ -11,8 +11,7 @@ import torch from composer.core.types import Batch from composer.utils.import_helpers import MissingConditionalImportError -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer log = logging.getLogger(__name__) @@ -20,93 +19,13 @@ InferenceAPIEvalWrapper __all__ = [ - 'OpenAICausalLMEvalWrapper', 'OpenAIChatAPIEvalWrapper', - 'OpenAITokenizerWrapper' + 'OpenAICausalLMEvalWrapper', + 'OpenAIChatAPIEvalWrapper', ] -Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] - MAX_RETRIES = 10 -class OpenAITokenizerWrapper(AutoTokenizer): - # this API is experimental and for evaluation only. It is subject to change as we add support for training - def __init__(self, name: str) -> None: - try: - import tiktoken - except ImportError as e: - raise MissingConditionalImportError( - extra_deps_group='openai', - conda_package='tiktoken', - conda_channel='conda-forge') from e - self.tokenizer = tiktoken.encoding_for_model(name) - - def __call__(self, x: str, add_special_tokens: bool = False): - if add_special_tokens: - raise ValueError( - 'OpenAITokenizerWrapper only supports add_special_tokens=False') - return self.encode(x) - - def encode(self, - x: Union[str, List[str]], - add_special_tokens: bool = False): - if add_special_tokens: - raise ValueError( - 'OpenAITokenizerWrapper only supports add_special_tokens=False') - if isinstance(x, str): - return { - 'input_ids': - self.tokenizer.encode(x, allowed_special={'<|endoftext|>'}) - } - elif isinstance(x, - list): # pyright: ignore [reportUnnecessaryIsInstance] - return { - 'input_ids': - self.tokenizer.encode_batch( - x, allowed_special={'<|endoftext|>'}) - } - else: - raise ValueError( - f'`encode` argument must be str or List[str], got: {type(x)}') - - def decode( - self, - x: Union[List[int], List[List[int]]], - ): - if len(x) > 0 and isinstance(x[0], list): - return self.tokenizer.decode_batch( - x) # pyright: ignore [reportGeneralTypeIssues] - else: - assert isinstance(x, list) - return self.tokenizer.decode( - x) # pyright: ignore [reportGeneralTypeIssues] - - @property - def pad_token_id(self): - return self.tokenizer.eot_token - - @property - def eos_token_id(self): - return self.tokenizer.eot_token - - @property - def vocab_size(self): - return self.tokenizer.n_vocab - - def construct_logit_tensor(self, logprobs: Dict[str, float]): - """Construct tensor of shape (vocab_size,) mapping words to logprobs. - - Args: - logprobs (Dict[str, float]): Dictionary mapping tokens to log probabilities assigned to them by the model. - """ - tensor = torch.tensor([min(logprobs.values()) - 1] * (self.vocab_size)) - for k in logprobs: - encoding = self.encode(k)['input_ids'] - idx = encoding[0] - tensor[idx] = logprobs[k] - return tensor - - class OpenAIEvalInterface(InferenceAPIEvalWrapper): def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None: @@ -185,7 +104,7 @@ def retokenize(self, tokens: List[int], cont_idxs: List[int]): re-tokenize with the space removed. """ original_len = len(tokens) - retokenized_continuation = self.tokenizer.encode( + retokenized_continuation = self.tokenizer( self.tokenizer.decode(tokens[cont_idxs[0]:cont_idxs[-1] + 1]).strip())['input_ids'] @@ -275,8 +194,8 @@ def process_result(self, completion: Optional[dict]): assert isinstance(completion, dict) if len(completion['choices']) > 0: tensors = [] - for t in self.tokenizer.encode(completion['choices'][0]['message'] - ['content'])['input_ids']: + for t in self.tokenizer(completion['choices'][0]['message'] + ['content'])['input_ids']: tensors.append( self.tokenizer.construct_logit_tensor( {self.tokenizer.decode([t]): 0.0})) diff --git a/llmfoundry/tokenizers/__init__.py b/llmfoundry/tokenizers/__init__.py new file mode 100644 index 0000000000..1703ed8862 --- /dev/null +++ b/llmfoundry/tokenizers/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper + +__all__ = [ + 'TiktokenTokenizerWrapper', +] diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py new file mode 100644 index 0000000000..001be6a030 --- /dev/null +++ b/llmfoundry/tokenizers/tiktoken.py @@ -0,0 +1,290 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import PreTrainedTokenizer + + +class TiktokenTokenizerWrapper(PreTrainedTokenizer): + """A thin wrapper around tiktoken to make it compatible with Hugging Face. + + tokenizers. + + See HuggingFace for further documentation on general tokenizer methods. + """ + + model_input_names = ['input_ids', 'attention_mask'] + + def __init__(self, + model_name: Optional[str] = None, + encoding_name: Optional[str] = None, + add_bos_token: bool = False, + unk_token: Optional[str] = '<|endoftext|>', + eos_token: Optional[str] = '<|endoftext|>', + bos_token: Optional[str] = '<|endoftext|>', + pad_token: Optional[str] = None, + **kwargs: Dict[str, Any]): + """Constructor creates a tiktoken tokenizer to use as the underlying. + + tokenizer. + + Args: + model_name (Optional[str], optional): The name of the model to load from tiktoken. Defaults to None. + Either model_name or encoding_name must be set, but not both. + encoding_name (Optional[str], optional): The name of the encoding to load from tiktoken. Defaults to None. + Either model_name or encoding_name must be set, but not both. + add_bos_token (bool, optional): Whether to add bos tokens. Defaults to False. + unk_token (Optional[str], optional): The unk token. Defaults to '<|endoftext|>'. + eos_token (Optional[str], optional): The eos token. Defaults to '<|endoftext|>'. + bos_token (Optional[str], optional): The bos token. Defaults to '<|endoftext|>'. + pad_token (Optional[str], optional): The pad token. Defaults to None. + """ + try: + import tiktoken + except: + raise ImportError( + 'You need to install tiktoken to use TiktokenTokenizerWrapper.') + + if model_name is not None and encoding_name is not None: + raise ValueError( + 'You need to specify either model_name or encoding_name, not both.' + ) + + self.model_name = model_name + self.encoding_name = encoding_name + + if self.model_name is not None: + self.encoding = tiktoken.encoding_for_model( # type: ignore (thirdParty) + self.model_name) + elif self.encoding_name is not None: + self.encoding = tiktoken.get_encoding( # type: ignore (thirdParty) + self.encoding_name) + else: + raise ValueError( + 'You need to specify either model_name or encoding_name.') + + self.add_bos_token = add_bos_token + + super().__init__(model_name=model_name, + encoding_name=encoding_name, + add_bos_token=add_bos_token, + unk_token=unk_token, + eos_token=eos_token, + bos_token=bos_token, + pad_token=pad_token, + **kwargs) + + @property + def vocab_size(self) -> int: + """Returns vocab size.""" + return self.encoding.n_vocab + + @property + def is_fast(self) -> bool: + return False + + def get_vocab(self) -> Dict[str, int]: + """Returns vocab as a dict.""" + vocab = {} + for i in range(self.vocab_size): + try: + # need to try this first, so that we get a proper KeyError, + # otherwise it crashes in the rust code + _ = self.encoding.decode_single_token_bytes(i) + vocab[self.encoding.decode([i])] = i + except KeyError: + pass + + return vocab + + def _tokenize(self, text: str) -> List[int]: + """Returns a tokenized string. + + Note: We have slightly redefined the expected contract between this method and + the _convert_token_to_id method. Normally, this method turns a string, into a list of strings, + and then the _convert_token_to_id method turns that list of strings into a list of integers. + However, not all vocab indices can be decoded into a string, so instead we just return the integers + from this function, and have adjusted the _convert_token_to_id method to handle integers as well as strings. + The only use of _tokenize that I could find was in this way, so this _should_ be safe. + """ + if not isinstance(text, str): + raise ValueError( + f'Expected a string input to _tokenize but got {type(text)}.') + + tokens = [t for t in self.encoding.encode(text, allowed_special='all')] + + return tokens + + def _convert_token_to_id(self, token: Union[int, str]) -> int: + """Converts a token (str) into an id using the vocab.""" + if isinstance(token, int): + return token + + return self.encoding.encode(token, allowed_special='all')[0] + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) into a token (str) using the vocab.""" + return self.encoding.decode([index]) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (string) in a single string.""" + return ''.join(tokens) + + def convert_ids_to_tokens( + self, + ids: Union[int, List[int]], + skip_special_tokens: bool = False) -> Union[str, List[str]]: + """Converts a single index or a sequence of indices into a token or a. + + sequence of tokens, using the vocabulary and added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + if ids in self.added_tokens_decoder: + return self.added_tokens_decoder[ids] + + return self._convert_id_to_token(ids) + + # current_stream will collect multiple tokens, and then separately add items + # for each added token. This is done so that decode works properly with token ids + # that cannot be represented naively in utf-8. + tokens = [] + current_stream = [] + for index in ids: + if skip_special_tokens and index in self.all_special_ids: + continue + + if index in self.added_tokens_decoder: + tokens.append(self.encoding.decode(current_stream)) + current_stream = [] + tokens.append(self.added_tokens_decoder[index]) + else: + current_stream.append(index) + + if len(current_stream) > 0: + tokens.append(self.encoding.decode(current_stream)) + return tokens + + def build_inputs_with_special_tokens( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is None: + return output + + return output + bos_token_ids + token_ids_1 + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False) -> List[int]: + """Retrieves sequence ids from a token list that has no special tokens. + + Function copied from + https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/models/gpt2/tokenization_gpt2.py#L265-L295 + + added. This method is called when adding special tokens using the + tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True) + + if not self.add_bos_token: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=False) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + + def create_token_type_ids_from_sequences( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None) -> List[int]: + sep = [self.sep_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + sep) * [0] + return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, + save_directory: str, + filename_prefix: Optional[str] = None) -> Tuple[str]: + + # ignore the below type to keep the original signature + # we are knowingly breaking the signature here, although not 100% certain + # it doesn't have side effects + # There is some code in huggingface that calls this function to get the vocab files, + # but it doesn't seem to access them (or at least checks for their existence + # before accessing them) + return (None, None) # type: ignore + + def sanitize_special_tokens(self) -> int: + """Make sure that all the special tokens attributes of the tokenizer. + + (`tokenizer.mask_token`, `tokenizer.cls_token`, etc.) are in the + vocabulary. + + Add the missing ones to the vocabulary if needed. + + Return: + `int`: The number of tokens added in the vocabulary during the operation. + """ + actual_new_tokens = [] + for token in self.all_special_tokens_extended: + encoded = self.encoding.encode(token, allowed_special='all') + if len(encoded) > 1: + actual_new_tokens.append(token) + + return self.add_tokens(actual_new_tokens, special_tokens=True) + + def construct_logit_tensor(self, logprobs: Dict[str, + float]) -> torch.Tensor: + """Construct tensor of shape (vocab_size,) mapping words to logprobs. + + Args: + logprobs (Dict[str, float]): Dictionary mapping tokens to log probabilities assigned to them by the model. + """ + tensor = torch.tensor([min(logprobs.values()) - 1] * (self.vocab_size)) + for k in logprobs: + encoding = self(k)['input_ids'] + idx = encoding[0] + tensor[idx] = logprobs[k] + return tensor + + +TiktokenTokenizerWrapper.register_for_auto_class() diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index b89ff899ee..071fb98ed8 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -30,10 +30,9 @@ GlobalLRScaling, HuggingFaceCheckpointer, LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) -from llmfoundry.models.inference_api_wrapper.openai_causal_lm import \ - OpenAITokenizerWrapper from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) +from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper log = logging.getLogger(__name__) @@ -168,12 +167,12 @@ def build_scheduler(name: str, def build_tokenizer( tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase: - if tokenizer_name == 'openai': - return OpenAITokenizerWrapper(**tokenizer_kwargs) - else: - os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' - os.environ['TOKENIZERS_PARALLELISM'] = 'false' + os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + if tokenizer_name.startswith('tiktoken'): + tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) + else: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 8690271874..6680154e87 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -47,7 +47,7 @@ def pop_config(cfg: DictConfig, def calculate_batch_size_info( - global_batch_size: int, device_microbatch_size: Union[int, str] + global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']] ) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: if global_batch_size % dist.get_world_size() != 0: raise ValueError( diff --git a/pyproject.toml b/pyproject.toml index efa8a7b582..a2fcec3eed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ exclude = ['env-**', 'venv*', '**/flash_attn_triton.py'] ignore = ['llmfoundry/models/layers/flash_attn_triton.py'] stubPath = "" # suppress useless 'stubPath is not a valid directory' errors -reportUnnecessaryIsInstance = "warning" +reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety reportMissingTypeStubs = "none" reportIncompatibleMethodOverride = "none" reportIncompatibleVariableOverride = "error" diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index fee56de54e..964b05ed09 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Streaming dataset conversion scripts for C4 and The Pile.""" +import json import os import platform from argparse import ArgumentParser, Namespace @@ -14,9 +15,10 @@ from streaming import MDSWriter from torch.utils.data import DataLoader, Dataset, IterableDataset from tqdm import tqdm -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase from llmfoundry.data import ConcatTokensDataset, NoConcatDataset +from llmfoundry.utils.builders import build_tokenizer class ConcatMode(Enum): @@ -49,6 +51,7 @@ def parse_args() -> Namespace: help='Convert text to tokens and concatenate up to this many tokens') parser.add_argument('--tokenizer', type=str, required=False, default=None) + parser.add_argument('--tokenizer_kwargs', type=str, required=False) parser.add_argument('--bos_text', type=str, required=False, default=None) parser.add_argument('--eos_text', type=str, required=False, default=None) parser.add_argument('--no_wrap', default=False, action='store_true') @@ -56,6 +59,11 @@ def parse_args() -> Namespace: parsed = parser.parse_args() + if parsed.tokenizer_kwargs is not None: + parsed.tokenizer_kwargs = json.loads(parsed.tokenizer_kwargs) + else: + parsed.tokenizer_kwargs = {} + if os.path.isdir(parsed.out_root) and len( set(os.listdir(parsed.out_root)).intersection(set( parsed.splits))) > 0: @@ -316,7 +324,7 @@ def main(args: Namespace) -> None: if args.concat_tokens is not None: mode = ConcatMode.CONCAT_TOKENS - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs) # we will enforce length, so suppress warnings about sequences too long for the model tokenizer.model_max_length = int(1e30) columns = {'tokens': 'bytes'} diff --git a/tests/horrible_strings.py b/tests/horrible_strings.py new file mode 100644 index 0000000000..31cd55cb9b --- /dev/null +++ b/tests/horrible_strings.py @@ -0,0 +1,106 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# taken from https://github.com/explosion/spaCy/blob/8f0d6b0a8c42e4852bf6e24cdf629043f2f39361/spacy/tests/tokenizer/test_naughty_strings.py#L7 +HORRIBLE_STRINGS = [ + # ASCII punctuation + r",./;'[]\-=", + r'<>?:"{}|_+', + r'!@#$%^&*()`~"', + # Unicode additional control characters, byte order marks + r"­؀؁؂؃؄؅؜۝܏᠎​‌‍‎‏‪", + r"￾", + # Unicode Symbols + r"Ω≈ç√∫˜µ≤≥÷", + r"åß∂ƒ©˙∆˚¬…æ", + "œ∑´®†¥¨ˆøπ“‘", + r"¡™£¢∞§¶•ªº–≠", + r"¸˛Ç◊ı˜Â¯˘¿", + r"ÅÍÎÏ˝ÓÔÒÚÆ☃", + r"Œ„´‰ˇÁ¨ˆØ∏”’", + r"`⁄€‹›fifl‡°·‚—±", + r"⅛⅜⅝⅞", + r"ЁЂЃЄЅІЇЈЉЊЋЌЍЎЏАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюя", + r"٠١٢٣٤٥٦٧٨٩", + # Unicode Subscript/Superscript/Accents + r"⁰⁴⁵", + r"₀₁₂", + r"⁰⁴⁵₀₁₂", + r"ด้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็ ด้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็ ด้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็็้้้้้้้้็็็็็้้้้้็็็็", + r" ̄ ̄", + # Two-Byte Characters + r"田中さんにあげて下さい", + r"パーティーへ行かないか", + r"和製漢語", + r"部落格", + r"사회과학원 어학연구소", + r"찦차를 타고 온 펲시맨과 쑛다리 똠방각하", + r"社會科學院語學研究所", + r"울란바토르", + r"𠜎𠜱𠝹𠱓𠱸𠲖𠳏", + # Japanese Emoticons + r"ヽ༼ຈل͜ຈ༽ノ ヽ༼ຈل͜ຈ༽ノ", + r"(。◕ ∀ ◕。)", + r"`ィ(´∀`∩", + r"__ロ(,_,*)", + r"・( ̄∀ ̄)・:*:", + r"゚・✿ヾ╲(。◕‿◕。)╱✿・゚", + r",。・:*:・゜’( ☻ ω ☻ )。・:*:・゜’", + r"(╯°□°)╯︵ ┻━┻)" "(ノಥ益ಥ)ノ ┻━┻", # type: ignore + r"┬─┬ノ( º _ ºノ)", + r"( ͡° ͜ʖ ͡°)", + # Emoji + r"😍", + r"👩🏽", + r"👾 🙇 💁 🙅 🙆 🙋 🙎 🙍", + r"🐵 🙈 🙉 🙊", + r"❤️ 💔 💌 💕 💞 💓 💗 💖 💘 💝 💟 💜 💛 💚 💙", + r"✋🏿 💪🏿 👐🏿 🙌🏿 👏🏿 🙏🏿", + r"🚾 🆒 🆓 🆕 🆖 🆗 🆙 🏧", + r"0️⃣ 1️⃣ 2️⃣ 3️⃣ 4️⃣ 5️⃣ 6️⃣ 7️⃣ 8️⃣ 9️⃣ 🔟", + # Regional Indicator Symbols + r"🇺🇸🇷🇺🇸 🇦🇫🇦🇲🇸", + r"🇺🇸🇷🇺🇸🇦🇫🇦🇲", + r"🇺🇸🇷🇺🇸🇦", + # Unicode Numbers + r"123", + r"١٢٣", + # Right-To-Left Strings + r"ثم نفس سقطت وبالتحديد،, جزيرتي باستخدام أن دنو. إذ هنا؟ الستار وتنصيب كان. أهّل ايطاليا، بريطانيا-فرنسا قد أخذ. سليمان، إتفاقية بين ما, يذكر الحدود أي بعد, معاملة بولندا، الإطلاق عل إيو.", + r"إيو.", + r"בְּרֵאשִׁית, בָּרָא אֱלֹהִים, אֵת הַשָּׁמַיִם, וְאֵת הָאָרֶץ", + r"הָיְתָהtestالصفحات التّحول", + r"﷽", + r"ﷺ", + r"مُنَاقَشَةُ سُبُلِ اِسْتِخْدَامِ اللُّغَةِ فِي النُّظُمِ الْقَائِمَةِ وَفِيم يَخُصَّ التَّطْبِيقَاتُ الْحاسُوبِيَّةُ،", + # Trick Unicode + r"‪‪test‪", + r"‫test", + r"
test
", + r"test⁠test", + r"⁦test⁧", + # Zalgo Text + r"Ṱ̺̺̕o͞ ̷i̲̬͇̪͙n̝̗͕v̟̜̘̦͟o̶̙̰̠kè͚̮̺̪̹̱̤ ̖t̝͕̳̣̻̪͞h̼͓̲̦̳̘̲e͇̣̰̦̬͎ ̢̼̻̱̘h͚͎͙̜̣̲ͅi̦̲̣̰̤v̻͍e̺̭̳̪̰-m̢iͅn̖̺̞̲̯̰d̵̼̟͙̩̼̘̳ ̞̥̱̳̭r̛̗̘e͙p͠r̼̞̻̭̗e̺̠̣͟s̘͇̳͍̝͉e͉̥̯̞̲͚̬͜ǹ̬͎͎̟̖͇̤t͍̬̤͓̼̭͘ͅi̪̱n͠g̴͉ ͏͉ͅc̬̟h͡a̫̻̯͘o̫̟̖͍̙̝͉s̗̦̲.̨̹͈̣", + r"̡͓̞ͅI̗̘̦͝n͇͇͙v̮̫ok̲̫̙͈i̖͙̭̹̠̞n̡̻̮̣̺g̲͈͙̭͙̬͎ ̰t͔̦h̞̲e̢̤ ͍̬̲͖f̴̘͕̣è͖ẹ̥̩l͖͔͚i͓͚̦͠n͖͍̗͓̳̮g͍ ̨o͚̪͡f̘̣̬ ̖̘͖̟͙̮c҉͔̫͖͓͇͖ͅh̵̤̣͚͔á̗̼͕ͅo̼̣̥s̱͈̺̖̦̻͢.̛̖̞̠̫̰", + r"̗̺͖̹̯͓Ṯ̤͍̥͇͈h̲́e͏͓̼̗̙̼̣͔ ͇̜̱̠͓͍ͅN͕͠e̗̱z̘̝̜̺͙p̤̺̹͍̯͚e̠̻̠͜r̨̤͍̺̖͔̖̖d̠̟̭̬̝͟i̦͖̩͓͔̤a̠̗̬͉̙n͚͜ ̻̞̰͚ͅh̵͉i̳̞v̢͇ḙ͎͟-҉̭̩̼͔m̤̭̫i͕͇̝̦n̗͙ḍ̟ ̯̲͕͞ǫ̟̯̰̲͙̻̝f ̪̰̰̗̖̭̘͘c̦͍̲̞͍̩̙ḥ͚a̮͎̟̙͜ơ̩̹͎s̤.̝̝ ҉Z̡̖̜͖̰̣͉̜a͖̰͙̬͡l̲̫̳͍̩g̡̟̼̱͚̞̬ͅo̗͜.̟", + r"̦H̬̤̗̤͝e͜ ̜̥̝̻͍̟́w̕h̖̯͓o̝͙̖͎̱̮ ҉̺̙̞̟͈W̷̼̭a̺̪͍į͈͕̭͙̯̜t̶̼̮s̘͙͖̕ ̠̫̠B̻͍͙͉̳ͅe̵h̵̬͇̫͙i̹͓̳̳̮͎̫̕n͟d̴̪̜̖ ̰͉̩͇͙̲͞ͅT͖̼͓̪͢h͏͓̮̻e̬̝̟ͅ ̤̹̝W͙̞̝͔͇͝ͅa͏͓͔̹̼̣l̴͔̰̤̟͔ḽ̫.͕", + r"Z̮̞̠͙͔ͅḀ̗̞͈̻̗Ḷ͙͎̯̹̞͓G̻O̭̗̮", + # Unicode Upsidedown + r"˙ɐnbᴉlɐ ɐuƃɐɯ ǝɹolop ʇǝ ǝɹoqɐl ʇn ʇunpᴉpᴉɔuᴉ ɹodɯǝʇ poɯsnᴉǝ op pǝs 'ʇᴉlǝ ƃuᴉɔsᴉdᴉpɐ ɹnʇǝʇɔǝsuoɔ 'ʇǝɯɐ ʇᴉs ɹolop ɯnsdᴉ ɯǝɹo˥", + r"00˙Ɩ$-", + # Unicode font + r"The quick brown fox jumps over the lazy dog", + r"𝐓𝐡𝐞 𝐪𝐮𝐢𝐜𝐤 𝐛𝐫𝐨𝐰𝐧 𝐟𝐨𝐱 𝐣𝐮𝐦𝐩𝐬 𝐨𝐯𝐞𝐫 𝐭𝐡𝐞 𝐥𝐚𝐳𝐲 𝐝𝐨𝐠", + r"𝕿𝖍𝖊 𝖖𝖚𝖎𝖈𝖐 𝖇𝖗𝖔𝖜𝖓 𝖋𝖔𝖝 𝖏𝖚𝖒𝖕𝖘 𝖔𝖛𝖊𝖗 𝖙𝖍𝖊 𝖑𝖆𝖟𝖞 𝖉𝖔𝖌", + r"𝑻𝒉𝒆 𝒒𝒖𝒊𝒄𝒌 𝒃𝒓𝒐𝒘𝒏 𝒇𝒐𝒙 𝒋𝒖𝒎𝒑𝒔 𝒐𝒗𝒆𝒓 𝒕𝒉𝒆 𝒍𝒂𝒛𝒚 𝒅𝒐𝒈", + r"𝓣𝓱𝓮 𝓺𝓾𝓲𝓬𝓴 𝓫𝓻𝓸𝔀𝓷 𝓯𝓸𝔁 𝓳𝓾𝓶𝓹𝓼 𝓸𝓿𝓮𝓻 𝓽𝓱𝓮 𝓵𝓪𝔃𝔂 𝓭𝓸𝓰", + r"𝕋𝕙𝕖 𝕢𝕦𝕚𝕔𝕜 𝕓𝕣𝕠𝕨𝕟 𝕗𝕠𝕩 𝕛𝕦𝕞𝕡𝕤 𝕠𝕧𝕖𝕣 𝕥𝕙𝕖 𝕝𝕒𝕫𝕪 𝕕𝕠𝕘", + r"𝚃𝚑𝚎 𝚚𝚞𝚒𝚌𝚔 𝚋𝚛𝚘𝚠𝚗 𝚏𝚘𝚡 𝚓𝚞𝚖𝚙𝚜 𝚘𝚟𝚎𝚛 𝚝𝚑𝚎 𝚕𝚊𝚣𝚢 𝚍𝚘𝚐", + r"⒯⒣⒠ ⒬⒰⒤⒞⒦ ⒝⒭⒪⒲⒩ ⒡⒪⒳ ⒥⒰⒨⒫⒮ ⒪⒱⒠⒭ ⒯⒣⒠ ⒧⒜⒵⒴ ⒟⒪⒢", + # File paths + r"../../../../../../../../../../../etc/passwd%00", + r"../../../../../../../../../../../etc/hosts", + # iOS Vulnerabilities + r"Powerلُلُصّبُلُلصّبُررً ॣ ॣh ॣ ॣ冗", + r"🏳0🌈️", +] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 72bfac1d08..eea887d663 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -92,6 +92,7 @@ def test_correct_padding(tokenizer_name: str, 'compression': None, 'concat_tokens': 2048, 'tokenizer': tokenizer_name, + 'tokenizer_kwargs': {}, 'bos_text': bos_text, 'eos_text': eos_text, 'no_wrap': False, @@ -108,6 +109,7 @@ def test_correct_padding(tokenizer_name: str, 'compression': None, 'concat_tokens': None, 'tokenizer': tokenizer_name, + 'tokenizer_kwargs': {}, 'bos_text': bos_text, 'eos_text': eos_text, 'no_wrap': False, diff --git a/tests/test_inference_api_eval_wrapper.py b/tests/test_inference_api_eval_wrapper.py index ba065b6020..6e5f91de00 100644 --- a/tests/test_inference_api_eval_wrapper.py +++ b/tests/test_inference_api_eval_wrapper.py @@ -8,8 +8,8 @@ from omegaconf import DictConfig, ListConfig from llmfoundry.models.inference_api_wrapper import (OpenAICausalLMEvalWrapper, - OpenAIChatAPIEvalWrapper, - OpenAITokenizerWrapper) + OpenAIChatAPIEvalWrapper) +from llmfoundry.tokenizers import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_icl_evaluators @@ -84,7 +84,8 @@ def test_openai_api_eval_wrapper(tmp_path: str): with patch('openai.Completion') as mock: mock.create = mock_create model_name = 'davinci' - tokenizer = OpenAITokenizerWrapper(model_name) + tokenizer = TiktokenTokenizerWrapper(model_name=model_name, + pad_token='<|endoftext|>') model = OpenAICausalLMEvalWrapper(model_cfg={'version': model_name}, tokenizer=tokenizer) task_cfg = load_icl_config() @@ -118,7 +119,8 @@ def test_chat_api_eval_wrapper(tmp_path: str): }], } model_name = 'gpt-3.5-turbo' - tokenizer = OpenAITokenizerWrapper(model_name) + tokenizer = TiktokenTokenizerWrapper(model_name=model_name, + pad_token='<|endoftext|>') chatmodel = OpenAIChatAPIEvalWrapper(model_cfg={'version': model_name}, tokenizer=tokenizer) task_cfg = load_icl_config() diff --git a/tests/test_tiktoken.py b/tests/test_tiktoken.py new file mode 100644 index 0000000000..a255a5ffa7 --- /dev/null +++ b/tests/test_tiktoken.py @@ -0,0 +1,203 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import pathlib +from typing import TYPE_CHECKING, Optional, Tuple + +import pytest +import transformers + +from llmfoundry import TiktokenTokenizerWrapper +from tests.horrible_strings import HORRIBLE_STRINGS +from tests.test_hf_conversion_script import check_hf_tokenizer_equivalence + +if TYPE_CHECKING: + from tiktoken.core import Encoding + +TEST_STRINGS = [ + 'Hello world!', 'def hello_world(input: str):\n print(input)', + '0000000000000000000000000000', + '19234324 asas sf 119aASDFM AW3RAW-AF;;9900', '\n\n\n\nhello\n\t', + ' hello\n\t\\\\ goodbye!?*#&@!) ', + 'This is just a normal sentence. And here is another one!', + 'hello<|endoftext|>world', 'hello <|endoftext|> world', + 'hello <|endoftext|>', 'hello <|endoftext|> ', '<|endoftext}>', + '<|endoftext}> ', ' <|endoftext|>', + '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>', + '<|endoftext|> <|endoftext|> <|endoftext|> <|endoftext|>' +] + +TEST_STRINGS += HORRIBLE_STRINGS + +MODEL_OR_ENCODING_NAME_TO_NON_UTF8_TOKENS = { + 'gpt-4': 77, + 'gpt-3.5-turbo': 77, + 'text-davinci-003': 14, + 'cl100k_base': 77, +} + +MODEL_ENCODING_NAME_PARAMETRIZATION = [ + ('gpt-4', None), + ('gpt-3.5-turbo', None), + ('text-davinci-003', None), + (None, 'cl100k_base'), +] + + +def get_tokenizers_for_testing( + model_name: Optional[str], encoding_name: Optional[str], + tmp_path: pathlib.Path +) -> Tuple[TiktokenTokenizerWrapper, TiktokenTokenizerWrapper, 'Encoding']: + tiktoken = pytest.importorskip('tiktoken') + + # Construction + wrapped_tokenizer = TiktokenTokenizerWrapper(model_name=model_name, + encoding_name=encoding_name) + if model_name is not None: + original_tokenizer = tiktoken.encoding_for_model(model_name) + else: + original_tokenizer = tiktoken.get_encoding(encoding_name) + + # Repr works + _ = wrapped_tokenizer.__repr__() + + # Save and load + wrapped_tokenizer.save_pretrained(tmp_path) + reloaded_wrapped_tokenizer = transformers.AutoTokenizer.from_pretrained( + tmp_path, trust_remote_code=True) + + return wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer + + +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_simple(model_name: Optional[str], + encoding_name: Optional[str], tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + # Simple tokenization test + for string in TEST_STRINGS: + wrapped_output = wrapped_tokenizer(string) + original_output = original_tokenizer.encode(string, + allowed_special='all') + reloaded_wrapped_output = reloaded_wrapped_tokenizer(string) + + assert wrapped_output['input_ids'] == original_output + assert set(wrapped_output.keys()) == {'input_ids', 'attention_mask'} + assert reloaded_wrapped_output == wrapped_output + + +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_roundtrip(model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + for string in TEST_STRINGS: + wrapped_output = wrapped_tokenizer.decode( + wrapped_tokenizer(string)['input_ids']) + original_output = original_tokenizer.decode( + original_tokenizer.encode(string, allowed_special='all')) + reloaded_wrapped_output = reloaded_wrapped_tokenizer.decode( + reloaded_wrapped_tokenizer(string)['input_ids']) + assert wrapped_output == string + assert original_output == string + assert reloaded_wrapped_output == string + + +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_batched(model_name: Optional[str], + encoding_name: Optional[str], tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + wrapped_output = wrapped_tokenizer( + ['Hello world!', 'Hello world but longer!']) + original_output = original_tokenizer.encode_batch( + ['Hello world!', 'Hello world but longer!']) + reloaded_wrapped_output = reloaded_wrapped_tokenizer( + ['Hello world!', 'Hello world but longer!']) + assert wrapped_output['input_ids'] == original_output + assert set(wrapped_output.keys()) == {'input_ids', 'attention_mask'} + assert reloaded_wrapped_output == wrapped_output + assert wrapped_tokenizer.batch_decode( + wrapped_output['input_ids']) == original_tokenizer.decode_batch( + original_output) + assert reloaded_wrapped_tokenizer.batch_decode( + reloaded_wrapped_output['input_ids'] + ) == original_tokenizer.decode_batch(original_output) + + +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_padding(model_name: Optional[str], + encoding_name: Optional[str], tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + wrapped_tokenizer.pad_token_id = wrapped_tokenizer.eos_token_id + reloaded_wrapped_tokenizer.pad_token_id = reloaded_wrapped_tokenizer.eos_token_id + wrapped_output = wrapped_tokenizer( + ['Hello world!', 'Hello world but longer!'], padding=True) + original_output = original_tokenizer.encode_batch( + ['Hello world!', 'Hello world but longer!']) + reloaded_wrapped_output = reloaded_wrapped_tokenizer( + ['Hello world!', 'Hello world but longer!'], padding=True) + for wrapped1, attn_mask, original1 in zip(wrapped_output['input_ids'], + wrapped_output['attention_mask'], + original_output): + original_length = len(original1) + assert wrapped1[:original_length] == original1 + assert sum(attn_mask) == original_length + + assert set(wrapped_output.keys()) == {'input_ids', 'attention_mask'} + assert reloaded_wrapped_output == wrapped_output + + +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_vocab(model_name: Optional[str], encoding_name: Optional[str], + tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, original_tokenizer = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + + wrapped_vocab = wrapped_tokenizer.get_vocab() + reloaded_wrapped_vocab = reloaded_wrapped_tokenizer.get_vocab() + assert wrapped_vocab == reloaded_wrapped_vocab + + didnt_match = [] + for key, value in wrapped_vocab.items(): + if original_tokenizer.encode(key, allowed_special='all') == [value]: + continue + else: + didnt_match.append( + (key, original_tokenizer.encode(key, + allowed_special='all'), value)) + + # Decode is lossy because some bytes are not representable in utf-8 + # see https://github.com/openai/tiktoken/blob/39f29cecdb6fc38d9a3434e5dd15e4de58cf3c80/tiktoken/core.py#L245-L247 + # This means that the str: int vocab mapping doesn't work. Would have to look more into how other HF tokenizers handle this. + model_or_encoding_name = model_name or encoding_name + if model_or_encoding_name is not None: + expected_didnt_match = MODEL_OR_ENCODING_NAME_TO_NON_UTF8_TOKENS.get( + model_or_encoding_name) + assert len(didnt_match) == expected_didnt_match + else: + raise NotImplementedError( + 'Add the new tokenizer and how many tokens in the vocab are not utf8 representable.' + ) + + +@pytest.mark.parametrize('model_name,encoding_name', + MODEL_ENCODING_NAME_PARAMETRIZATION) +def test_tiktoken_save_from_pretrained(model_name: Optional[str], + encoding_name: Optional[str], + tmp_path: pathlib.Path): + wrapped_tokenizer, reloaded_wrapped_tokenizer, _ = get_tokenizers_for_testing( + model_name, encoding_name, tmp_path) + check_hf_tokenizer_equivalence(wrapped_tokenizer, + reloaded_wrapped_tokenizer) diff --git a/tests/test_training.py b/tests/test_training.py index 09254f79e9..e03703c859 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -36,6 +36,7 @@ def create_c4_dataset_xsmall(prefix: str) -> str: 'compression': None, 'concat_tokens': 2048, 'tokenizer': 'EleutherAI/gpt-neox-20b', + 'tokenizer_kwargs': {}, 'bos_text': '', 'eos_text': '<|endoftext|>', 'no_wrap': False,