Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue with WordPieceTokenizer by adding vocab_size argument #1216

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from keras_nlp.utils.tensor_utils import convert_to_ragged_batch
from keras_nlp.utils.tensor_utils import is_integer_dtype
from keras_nlp.utils.tensor_utils import is_string_dtype
from absl import logging

try:
import tensorflow_text as tf_text
Expand Down Expand Up @@ -202,6 +203,8 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
plain text file containing a single WordPiece token per line.
sequence_length: int. If set, the output will be converted to a dense
tensor and padded/trimmed so all outputs are of sequence_length.
vocab_size: int. If set, force vocabulary to be exactly vocabulary_size,
by truncating the input vocabulary if necessary.
lowercase: bool. If `True`, the input text will be
lowercased before tokenization. Defaults to `False`.
strip_accents: bool. If `True`, all accent marks will
Expand Down Expand Up @@ -294,6 +297,7 @@ def __init__(
self,
vocabulary=None,
sequence_length: int = None,
vocab_size: int = None,
lowercase: bool = False,
strip_accents: bool = False,
split: bool = True,
Expand All @@ -313,13 +317,46 @@ def __init__(

super().__init__(dtype=dtype, **kwargs)

self.vocab_size = vocab_size
self.sequence_length = sequence_length
self.lowercase = lowercase
self.strip_accents = strip_accents
self.split = split
self.split_on_cjk = split_on_cjk
self.suffix_indicator = suffix_indicator
self.oov_token = oov_token

if isinstance(vocabulary, str):
self.vocabulary = [
vocabulary_list = [
line.rstrip() for line in tf.io.gfile.GFile(vocabulary)
]
input_vocabulary_size = len(vocabulary_list)
if self.vocab_size == None:
self.vocab_size = input_vocabulary_size
self.vocabulary = vocabulary_list
elif self.vocab_size < input_vocabulary_size:
logging.warning(
"Setting vocab size to a smaller value than the input vocabulary file."
"Some token ids will never be output from the tokenizer."
)
self.vocabulary = vocabulary_list[:self.vocab_size]
else:
self.vocab_size = input_vocabulary_size
self.vocabulary = vocabulary_list
elif isinstance(vocabulary, Iterable):
# Make a copy.
self.vocabulary = list(vocabulary)
input_vocabulary_size = len(vocabulary)
if self.vocab_size == None:
self.vocab_size = input_vocabulary_size
self.vocabulary = list(vocabulary)
elif self.vocab_size < input_vocabulary_size:
logging.warning(
"Setting vocab size to a smaller value than the input vocabulary file."
"Some token ids will never be output from the tokenizer."
)
self.vocabulary = list(vocabulary[:self.vocab_size])
else:
self.vocab_size = input_vocabulary_size
self.vocabulary = list(vocabulary)
else:
raise ValueError(
"Vocabulary must be an file path or list of terms. "
Expand All @@ -328,14 +365,6 @@ def __init__(
if oov_token is None:
raise ValueError("`oov_token` cannot be None.")

self.sequence_length = sequence_length
self.lowercase = lowercase
self.strip_accents = strip_accents
self.split = split
self.split_on_cjk = split_on_cjk
self.suffix_indicator = suffix_indicator
self.oov_token = oov_token

if oov_token not in self.vocabulary:
raise ValueError(
f'Cannot find `oov_token="{self.oov_token}"` in the '
Expand All @@ -360,7 +389,7 @@ def get_vocabulary(self) -> List[str]:

def vocabulary_size(self) -> int:
"""Get the size of the tokenizer vocabulary."""
return len(self.vocabulary)
return self.vocab_size

def id_to_token(self, id: int) -> str:
"""Convert an integer id to a string token."""
Expand Down