Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Support multiple tokenizers and other layers with assets #1860

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions keras_hub/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,39 @@ def from_preset(
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_preprocessor(cls, **kwargs)

@classmethod
def _add_missing_kwargs(cls, loader, kwargs):
"""Fill in required kwargs when loading from preset.

This is a private method hit when loading a preprocessing layer that
was not directly saved in the preset. This method should fill in
all required kwargs required to call the class constructor. For almost,
all preprocessors, the only required args are `tokenizer`,
`image_converter`, and `audio_converter`, but this can be overridden,
e.g. for a preprocessor with multiple tokenizers for different
encoders."""
if "tokenizer" not in kwargs and cls.tokenizer_cls:
kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls)
if "audio_converter" not in kwargs and cls.audio_converter_cls:
kwargs["audio_converter"] = loader.load_audio_converter(
cls.audio_converter_cls
)
if "image_converter" not in kwargs and cls.image_converter_cls:
kwargs["image_converter"] = loader.load_image_converter(
cls.image_converter_cls
)
return kwargs

def load_preset_assets(self, preset):
"""Load all static assets needed by the preprocessing layer.

Args:
preset_dir: The path to the local model preset directory.
"""
for layer in self._flatten_layers(include_self=False):
if hasattr(layer, "load_preset_assets"):
layer.load_preset_assets(self.preset)

def save_to_preset(self, preset_dir):
"""Save preprocessor to a preset directory.

Expand All @@ -188,9 +221,6 @@ def save_to_preset(self, preset_dir):
preset_dir,
config_file=PREPROCESSOR_CONFIG_FILE,
)
if self.tokenizer:
self.tokenizer.save_to_preset(preset_dir)
if self.audio_converter:
self.audio_converter.save_to_preset(preset_dir)
if self.image_converter:
self.image_converter.save_to_preset(preset_dir)
for layer in self._flatten_layers(include_self=False):
if hasattr(layer, "save_to_preset"):
layer.save_to_preset(preset_dir)
38 changes: 23 additions & 15 deletions keras_hub/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import ASSET_DIR
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_file
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import save_tokenizer_assets
from keras_hub.src.utils.python_utils import classproperty
from keras_hub.src.utils.tensor_utils import preprocessing_function

Expand Down Expand Up @@ -80,6 +78,7 @@ def detokenize(self, inputs):
backbone_cls = None

def __init__(self, *args, **kwargs):
self.config_name = kwargs.pop("config_name", "tokenizer.json")
super().__init__(*args, **kwargs)
self.file_assets = None

Expand Down Expand Up @@ -187,18 +186,26 @@ def _update_special_token_ids(self):
token = getattr(self, attr)
setattr(self, f"{attr}_id", self.token_to_id(token))

def get_config(self):
config = super().get_config()
config.update(
{
"config_name": self.config_name,
}
)
return config

def save_to_preset(self, preset_dir):
"""Save tokenizer to a preset directory.

Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=TOKENIZER_CONFIG_FILE,
)
save_tokenizer_assets(self, preset_dir)
save_serialized_object(self, preset_dir, config_file=self.config_name)
subdir = self.config_name.split(".")[0]
asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir)
os.makedirs(asset_dir, exist_ok=True)
self.save_assets(asset_dir)

@preprocessing_function
def call(self, inputs, *args, training=None, **kwargs):
Expand All @@ -207,11 +214,11 @@ def call(self, inputs, *args, training=None, **kwargs):
def load_preset_assets(self, preset):
asset_path = None
for asset in self.file_assets:
asset_path = get_file(
preset, os.path.join(TOKENIZER_ASSET_DIR, asset)
)
tokenizer_asset_dir = os.path.dirname(asset_path)
self.load_assets(tokenizer_asset_dir)
subdir = self.config_name.split(".")[0]
preset_path = os.path.join(ASSET_DIR, subdir, asset)
asset_path = get_file(preset, preset_path)
tokenizer_config_name = os.path.dirname(asset_path)
self.load_assets(tokenizer_config_name)

