diff --git a/keras_hub/src/models/clip/clip_preprocessor.py b/keras_hub/src/models/clip/clip_preprocessor.py index c8632e033..454afc87f 100644 --- a/keras_hub/src/models/clip/clip_preprocessor.py +++ b/keras_hub/src/models/clip/clip_preprocessor.py @@ -94,7 +94,7 @@ def build(self, input_shape): self.packer = StartEndPacker( start_value=self.tokenizer.start_token_id, end_value=self.tokenizer.end_token_id, - pad_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, sequence_length=self.sequence_length, return_padding_mask=True, ) diff --git a/keras_hub/src/models/clip/clip_preprocessor_test.py b/keras_hub/src/models/clip/clip_preprocessor_test.py index 8321d7be7..733150ba3 100644 --- a/keras_hub/src/models/clip/clip_preprocessor_test.py +++ b/keras_hub/src/models/clip/clip_preprocessor_test.py @@ -38,7 +38,7 @@ def test_preprocessor_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output={ - "token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]], + "token_ids": [[5, 1, 2, 1, 3, 4, 0, 0]], "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], }, ) @@ -52,7 +52,7 @@ def test_no_start_end_token(self): add_end_token=False, ) x = preprocessor(input_data) - self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 4, 4, 4, 4]] * 4) + self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4) self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) def test_sequence_length_override(self): diff --git a/keras_hub/src/models/clip/clip_tokenizer.py b/keras_hub/src/models/clip/clip_tokenizer.py index 2a594bda9..5917b0090 100644 --- a/keras_hub/src/models/clip/clip_tokenizer.py +++ b/keras_hub/src/models/clip/clip_tokenizer.py @@ -69,6 +69,7 @@ def __init__( self._add_special_token("<|startoftext|>", "start_token") self._add_special_token("<|endoftext|>", "end_token") self.pad_token_id = 0 + self.pad_with_end_token = pad_with_end_token super().__init__( vocabulary=vocabulary, @@ -77,12 +78,10 @@ def __init__( **kwargs, ) - # When `pad_with_end_token` is True, we need to access the vocabulary, - # so the check is required. - if pad_with_end_token: - self._check_vocabulary() + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + if self.pad_with_end_token: self.pad_token_id = self.end_token_id - self.pad_with_end_token = pad_with_end_token def _bpe_merge_and_update_cache(self, tokens): """Process unseen tokens and add to cache.""" diff --git a/keras_hub/src/models/preprocessor.py b/keras_hub/src/models/preprocessor.py index 01ed2b6bf..07cbba4d1 100644 --- a/keras_hub/src/models/preprocessor.py +++ b/keras_hub/src/models/preprocessor.py @@ -47,6 +47,7 @@ class Preprocessor(PreprocessingLayer): image_converter_cls = None def __init__(self, *args, **kwargs): + self.config_name = kwargs.pop("config_name", PREPROCESSOR_CONFIG_FILE) super().__init__(*args, **kwargs) self._tokenizer = None self._image_converter = None @@ -97,6 +98,11 @@ def get_config(self): config["image_converter"] = keras.layers.serialize( self.image_converter ) + config.update( + { + "config_name": self.config_name, + } + ) return config @classmethod @@ -126,6 +132,7 @@ def presets(cls): def from_preset( cls, preset, + config_name=PREPROCESSOR_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Preprocessor` from a model preset. @@ -175,7 +182,41 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, **kwargs) + return loader.load_preprocessor(cls, config_name, **kwargs) + + @classmethod + def _add_missing_kwargs(cls, loader, kwargs): + """Fill in required kwargs when loading from preset. + + This is a private method hit when loading a preprocessing layer that + was not directly saved in the preset. This method should fill in + all required kwargs required to call the class constructor. For almost, + all preprocessors, the only required args are `tokenizer`, + `image_converter`, and `audio_converter`, but this can be overridden, + e.g. for a preprocessor with multiple tokenizers for different + encoders. + """ + if "tokenizer" not in kwargs and cls.tokenizer_cls: + kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls) + if "audio_converter" not in kwargs and cls.audio_converter_cls: + kwargs["audio_converter"] = loader.load_audio_converter( + cls.audio_converter_cls + ) + if "image_converter" not in kwargs and cls.image_converter_cls: + kwargs["image_converter"] = loader.load_image_converter( + cls.image_converter_cls + ) + return kwargs + + def load_preset_assets(self, preset): + """Load all static assets needed by the preprocessing layer. + + Args: + preset_dir: The path to the local model preset directory. + """ + for layer in self._flatten_layers(include_self=False): + if hasattr(layer, "load_preset_assets"): + layer.load_preset_assets(preset) def save_to_preset(self, preset_dir): """Save preprocessor to a preset directory. @@ -183,14 +224,7 @@ def save_to_preset(self, preset_dir): Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=PREPROCESSOR_CONFIG_FILE, - ) - if self.tokenizer: - self.tokenizer.save_to_preset(preset_dir) - if self.audio_converter: - self.audio_converter.save_to_preset(preset_dir) - if self.image_converter: - self.image_converter.save_to_preset(preset_dir) + save_serialized_object(self, preset_dir, config_file=self.config_name) + for layer in self._flatten_layers(include_self=False): + if hasattr(layer, "save_to_preset"): + layer.save_to_preset(preset_dir) diff --git a/keras_hub/src/models/stable_diffusion_3/__init__.py b/keras_hub/src/models/stable_diffusion_3/__init__.py index fd48fde00..fa6f98835 100644 --- a/keras_hub/src/models/stable_diffusion_3/__init__.py +++ b/keras_hub/src/models/stable_diffusion_3/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_presets import ( + backbone_presets, +) +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, StableDiffusion3Backbone) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index 883c2b11f..ef6a43c25 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -615,8 +615,22 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): - # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated. config = config.copy() + + # Propagate `dtype` to text encoders if needed. + if "dtype" in config and config["dtype"] is not None: + dtype_config = config["dtype"] + if "dtype" not in config["clip_l"]["config"]: + config["clip_l"]["config"]["dtype"] = dtype_config + if "dtype" not in config["clip_g"]["config"]: + config["clip_g"]["config"]["dtype"] = dtype_config + if ( + config["t5"] is not None + and "dtype" not in config["t5"]["config"] + ): + config["t5"]["config"]["dtype"] = dtype_config + + # We expect `clip_l`, `clip_g` and/or `t5` to be instantiated. config["clip_l"] = layers.deserialize( config["clip_l"], custom_objects=custom_objects ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py new file mode 100644 index 000000000..bf71579bf --- /dev/null +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -0,0 +1,31 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""StableDiffusion3 preset configurations.""" + +backbone_presets = { + "stable_diffusion_3_medium": { + "metadata": { + "description": ( + "3 billion parameter, including CLIP L and CLIP G text " + "encoders, MMDiT generative model, and VAE decoder. " + "Developed by Stability AI." + ), + "params": 2952806723, + "official_name": "StableDiffusion3", + "path": "stablediffusion3", + "model_card": "https://arxiv.org/abs/2110.00476", + }, + "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1", + } +} diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py index 2a0656bdf..daef6a748 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import keras from keras import layers from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) @keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor") @@ -33,6 +37,8 @@ class StableDiffusion3TextToImagePreprocessor(Preprocessor): t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance. """ + backbone_cls = StableDiffusion3Backbone + def __init__( self, clip_l_preprocessor, @@ -45,6 +51,11 @@ def __init__( self.clip_g_preprocessor = clip_g_preprocessor self.t5_preprocessor = t5_preprocessor + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self.clip_l_preprocessor.sequence_length + def build(self, input_shape): self.built = True @@ -71,7 +82,15 @@ def get_config(self): ) return config - @property - def sequence_length(self): - """The padded length of model input sequences.""" - return self.clip_l_preprocessor.sequence_length + @classmethod + def from_config(cls, config): + for layer_name in ( + "clip_l_preprocessor", + "clip_g_preprocessor", + "t5_preprocessor", + ): + if layer_name in config and isinstance(config[layer_name], dict): + config[layer_name] = keras.layers.deserialize( + config[layer_name] + ) + return cls(**config) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py index 58d2bb1a4..7934f8262 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py @@ -70,4 +70,4 @@ def test_generate_preprocess(self): self.assertIn("clip_l", x) self.assertIn("clip_g", x) self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3]) - self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 3, 3, 3, 3]) + self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 0, 0, 0, 0]) diff --git a/keras_hub/src/tokenizers/tokenizer.py b/keras_hub/src/tokenizers/tokenizer.py index 7856b79ca..534e8ef90 100644 --- a/keras_hub/src/tokenizers/tokenizer.py +++ b/keras_hub/src/tokenizers/tokenizer.py @@ -17,14 +17,13 @@ from keras_hub.src.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_hub.src.utils.preset_utils import TOKENIZER_ASSET_DIR +from keras_hub.src.utils.preset_utils import ASSET_DIR from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_hub.src.utils.preset_utils import builtin_presets from keras_hub.src.utils.preset_utils import find_subclass from keras_hub.src.utils.preset_utils import get_file from keras_hub.src.utils.preset_utils import get_preset_loader from keras_hub.src.utils.preset_utils import save_serialized_object -from keras_hub.src.utils.preset_utils import save_tokenizer_assets from keras_hub.src.utils.python_utils import classproperty from keras_hub.src.utils.tensor_utils import preprocessing_function @@ -80,6 +79,7 @@ def detokenize(self, inputs): backbone_cls = None def __init__(self, *args, **kwargs): + self.config_name = kwargs.pop("config_name", TOKENIZER_CONFIG_FILE) super().__init__(*args, **kwargs) self.file_assets = None @@ -187,18 +187,26 @@ def _update_special_token_ids(self): token = getattr(self, attr) setattr(self, f"{attr}_id", self.token_to_id(token)) + def get_config(self): + config = super().get_config() + config.update( + { + "config_name": self.config_name, + } + ) + return config + def save_to_preset(self, preset_dir): """Save tokenizer to a preset directory. Args: preset_dir: The path to the local model preset directory. """ - save_serialized_object( - self, - preset_dir, - config_file=TOKENIZER_CONFIG_FILE, - ) - save_tokenizer_assets(self, preset_dir) + save_serialized_object(self, preset_dir, config_file=self.config_name) + subdir = self.config_name.split(".")[0] + asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir) + os.makedirs(asset_dir, exist_ok=True) + self.save_assets(asset_dir) @preprocessing_function def call(self, inputs, *args, training=None, **kwargs): @@ -207,11 +215,11 @@ def call(self, inputs, *args, training=None, **kwargs): def load_preset_assets(self, preset): asset_path = None for asset in self.file_assets: - asset_path = get_file( - preset, os.path.join(TOKENIZER_ASSET_DIR, asset) - ) - tokenizer_asset_dir = os.path.dirname(asset_path) - self.load_assets(tokenizer_asset_dir) + subdir = self.config_name.split(".")[0] + preset_path = os.path.join(ASSET_DIR, subdir, asset) + asset_path = get_file(preset, preset_path) + tokenizer_config_name = os.path.dirname(asset_path) + self.load_assets(tokenizer_config_name) @classproperty def presets(cls): @@ -222,6 +230,7 @@ def presets(cls): def from_preset( cls, preset, + config_name=TOKENIZER_CONFIG_FILE, **kwargs, ): """Instantiate a `keras_hub.models.Tokenizer` from a model preset. @@ -267,4 +276,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from backbone_cls = loader.check_backbone_class() if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_tokenizer(cls, **kwargs) + return loader.load_tokenizer(cls, config_name, **kwargs) diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 6f368f2d1..646c7ccb5 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -55,7 +55,8 @@ GS_SCHEME = "gs" HF_SCHEME = "hf" -TOKENIZER_ASSET_DIR = "assets/tokenizer" +ASSET_DIR = "assets" +TOKENIZER_ASSET_DIR = f"{ASSET_DIR}/tokenizer" # Config file names. CONFIG_FILE = "config.json" @@ -307,13 +308,6 @@ def make_preset_dir(preset): os.makedirs(preset, exist_ok=True) -def save_tokenizer_assets(tokenizer, preset): - if tokenizer: - asset_dir = os.path.join(preset, TOKENIZER_ASSET_DIR) - os.makedirs(asset_dir, exist_ok=True) - tokenizer.save_assets(asset_dir) - - def save_serialized_object( layer, preset, @@ -345,37 +339,6 @@ def save_metadata(layer, preset): metadata_file.write(json.dumps(metadata, indent=4)) -def _validate_tokenizer(preset): - if not check_file_exists(preset, TOKENIZER_CONFIG_FILE): - return - config_path = get_file(preset, TOKENIZER_CONFIG_FILE) - try: - with open(config_path, encoding="utf-8") as config_file: - config = json.load(config_file) - except Exception as e: - raise ValueError( - f"Tokenizer config file `{config_path}` is an invalid json file. " - f"Error message: {e}" - ) - layer = keras.saving.deserialize_keras_object(config) - - for asset in layer.file_assets: - asset_path = get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset)) - if not os.path.exists(asset_path): - tokenizer_asset_dir = os.path.dirname(asset_path) - raise FileNotFoundError( - f"Asset `{asset}` doesn't exist in the tokenizer asset direcotry" - f" `{tokenizer_asset_dir}`." - ) - config_dir = os.path.dirname(config_path) - asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR) - - tokenizer = get_tokenizer(layer) - if not tokenizer: - raise ValueError(f"Model or layer `{layer}` is missing tokenizer.") - tokenizer.load_assets(asset_dir) - - def _validate_backbone(preset): config_path = os.path.join(preset, CONFIG_FILE) if not os.path.exists(config_path): @@ -493,7 +456,6 @@ def upload_preset( raise FileNotFoundError(f"The preset directory {preset} doesn't exist.") _validate_backbone(preset) - _validate_tokenizer(preset) if uri.startswith(KAGGLE_PREFIX): if kagglehub is None: @@ -657,6 +619,20 @@ def __init__(self, preset, config): self.config = config self.preset = preset + def get_backbone_kwargs(self, **kwargs): + backbone_kwargs = {} + + # Forward `dtype` to backbone. + backbone_kwargs["dtype"] = kwargs.pop("dtype", None) + + # Forward `height` and `width` to backbone when using `TextToImage`. + if "height" in kwargs: + backbone_kwargs["height"] = kwargs.pop("height", None) + if "width" in kwargs: + backbone_kwargs["width"] = kwargs.pop("width", None) + + return backbone_kwargs, kwargs + def check_backbone_class(self): """Infer the backbone architecture.""" raise NotImplementedError @@ -665,7 +641,7 @@ def load_backbone(self, cls, load_weights, **kwargs): """Load the backbone model from the preset.""" raise NotImplementedError - def load_tokenizer(self, cls, **kwargs): + def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError @@ -685,8 +661,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): """ if "backbone" not in kwargs: backbone_class = cls.backbone_cls - # Forward dtype to backbone. - backbone_kwargs = {"dtype": kwargs.pop("dtype", None)} + backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs) kwargs["backbone"] = self.load_backbone( backbone_class, load_weights, **backbone_kwargs ) @@ -696,23 +671,16 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) return cls(**kwargs) - def load_preprocessor(self, cls, **kwargs): + def load_preprocessor( + self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + ): """Load a prepocessor layer from the preset. By default, we create a preprocessor from a tokenizer with default arguments. This allow us to support transformers checkpoints by only converting the backbone and tokenizer. """ - if "tokenizer" not in kwargs and cls.tokenizer_cls: - kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls) - if "audio_converter" not in kwargs and cls.audio_converter_cls: - kwargs["audio_converter"] = self.load_audio_converter( - cls.audio_converter_cls - ) - if "image_converter" not in kwargs and cls.image_converter_cls: - kwargs["image_converter"] = self.load_image_converter( - cls.image_converter_cls - ) + kwargs = cls._add_missing_kwargs(self, kwargs) return cls(**kwargs) @@ -727,8 +695,8 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) return backbone - def load_tokenizer(self, cls, **kwargs): - tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE) + def load_tokenizer(self, cls, config_name=TOKENIZER_CONFIG_FILE, **kwargs): + tokenizer_config = load_json(self.preset, config_name) tokenizer = load_serialized_object(tokenizer_config, **kwargs) tokenizer.load_preset_assets(self.preset) return tokenizer @@ -755,8 +723,8 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) - if task.preprocessor and task.preprocessor.tokenizer: - task.preprocessor.tokenizer.load_preset_assets(self.preset) + if task.preprocessor: + task.preprocessor.load_preset_assets(self.preset) if load_weights: has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) if has_task_weights and load_task_weights: @@ -769,15 +737,17 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): task.backbone.load_weights(backbone_weights) return task - def load_preprocessor(self, cls, **kwargs): + def load_preprocessor( + self, cls, config_name=PREPROCESSOR_CONFIG_FILE, **kwargs + ): # If there is no `preprocessing.json` or it's for the wrong class, # delegate to the super class loader. - if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE): + if not check_file_exists(self.preset, config_name): return super().load_preprocessor(cls, **kwargs) - preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE) + preprocessor_json = load_json(self.preset, config_name) if not issubclass(check_config_class(preprocessor_json), cls): return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. preprocessor = load_serialized_object(preprocessor_json, **kwargs) - preprocessor.tokenizer.load_preset_assets(self.preset) + preprocessor.load_preset_assets(self.preset) return preprocessor diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 593792d08..a88cec246 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -69,7 +69,7 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone - def load_tokenizer(self, cls, **kwargs): + def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) def load_image_converter(self, cls, **kwargs): diff --git a/keras_hub/src/utils/transformers/safetensor_utils.py b/keras_hub/src/utils/transformers/safetensor_utils.py index 1f7fd80d2..ee4d21a87 100644 --- a/keras_hub/src/utils/transformers/safetensor_utils.py +++ b/keras_hub/src/utils/transformers/safetensor_utils.py @@ -26,7 +26,7 @@ class SafetensorLoader(contextlib.ExitStack): - def __init__(self, preset, prefix=None): + def __init__(self, preset, prefix=None, fname=None): super().__init__() if safetensors is None: @@ -44,6 +44,13 @@ def __init__(self, preset, prefix=None): self.safetensor_files = {} self.prefix = prefix + if fname is not None and self.safetensor_config is not None: + raise ValueError( + f"Cannot specify `fname` if {SAFETENSOR_CONFIG_FILE} exists. " + f"Received: fname={fname}" + ) + self.fname = fname # Specify the name of the safetensor file. + def get_prefixed_key(self, hf_weight_key, dict_like): """ Determine and return a prefixed key for a given hf weight key. @@ -71,7 +78,7 @@ def get_prefixed_key(self, hf_weight_key, dict_like): def get_tensor(self, hf_weight_key): if self.safetensor_config is None: - fname = SAFETENSOR_FILE + fname = self.fname if self.fname is not None else SAFETENSOR_FILE else: full_key = self.get_prefixed_key( hf_weight_key, self.safetensor_config["weight_map"] diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py index 8e3da4a97..02de89080 100644 --- a/tools/checkpoint_conversion/convert_bloom_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -21,10 +21,11 @@ from absl import app from absl import flags -import keras_hub -from keras_hub.models import BloomBackbone -from keras_hub.models import BloomPreprocessor -from keras_hub.models import BloomTokenizer +from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone +from keras_hub.src.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer FLAGS = flags.FLAGS @@ -123,10 +124,10 @@ def convert_weights(keras_model, hf_model): hf_wts["word_embeddings.weight"].detach().numpy() ) # LayerNorm. - keras_model.get_layer("token_embedding_layernorm").gamma.assign( + keras_model.get_layer("embedding_layernorm").gamma.assign( hf_wts["word_embeddings_layernorm.weight"].detach().numpy() ) - keras_model.get_layer("token_embedding_layernorm").beta.assign( + keras_model.get_layer("embedding_layernorm").beta.assign( hf_wts["word_embeddings_layernorm.bias"].detach().numpy() ) @@ -222,13 +223,31 @@ def validate_output( hf_model_outputs = hf_model_outputs.detach().numpy() # KerasHub - preprocessor = BloomPreprocessor( + preprocessor = BloomCausalLMPreprocessor( tokenizer=keras_tokenizer, sequence_length=hf_model_outputs.shape[1], add_end_token=False, add_start_token=False, ) - keras_model_input = preprocessor(input_str) + + # Since we've removed `BloomPreprocessor`, to verify the outputs, we need to + # manually call the following function. + def preprocessor_call(input_str): + if not preprocessor.built: + preprocessor.build(None) + x = preprocessor.tokenizer(input_str) + token_ids, padding_mask = preprocessor.packer( + x, + sequence_length=None, + add_start_value=preprocessor.add_start_token, + add_end_value=preprocessor.add_end_token, + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + keras_model_input = preprocessor_call(input_str) keras_model_outputs = keras_model.predict(keras_model_input) # Comparing the outputs. @@ -280,7 +299,7 @@ def main(_): del hf_tokenizer # Save float32 keras preset - keras_hub.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_model.save_to_preset(preset) # Delete float32 Keras model del keras_model @@ -290,10 +309,8 @@ def main(_): keras_model = BloomBackbone.from_preset(preset_path, dtype="float16") # Save float16 keras model - keras_hub.src.utils.preset_utils.save_to_preset(keras_model, preset) - keras_hub.src.utils.preset_utils.save_to_preset( - keras_tokenizer, preset, config_filename="tokenizer.json" - ) + keras_model.save_to_preset(preset) + keras_tokenizer.save_to_preset(preset) print("✅ Preset saved") else: diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py new file mode 100644 index 000000000..54ef1d91e --- /dev/null +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -0,0 +1,513 @@ +# Copyright 2024 The KerasHub Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert StableDiffusion3 checkpoints. + +export KAGGLE_USERNAME=XXX +export KAGGLE_KEY=XXX + +python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \ + --preset stable_diffusion_3_medium --upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium +""" +import os +import shutil + +import keras +import numpy as np +from absl import app +from absl import flags +from PIL import Image + +import keras_hub +from keras_hub.src.models.clip.clip_preprocessor import CLIPPreprocessor +from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder +from keras_hub.src.models.clip.clip_tokenizer import CLIPTokenizer +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( + StableDiffusion3Backbone, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image import ( + StableDiffusion3TextToImage, +) +from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import ( + StableDiffusion3TextToImagePreprocessor, +) +from keras_hub.src.utils.preset_utils import load_json +from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "stable_diffusion_3_medium": { + # HF root + "root": "hf://stabilityai/stable-diffusion-3-medium", + # Model <-> Path + "clip_l": "text_encoders/clip_l.safetensors", + "clip_g": "text_encoders/clip_g.safetensors", + "diffuser": "sd3_medium.safetensors", + "decoder": "sd3_medium.safetensors", + # Tokenizer + "clip_tokenizer": "hf://openai/clip-vit-large-patch14", + } +} + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) +flags.DEFINE_string( + "output_dir", + "output_dir", + "The generated image will be saved here.", + required=False, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def convert_model(preset, height, width): + # The text encoders are all the same. + clip_l = CLIPTextEncoder( + 49408, 768, 768, 12, 12, 3072, "quick_gelu", -2, name="clip_l" + ) + clip_g = CLIPTextEncoder( + 49408, 1280, 1280, 32, 20, 5120, "gelu", -2, name="clip_g" + ) + # TODO: Add T5. + + # Currently, we hardcode the model arch by preset. + if preset == "stable_diffusion_3_medium": + backbone = StableDiffusion3Backbone( + 2, + 64 * 24, + 24, + 24, + 192, + [512, 512, 256, 128], + [3, 3, 3, 3], + clip_l, + clip_g, + height=height, + width=width, + ) + return backbone + + +def convert_preprocessor(): + tokenizer_content = load_json( + "hf://openai/clip-vit-large-patch14", "tokenizer.json" + ) + vocabulary = tokenizer_content["model"]["vocab"] + merges = tokenizer_content["model"]["merges"] + clip_l_tokenizer = CLIPTokenizer( + vocabulary, + merges, + pad_with_end_token=True, + config_name="clip_l_tokenizer.json", + ) + clip_g_tokenizer = CLIPTokenizer( + vocabulary, merges, config_name="clip_g_tokenizer.json" + ) + clip_l_preprocessor = CLIPPreprocessor( + clip_l_tokenizer, config_name="clip_l_preprocessor.json" + ) + clip_g_preprocessor = CLIPPreprocessor( + clip_g_tokenizer, config_name="clip_g_preprocessor.json" + ) + preprocessor = StableDiffusion3TextToImagePreprocessor( + clip_l_preprocessor, clip_g_preprocessor + ) + return preprocessor + + +def convert_weights(preset, keras_model): + # Define helper functions. + def port_conv2d(loader, keras_variable, hf_weight_key): + loader.port_weight( + keras_variable.kernel, + f"{hf_weight_key}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + loader.port_weight(keras_variable.bias, f"{hf_weight_key}.bias") + + def port_dense(loader, keras_variable, hf_weight_key): + loader.port_weight( + keras_variable.kernel, + f"{hf_weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + loader.port_weight(keras_variable.bias, f"{hf_weight_key}.bias") + + def port_mha(loader, keras_variable, hf_weight_key, num_heads, hidden_dim): + # query + loader.port_weight( + keras_variable.query_dense.kernel, + f"{hf_weight_key}.q_proj.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.query_dense.bias, + f"{hf_weight_key}.q_proj.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # key + loader.port_weight( + keras_variable.key_dense.kernel, + f"{hf_weight_key}.k_proj.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.key_dense.bias, + f"{hf_weight_key}.k_proj.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # value + loader.port_weight( + keras_variable.value_dense.kernel, + f"{hf_weight_key}.v_proj.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.value_dense.bias, + f"{hf_weight_key}.v_proj.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # output + loader.port_weight( + keras_variable.output_dense.kernel, + f"{hf_weight_key}.out_proj.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (num_heads, hidden_dim // num_heads, hidden_dim) + ), + ) + loader.port_weight( + keras_variable.output_dense.bias, f"{hf_weight_key}.out_proj.bias" + ) + + def port_ln_or_gn(loader, keras_variable, hf_weight_key): + loader.port_weight(keras_variable.gamma, f"{hf_weight_key}.weight") + loader.port_weight(keras_variable.beta, f"{hf_weight_key}.bias") + + def port_clip(preset, filename, model, projection_layer): + with SafetensorLoader(preset, prefix="", fname=filename) as loader: + # Embeddings + embedding = model.embedding + loader.port_weight( + embedding.token_embedding._embeddings, + "text_model.embeddings.token_embedding.weight", + ) + loader.port_weight( + embedding.position_embedding.position_embeddings, + "text_model.embeddings.position_embedding.weight", + ) + + # Encoders + encoder_layers = model.encoder_layers + for i in range(len(encoder_layers)): + prefix = "text_model.encoder.layers" + num_heads = encoder_layers[i].num_heads + hidden_dim = encoder_layers[i].hidden_dim + port_mha( + loader, + encoder_layers[i].attention, + f"{prefix}.{i}.self_attn", + num_heads, + hidden_dim, + ) + port_ln_or_gn( + loader, + encoder_layers[i].layer_norm_1, + f"{prefix}.{i}.layer_norm1", + ) + port_ln_or_gn( + loader, + encoder_layers[i].layer_norm_2, + f"{prefix}.{i}.layer_norm2", + ) + port_dense( + loader, encoder_layers[i].dense_1, f"{prefix}.{i}.mlp.fc1" + ) + port_dense( + loader, encoder_layers[i].dense_2, f"{prefix}.{i}.mlp.fc2" + ) + + # Output layers + port_ln_or_gn( + loader, model.layer_norm, "text_model.final_layer_norm" + ) + try: + loader.port_weight( + projection_layer.dense.kernel, + "text_projection.weight", + hook_fn=lambda x, _: x.T, + ) + except Exception: + pass + return model + + def port_diffuser(preset, filename, model): + hf_prefix = "model.diffusion_model." + with SafetensorLoader( + preset, prefix=hf_prefix, fname=filename + ) as loader: + # Embeddings + port_conv2d( + loader, model.patch_embedding.patch_embedding, "x_embedder.proj" + ) + loader.port_weight( + model.position_embedding.position_embeddings, + "pos_embed", + hook_fn=lambda x, _: x[0], + ) + port_dense(loader, model.context_embedding, "context_embedder") + port_dense( + loader, model.vector_embedding.layers[0], "y_embedder.mlp.0" + ) + port_dense( + loader, model.vector_embedding.layers[1], "y_embedder.mlp.2" + ) + port_dense( + loader, + model.timestep_embedding.mlp.layers[0], + "t_embedder.mlp.0", + ) + port_dense( + loader, + model.timestep_embedding.mlp.layers[1], + "t_embedder.mlp.2", + ) + + # Blocks + num_layers = model.num_layers + for i in range(num_layers): + x_block = model.joint_blocks[i].x_block + context_block = model.joint_blocks[i].context_block + for block_name, block in ( + ("x_block", x_block), + ("context_block", context_block), + ): + prefix = f"joint_blocks.{i}.{block_name}" + port_dense( + loader, + block.adaptive_norm_modulation.layers[1], + f"{prefix}.adaLN_modulation.1", + ) + port_dense( + loader, block.attention_qkv, f"{prefix}.attn.qkv" + ) + + if block_name == "context_block" and (i == num_layers - 1): + continue + + port_dense( + loader, block.attention_proj, f"{prefix}.attn.proj" + ) + port_dense(loader, block.mlp.layers[0], f"{prefix}.mlp.fc1") + port_dense(loader, block.mlp.layers[1], f"{prefix}.mlp.fc2") + + # Output layer + port_dense( + loader, + model.output_layer.adaptive_norm_modulation.layers[1], + "final_layer.adaLN_modulation.1", + ) + port_dense( + loader, model.output_layer.output_dense, "final_layer.linear" + ) + return model + + def port_decoder(preset, filename, model): + hf_prefix = "first_stage_model." + + def port_resnet_block( + keras_variable_name, hf_weight_key, has_residual=False + ): + port_ln_or_gn( + loader, + model.get_layer(f"{keras_variable_name}_norm1"), + f"{hf_weight_key}.norm1", + ) + port_conv2d( + loader, + model.get_layer(f"{keras_variable_name}_conv1"), + f"{hf_weight_key}.conv1", + ) + port_ln_or_gn( + loader, + model.get_layer(f"{keras_variable_name}_norm2"), + f"{hf_weight_key}.norm2", + ) + port_conv2d( + loader, + model.get_layer(f"{keras_variable_name}_conv2"), + f"{hf_weight_key}.conv2", + ) + if has_residual: + port_conv2d( + loader, + model.get_layer( + f"{keras_variable_name}_residual_projection" + ), + f"{hf_weight_key}.nin_shortcut", + ) + + def port_attention(keras_variable_name, hf_weight_key): + port_ln_or_gn( + loader, + model.get_layer(keras_variable_name).group_norm, + f"{hf_weight_key}.norm", + ) + port_conv2d( + loader, + model.get_layer(keras_variable_name).query_conv2d, + f"{hf_weight_key}.q", + ) + port_conv2d( + loader, + model.get_layer(keras_variable_name).key_conv2d, + f"{hf_weight_key}.k", + ) + port_conv2d( + loader, + model.get_layer(keras_variable_name).value_conv2d, + f"{hf_weight_key}.v", + ) + port_conv2d( + loader, + model.get_layer(keras_variable_name).output_conv2d, + f"{hf_weight_key}.proj_out", + ) + + with SafetensorLoader( + preset, prefix=hf_prefix, fname=filename + ) as loader: + # Stem + port_conv2d( + loader, model.get_layer("input_projection"), "decoder.conv_in" + ) + port_resnet_block("input_block0", "decoder.mid.block_1") + port_attention("input_attention", "decoder.mid.attn_1") + port_resnet_block("input_block1", "decoder.mid.block_2") + + # Stacks + input_filters = model.stackwise_num_filters[0] + for i, filters in enumerate(model.stackwise_num_filters): + for j in range(model.stackwise_num_blocks[i]): + n = model.stackwise_num_blocks[i] + prefix = f"decoder.up.{n-i}.block.{j}" + port_resnet_block( + f"block{i}_{j}", + prefix, + has_residual=filters != input_filters, + ) + input_filters = filters + if i != len(model.stackwise_num_filters) - 1: + port_conv2d( + loader, + model.get_layer(f"upsample_{i}_conv"), + f"decoder.up.{n-i}.upsample.conv", + ) + # Output layers + port_ln_or_gn( + loader, model.get_layer("output_norm"), "decoder.norm_out" + ) + port_conv2d( + loader, model.get_layer("output_projection"), "decoder.conv_out" + ) + return model + + # Start conversion. + config = PRESET_MAP[preset] + port_clip( + config["root"], + config["clip_l"], + keras_model.clip_l, + keras_model.clip_l_projection, + ) + port_clip( + config["root"], + config["clip_g"], + keras_model.clip_g, + keras_model.clip_g_projection, + ) + port_diffuser(config["root"], config["diffuser"], keras_model.diffuser) + port_decoder(config["root"], config["decoder"], keras_model.decoder) + + +def validate_output(keras_model, keras_preprocessor, output_dir): + # TODO: Verify the numerics. + text_to_image = StableDiffusion3TextToImage(keras_model, keras_preprocessor) + image = text_to_image.generate("cute wallpaper art of a cat", seed=42) + image = Image.fromarray(image) + image.save(os.path.join(output_dir, "test.png")) + + +def main(_): + preset = FLAGS.preset + output_dir = FLAGS.output_dir + if os.path.exists(preset): + shutil.rmtree(preset) + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(preset) + os.makedirs(output_dir) + + print(f"🏃 Coverting {preset}") + + # Currently SD3 weights are float16 (and have much faster download + # times for it). We follow suit with Keras weights. + keras.config.set_dtype_policy("float16") + height, width = 512, 512 # Use a smaller image size to speed up generation. + + keras_preprocessor = convert_preprocessor() + keras_model = convert_model(preset, height, width) + print("✅ KerasHub model loaded.") + + convert_weights(preset, keras_model) + print("✅ Weights converted.") + + validate_output(keras_model, keras_preprocessor, output_dir) + print("✅ Output validated.") + + keras_preprocessor.save_to_preset(preset) + # Set the image size to 1024, the same as in huggingface/diffusers. + keras_model.height = 1024 + keras_model.width = 1024 + keras_model.save_to_preset(preset) + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main)