Skip to content

Commit

Permalink
Add SD3 preset (#1884)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Sep 26, 2024
1 parent f67b4db commit a10de04
Show file tree
Hide file tree
Showing 15 changed files with 740 additions and 118 deletions.
2 changes: 1 addition & 1 deletion keras_hub/src/models/clip/clip_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def build(self, input_shape):
self.packer = StartEndPacker(
start_value=self.tokenizer.start_token_id,
end_value=self.tokenizer.end_token_id,
pad_value=self.tokenizer.end_token_id,
pad_value=self.tokenizer.pad_token_id,
sequence_length=self.sequence_length,
return_padding_mask=True,
)
Expand Down
4 changes: 2 additions & 2 deletions keras_hub/src/models/clip/clip_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_preprocessor_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output={
"token_ids": [[5, 1, 2, 1, 3, 4, 4, 4]],
"token_ids": [[5, 1, 2, 1, 3, 4, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
)
Expand All @@ -52,7 +52,7 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x = preprocessor(input_data)
self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 4, 4, 4, 4]] * 4)
self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)

def test_sequence_length_override(self):
Expand Down
9 changes: 4 additions & 5 deletions keras_hub/src/models/clip/clip_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self._add_special_token("<|startoftext|>", "start_token")
self._add_special_token("<|endoftext|>", "end_token")
self.pad_token_id = 0
self.pad_with_end_token = pad_with_end_token

super().__init__(
vocabulary=vocabulary,
Expand All @@ -77,12 +78,10 @@ def __init__(
**kwargs,
)

# When `pad_with_end_token` is True, we need to access the vocabulary,
# so the check is required.
if pad_with_end_token:
self._check_vocabulary()
def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if self.pad_with_end_token:
self.pad_token_id = self.end_token_id
self.pad_with_end_token = pad_with_end_token

def _bpe_merge_and_update_cache(self, tokens):
"""Process unseen tokens and add to cache."""
Expand Down
58 changes: 46 additions & 12 deletions keras_hub/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Preprocessor(PreprocessingLayer):
image_converter_cls = None

def __init__(self, *args, **kwargs):
self.config_name = kwargs.pop("config_name", PREPROCESSOR_CONFIG_FILE)
super().__init__(*args, **kwargs)
self._tokenizer = None
self._image_converter = None
Expand Down Expand Up @@ -97,6 +98,11 @@ def get_config(self):
config["image_converter"] = keras.layers.serialize(
self.image_converter
)
config.update(
{
"config_name": self.config_name,
}
)
return config

@classmethod
Expand Down Expand Up @@ -126,6 +132,7 @@ def presets(cls):
def from_preset(
cls,
preset,
config_name=PREPROCESSOR_CONFIG_FILE,
**kwargs,
):
"""Instantiate a `keras_hub.models.Preprocessor` from a model preset.
Expand Down Expand Up @@ -175,22 +182,49 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_preprocessor(cls, **kwargs)
return loader.load_preprocessor(cls, config_name, **kwargs)

@classmethod
def _add_missing_kwargs(cls, loader, kwargs):
"""Fill in required kwargs when loading from preset.
This is a private method hit when loading a preprocessing layer that
was not directly saved in the preset. This method should fill in
all required kwargs required to call the class constructor. For almost,
all preprocessors, the only required args are `tokenizer`,
`image_converter`, and `audio_converter`, but this can be overridden,
e.g. for a preprocessor with multiple tokenizers for different
encoders.
"""
if "tokenizer" not in kwargs and cls.tokenizer_cls:
kwargs["tokenizer"] = loader.load_tokenizer(cls.tokenizer_cls)
if "audio_converter" not in kwargs and cls.audio_converter_cls:
kwargs["audio_converter"] = loader.load_audio_converter(
cls.audio_converter_cls
)
if "image_converter" not in kwargs and cls.image_converter_cls:
kwargs["image_converter"] = loader.load_image_converter(
cls.image_converter_cls
)
return kwargs

def load_preset_assets(self, preset):
"""Load all static assets needed by the preprocessing layer.
Args:
preset_dir: The path to the local model preset directory.
"""
for layer in self._flatten_layers(include_self=False):
if hasattr(layer, "load_preset_assets"):
layer.load_preset_assets(preset)

def save_to_preset(self, preset_dir):
"""Save preprocessor to a preset directory.
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=PREPROCESSOR_CONFIG_FILE,
)
if self.tokenizer:
self.tokenizer.save_to_preset(preset_dir)
if self.audio_converter:
self.audio_converter.save_to_preset(preset_dir)
if self.image_converter:
self.image_converter.save_to_preset(preset_dir)
save_serialized_object(self, preset_dir, config_file=self.config_name)
for layer in self._flatten_layers(include_self=False):
if hasattr(layer, "save_to_preset"):
layer.save_to_preset(preset_dir)
9 changes: 9 additions & 0 deletions keras_hub/src/models/stable_diffusion_3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
StableDiffusion3Backbone,
)
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_presets import (
backbone_presets,
)
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, StableDiffusion3Backbone)
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,22 @@ def get_config(self):

@classmethod
def from_config(cls, config, custom_objects=None):
# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
config = config.copy()

# Propagate `dtype` to text encoders if needed.
if "dtype" in config and config["dtype"] is not None:
dtype_config = config["dtype"]
if "dtype" not in config["clip_l"]["config"]:
config["clip_l"]["config"]["dtype"] = dtype_config
if "dtype" not in config["clip_g"]["config"]:
config["clip_g"]["config"]["dtype"] = dtype_config
if (
config["t5"] is not None
and "dtype" not in config["t5"]["config"]
):
config["t5"]["config"]["dtype"] = dtype_config