@classproperty
def presets(cls):
Expand All @@ -222,6 +229,7 @@ def presets(cls):
def from_preset(
cls,
preset,
config_name="tokenizer.json",
**kwargs,
):
"""Instantiate a `keras_hub.models.Tokenizer` from a model preset.
Expand Down Expand Up @@ -267,4 +275,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_tokenizer(cls, **kwargs)
return loader.load_tokenizer(cls, config_name, **kwargs)
65 changes: 9 additions & 56 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
GS_SCHEME = "gs"
HF_SCHEME = "hf"

TOKENIZER_ASSET_DIR = "assets/tokenizer"
ASSET_DIR = "assets"
TOKENIZER_ASSET_DIR = f"{ASSET_DIR}/tokenizer"

# Config file names.
CONFIG_FILE = "config.json"
Expand Down Expand Up @@ -307,13 +308,6 @@ def make_preset_dir(preset):
os.makedirs(preset, exist_ok=True)


def save_tokenizer_assets(tokenizer, preset):
if tokenizer:
asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR)
os.makedirs(asset_dir, exist_ok=True)
tokenizer.save_assets(asset_dir)


def save_serialized_object(
layer,
preset,
Expand Down Expand Up @@ -345,37 +339,6 @@ def save_metadata(layer, preset):
metadata_file.write(json.dumps(metadata, indent=4))


def _validate_tokenizer(preset):
if not check_file_exists(preset, TOKENIZER_CONFIG_FILE):
return
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
try:
with open(config_path, encoding="utf-8") as config_file:
config = json.load(config_file)
except Exception as e:
raise ValueError(
f"Tokenizer config file `{config_path}` is an invalid json file. "
f"Error message: {e}"
)
layer = keras.saving.deserialize_keras_object(config)

for asset in layer.file_assets:
asset_path = get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
if not os.path.exists(asset_path):
tokenizer_asset_dir = os.path.dirname(asset_path)
raise FileNotFoundError(
f"Asset `{asset}` doesn't exist in the tokenizer asset direcotry"
f" `{tokenizer_asset_dir}`."
)
config_dir = os.path.dirname(config_path)
asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR)

tokenizer = get_tokenizer(layer)
if not tokenizer:
raise ValueError(f"Model or layer `{layer}` is missing tokenizer.")
tokenizer.load_assets(asset_dir)


def _validate_backbone(preset):
config_path = os.path.join(preset, CONFIG_FILE)
if not os.path.exists(config_path):
Expand Down Expand Up @@ -493,7 +456,6 @@ def upload_preset(
raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")

_validate_backbone(preset)
_validate_tokenizer(preset)

if uri.startswith(KAGGLE_PREFIX):
if kagglehub is None:
Expand Down Expand Up @@ -665,7 +627,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
"""Load the backbone model from the preset."""
raise NotImplementedError

def load_tokenizer(self, cls, **kwargs):
def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs):
"""Load a tokenizer layer from the preset."""
raise NotImplementedError

Expand Down Expand Up @@ -703,16 +665,7 @@ def load_preprocessor(self, cls, **kwargs):
arguments. This allow us to support transformers checkpoints by
only converting the backbone and tokenizer.
"""
if "tokenizer" not in kwargs and cls.tokenizer_cls:
kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls)
if "audio_converter" not in kwargs and cls.audio_converter_cls:
kwargs["audio_converter"] = self.load_audio_converter(
cls.audio_converter_cls
)
if "image_converter" not in kwargs and cls.image_converter_cls:
kwargs["image_converter"] = self.load_image_converter(
cls.image_converter_cls
)
kwargs = cls._add_missing_kwargs(self, kwargs)
return cls(**kwargs)


Expand All @@ -727,8 +680,8 @@ def load_backbone(self, cls, load_weights, **kwargs):
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
return backbone

def load_tokenizer(self, cls, **kwargs):
tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE)
def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs):
tokenizer_config = load_json(self.preset, config_name)
tokenizer = load_serialized_object(tokenizer_config, **kwargs)
tokenizer.load_preset_assets(self.preset)
return tokenizer
Expand All @@ -755,8 +708,8 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
)
# We found a `task.json` with a complete config for our class.
task = load_serialized_object(task_config, **kwargs)
if task.preprocessor and task.preprocessor.tokenizer:
task.preprocessor.tokenizer.load_preset_assets(self.preset)
if task.preprocessor:
task.preprocessor.load_preset_assets(self.preset)
if load_weights:
has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
if has_task_weights and load_task_weights:
Expand All @@ -779,5 +732,5 @@ def load_preprocessor(self, cls, **kwargs):
return super().load_preprocessor(cls, **kwargs)
# We found a `preprocessing.json` with a complete config for our class.
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor.tokenizer.load_preset_assets(self.preset)
preprocessor.load_preset_assets(self.preset)
return preprocessor
2 changes: 1 addition & 1 deletion keras_hub/src/utils/transformers/preset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
self.converter.convert_weights(backbone, loader, self.config)
return backbone

def load_tokenizer(self, cls, **kwargs):
def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs):
return self.converter.convert_tokenizer(cls, self.preset, **kwargs)

def load_image_converter(self, cls, **kwargs):
Expand Down
Loading