Skip to content

Commit

Permalink
Merge branch 'main' into shashank/seq_id_flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML authored Nov 29, 2023
2 parents a964aea + 5f21855 commit 5765724
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 149 deletions.
14 changes: 6 additions & 8 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,17 @@ jobs:
GIT_SHA=$(echo ${{ github.sha }} | cut -c1-7)
echo "IMAGE_TAG=${GIT_SHA}" >> ${GITHUB_ENV}
if [ "${{ github.event_name }}" == "push" ]; then
echo "Triggered by push event."
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
elif [ "${{ github.event_name }}" == "pull_request" ]; then
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "Triggered by pull_request event."
STAGING_REPO="mosaicml/ci-staging"
IMAGE_TAG="${STAGING_REPO}:${{matrix.name}}-${GIT_SHA}"
IMAGE_CACHE="${STAGING_REPO}:${{matrix.name}}-buildcache"
else
echo "Triggered by unknown event: ${{ github.event_name }}"
exit 1
# Triggered by push or workflow_dispatch event
echo "Triggered by ${{ github.event_name }} event, releasing to prod"
PROD_REPO="mosaicml/llm-foundry"
IMAGE_TAG="${PROD_REPO}:${{matrix.name}}-${GIT_SHA},${PROD_REPO}:${{matrix.name}}-latest"
IMAGE_CACHE="${PROD_REPO}:${{matrix.name}}-buildcache"
fi
echo "IMAGE_TAG=${IMAGE_TAG}" >> ${GITHUB_ENV}
Expand Down
22 changes: 12 additions & 10 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -58,15 +58,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer):
device, or b) step() is executed on a non-CUDA parameter.
"""

def __init__(self,
params: Iterable[torch.Tensor],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0,
quantize: bool = True,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True, # XXX this flag is mostly for testing...
):

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down
169 changes: 77 additions & 92 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,46 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple

import torch
from transformers import PreTrainedTokenizer

DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible."""


# Taken from
# https://github.com/huggingface/transformers/blob/8aca43bdb3cb9a5020f6d57589d85679dc873b1c/src/transformers/models/gpt2/tokenization_gpt2.py#L62-L84
@lru_cache()
def bytes_to_unicode():
"""Returns list of utf-8 byte and a mapping to unicode strings.
We specifically avoids mapping to whitespace/control characters the bpe code
barfs on.
The reversible bpe codes work on unicode strings. This means you need a
large # of unicode characters in your vocab if you want to avoid UNKs. When
you're at something like a 10B token dataset you end up needing around 5K
for decent coverage. This is a significant percentage of your normal, say,
32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and
unicode strings.
"""
bs = (list(range(ord('!'),
ord('~') + 1)) + list(range(ord('¡'),
ord('¬') + 1)) +
list(range(ord('®'),
ord('ÿ') + 1)))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