# We expect `clip_l`, `clip_g` and/or `t5` to be instantiated.
config["clip_l"] = layers.deserialize(
config["clip_l"], custom_objects=custom_objects
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""StableDiffusion3 preset configurations."""

backbone_presets = {
"stable_diffusion_3_medium": {
"metadata": {
"description": (
"3 billion parameter, including CLIP L and CLIP G text "
"encoders, MMDiT generative model, and VAE decoder. "
"Developed by Stability AI."
),
"params": 2952806723,
"official_name": "StableDiffusion3",
"path": "stablediffusion3",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium/1",
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
from keras import layers

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
StableDiffusion3Backbone,
)


@keras_hub_export("keras_hub.models.StableDiffusion3TextToImagePreprocessor")
Expand All @@ -33,6 +37,8 @@ class StableDiffusion3TextToImagePreprocessor(Preprocessor):
t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
"""

backbone_cls = StableDiffusion3Backbone

def __init__(
self,
clip_l_preprocessor,
Expand All @@ -45,6 +51,11 @@ def __init__(
self.clip_g_preprocessor = clip_g_preprocessor
self.t5_preprocessor = t5_preprocessor

@property
def sequence_length(self):
"""The padded length of model input sequences."""
return self.clip_l_preprocessor.sequence_length

def build(self, input_shape):
self.built = True

Expand All @@ -71,7 +82,15 @@ def get_config(self):
)
return config

@property
def sequence_length(self):
"""The padded length of model input sequences."""
return self.clip_l_preprocessor.sequence_length
@classmethod
def from_config(cls, config):
for layer_name in (
"clip_l_preprocessor",
"clip_g_preprocessor",
"t5_preprocessor",
):
if layer_name in config and isinstance(config[layer_name], dict):
config[layer_name] = keras.layers.deserialize(
config[layer_name]
)
return cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def test_generate_preprocess(self):
self.assertIn("clip_l", x)
self.assertIn("clip_g", x)
self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3])
self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 3, 3, 3, 3])
self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 0, 0, 0, 0])
37 changes: 23 additions & 14 deletions keras_hub/src/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_hub.src.utils.preset_utils import ASSET_DIR
from keras_hub.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_hub.src.utils.preset_utils import builtin_presets
from keras_hub.src.utils.preset_utils import find_subclass
from keras_hub.src.utils.preset_utils import get_file
from keras_hub.src.utils.preset_utils import get_preset_loader
from keras_hub.src.utils.preset_utils import save_serialized_object
from keras_hub.src.utils.preset_utils import save_tokenizer_assets
from keras_hub.src.utils.python_utils import classproperty
from keras_hub.src.utils.tensor_utils import preprocessing_function

Expand Down Expand Up @@ -80,6 +79,7 @@ def detokenize(self, inputs):
backbone_cls = None

def __init__(self, *args, **kwargs):
self.config_name = kwargs.pop("config_name", TOKENIZER_CONFIG_FILE)
super().__init__(*args, **kwargs)
self.file_assets = None

Expand Down Expand Up @@ -187,18 +187,26 @@ def _update_special_token_ids(self):
token = getattr(self, attr)
setattr(self, f"{attr}_id", self.token_to_id(token))

def get_config(self):
config = super().get_config()
config.update(
{
"config_name": self.config_name,
}
)
return config

def save_to_preset(self, preset_dir):
"""Save tokenizer to a preset directory.
Args:
preset_dir: The path to the local model preset directory.
"""
save_serialized_object(
self,
preset_dir,
config_file=TOKENIZER_CONFIG_FILE,
)
save_tokenizer_assets(self, preset_dir)
save_serialized_object(self, preset_dir, config_file=self.config_name)
subdir = self.config_name.split(".")[0]
asset_dir = os.path.join(preset_dir, ASSET_DIR, subdir)
os.makedirs(asset_dir, exist_ok=True)
self.save_assets(asset_dir)

@preprocessing_function
def call(self, inputs, *args, training=None, **kwargs):
Expand All @@ -207,11 +215,11 @@ def call(self, inputs, *args, training=None, **kwargs):
def load_preset_assets(self, preset):
asset_path = None
for asset in self.file_assets:
asset_path = get_file(
preset, os.path.join(TOKENIZER_ASSET_DIR, asset)
)
tokenizer_asset_dir = os.path.dirname(asset_path)
self.load_assets(tokenizer_asset_dir)
subdir = self.config_name.split(".")[0]
preset_path = os.path.join(ASSET_DIR, subdir, asset)
asset_path = get_file(preset, preset_path)
tokenizer_config_name = os.path.dirname(asset_path)
self.load_assets(tokenizer_config_name)

@classproperty
def presets(cls):
Expand All @@ -222,6 +230,7 @@ def presets(cls):
def from_preset(
cls,
preset,
config_name=TOKENIZER_CONFIG_FILE,
**kwargs,
):
"""Instantiate a `keras_hub.models.Tokenizer` from a model preset.
Expand Down Expand Up @@ -267,4 +276,4 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
backbone_cls = loader.check_backbone_class()
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_tokenizer(cls, **kwargs)
return loader.load_tokenizer(cls, config_name, **kwargs)
Loading

0 comments on commit a10de04

Please sign in to comment.