class TiktokenTokenizerWrapper(PreTrainedTokenizer):
"""A thin wrapper around tiktoken to make it compatible with Hugging Face.
Expand Down Expand Up @@ -93,6 +124,28 @@ def pickle_Encoding(enc: Encoding):
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt

self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}

self.decoder: Dict[int, str] = {}
for i in range(self.encoding.n_vocab):
try:
self.encoding.decode_single_token_bytes(i)
except KeyError:
continue
# Taken from
# https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
decoding = ''.join([
bytes_to_unicode()[ord(char)] for char in
self.encoding.decode_single_token_bytes(i).decode('latin-1')
])
self.decoder[i] = decoding

self.encoder: Dict[str, int] = {}
for i in range(self.encoding.n_vocab):
if i in self.decoder:
self.encoder[self.decoder[i]] = i

super().__init__(model_name=model_name,
encoding_name=encoding_name,
add_bos_token=add_bos_token,
Expand Down Expand Up @@ -135,122 +188,54 @@ def default_chat_template(self):
return template

def get_vocab(self) -> Dict[str, int]:
"""Returns vocab as a dict.
Note: This function does not work properly due to difference in assumptions between tiktoken and Hugging Face tokenizers.
Most uses do not need to use get_vocab, so this is not a priority to fix.
"""
warnings.warn(
'get_vocab does not work properly with TiktokenTokenizerWrapper. Please do not rely on it being perfectly correct.'
+
' It will be called once init just to get the size of the vocab inside the base class.'
)

vocab = {}
for i in range(self.vocab_size):
try:
# need to try this first, so that we get a proper KeyError,
# otherwise it crashes in the rust code
_ = self.encoding.decode_single_token_bytes(i)
vocab[self.encoding.decode([i])] = i
except KeyError:
pass

"""Returns vocab as a dict."""
# As far as I can tell, we don't require get_vocab to completely work,
# but when using additional_special_tokens, Hugging Face determines the next
# token index to add with len(self.get_vocab()) so we need the _size_ of this dictionary to be correct.
vocab_clone = self.encoder.copy()
extra_id_index = 0
candidate_extra_id = f'<extra_id_{extra_id_index}>'
indices_to_fill_in = {i for i in range(self.vocab_size)} - set(
vocab.values())
vocab_clone.values())

# Add enough indices to make get_vocab() the right length
for index_to_add in indices_to_fill_in:
# Make sure we don't overwrite a token that already exists
while candidate_extra_id in vocab:
while candidate_extra_id in vocab_clone:
extra_id_index += 1
candidate_extra_id = f'<extra_id_{extra_id_index}>'

# Get an index to add and add the item
vocab[candidate_extra_id] = index_to_add
vocab_clone[candidate_extra_id] = index_to_add

return vocab
return vocab_clone

def _tokenize(self, text: str) -> List[int]:
"""Returns a tokenized string.
Note: We have slightly redefined the expected contract between this method and
the _convert_token_to_id method. Normally, this method turns a string, into a list of strings,
and then the _convert_token_to_id method turns that list of strings into a list of integers.
However, not all vocab indices can be decoded into a string, so instead we just return the integers
from this function, and have adjusted the _convert_token_to_id method to handle integers as well as strings.
The only use of _tokenize that I could find was in this way, so this _should_ be safe.
"""
def _tokenize(self, text: str) -> List[str]:
"""Returns a tokenized string."""
if not isinstance(text, str):
raise ValueError(
f'Expected a string input to _tokenize but got {type(text)}.')

tokens = [t for t in self.encoding.encode(text, allowed_special='all')]
tokens = [
self.decoder[t]
for t in self.encoding.encode(text, allowed_special='all')
]

return tokens

def _convert_token_to_id(self, token: Union[int, str]) -> int:
"""Converts a token (str) into an id using the vocab."""
if isinstance(token, int):
return token
def _convert_token_to_id(self, token: str) -> Optional[int]:
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))

return self.encoding.encode(token, allowed_special='all')[0]

def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) into a token (str) using the vocab."""
return self.encoding.decode([index])
def _convert_id_to_token(self, index: int) -> Optional[str]:
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index)

def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (string) in a single string."""
return ''.join(tokens)

def convert_ids_to_tokens(
self,
ids: Union[int, List[int]],
skip_special_tokens: bool = False) -> Union[str, List[str]]:
"""Converts a single index or a sequence of indices into a token or a.
sequence of tokens, using the vocabulary and added tokens.
Args:
ids (`int` or `List[int]`):
The token id (or token ids) to convert to tokens.
skip_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not to remove special tokens in the decoding.
Returns:
`str` or `List[str]`: The decoded token(s).
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return str(self.added_tokens_decoder[ids])

return self._convert_id_to_token(ids)

# current_stream will collect multiple tokens, and then separately add items
# for each added token. This is done so that decode works properly with token ids
# that cannot be represented naively in utf-8.
tokens = []
current_stream = []
for index in ids:
if skip_special_tokens and index in self.all_special_ids:
continue

if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)

if len(current_stream) > 0:
tokens.append(self.encoding.decode(current_stream))
return tokens
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8')
return text

def build_inputs_with_special_tokens(
self,
Expand Down
Loading

0 comments on commit 5765724

Please sign in to comment.