Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add StableDiffusion3 #1820

Merged
merged 62 commits into from
Sep 25, 2024
Merged

Conversation

james77777778
Copy link
Collaborator

@james77777778 james77777778 commented Sep 11, 2024

This is more of a draft, as we may need further discussion regarding the implementation.

Notes for reviewing:

  • There are several small modifications in this PR to ensure the numerical stability of all modules/layers. (e.g., LN and Softmax should run in float32)
  • StableDiffusion3Backbone is a large model, resulting in a very long init signature. Is this acceptable? How could we refactor it?
  • Defining a functional model in backbone and text_to_image model is challenging for me.
  • I couldn't compile the entire text_to_image due to unexpected OOM issues. However, when splitting it into encode, denoise and decode functions, it worked fine. I'm unsure the performance impact about not compiling the entire function.
  • I have dropped T5 for simplification: skip it and zero-pad the embeddings from CLIP models.

Refs:

Demo colab:
including weights conversion for https://huggingface.co/stabilityai/stable-diffusion-3-medium.
https://colab.research.google.com/drive/1rrQMs0nlKSEzYNhIJChQwgnrZNiydexS?usp=sharing

"a cat holding a sign that says hello world" "cute wallpaper art of a cat"
1 2

TODO:

  • Rename model folder to stable_diffusion_3
  • Add docstrings
  • Add weight conversion script
  • Add tests

@divyashreepathihalli @mattdangerw @SamanehSaadat

BTW, I will be unavailable from 9/17~9/22

* Agg Vgg16 backbone

* update names

* update tests

* update test

* add image classifier

* incorporate review comments

* Update test case

* update backbone test

* add image classifier

* classifier cleanup

* code reformat

* add vgg16 image classifier

* make vgg generic

* update doc string

* update docstring

* add classifier test

* update tests

* update docstring

* address review comments

* code reformat

* update the configs

* address review comments

* fix task saved model test

* update init

* code reformatted
* Add ResNetV1 and ResNetV2

* Address comments
* Add CSP DarkNet

* Add CSP DarkNet

* snake_case function names

* change use_depthwise to block_type
…Backbone` (keras-team#1769)

* Add FeaturePyramidBackbone and update ResNetBackbone

* Simplify the implementation

* Fix CI

* Make ResNetBackbone compatible with timm and add FeaturePyramidBackbone

* Add conversion implementation

* Update docstrings

* Address comments
* Add DenseNet

* fix testcase

* address comments

* nit

* fix lint errors

* move description
* add vit det vit_det_backbone

* update docstring

* code reformat

* fix tests

* address review comments

* bump year on all files

* address review comments

* rename backbone

* fix tests

* change back to ViT

* address review comments

* update image shape
* Add MixTransformer

* fix testcase

* test changes and comments

* lint fix

* update config list

* modify testcase for 2 layers
* update input_image_shape -> image_shape

* update docstring example

* code reformat

* update tests
add missing __init__ file to vit_det
This is a temporary way to test out the keras-hub branch.
- Does a global rename of all symbols during package build.
- Registers the "old" name on symbol export for saving compat.
- Adds a github action to publish every commit to keras-hub as
  a new package.
- Removes our descriptions on PyPI temporarily, until we want
  to message this more broadly.
* Add `CLIPTokenizer`, `T5XXLTokenizer`, `CLIPTextEncoder` and `T5XXLTextEncoder`.

* Make CLIPTextEncoder as Backbone

* Add `T5XXLPreprocessor` and remove `T5XXLTokenizer`

Add `CLIPPreprocessor`

* Use `tf = None` at the top

* Replace manual implementation of `CLIPAttention` with `MultiHeadAttention`
* Bounding box utils

* - Correct test cases

* - Remove hard tensorflow dtype

* - fix api gen

* - Fix import for test cases
- Use setup for converters test case

* - fix api_gen issue

* - FIx api gen

* - Fix api gen error

* - Correct test cases as per new api changes
* mobilenet_v3 added in keras-nlp

* minor bug fixed in mobilenet_v3_backbone

* formatting corrected

* refactoring backbone

* correct_pad_downsample method added

* refactoring backbone

* parameters updated

* Testcaseupdated, expected output shape corrected

* code formatted with black

* testcase updated

* refactoring and description added

* comments updated

* added mobilenet v1 and v2

* merge conflict resolved

* version arg removed, and config options added

* input_shape changed to image_shape in arg

* config updated

* input shape corrected

* comments resolved

* activation function format changed

* minor bug fixed

* minor bug fixed

* added vision_backbone_test

* channel_first bug resolved

* channel_first cases working

* comments  resolved

* formatting fixed

* refactoring

---------

Co-authored-by: ushareng <[email protected]>
* migrating efficientnet models to keras-hub

* merging changes from other sources

* autoformatting pass

* initial consolidation of efficientnet_backbone

* most updates and removing separate implementation

* cleanup, autoformatting, keras generalization

* removed layer examples outside of effiicient net

* many, mainly documentation changes, small test fixes
* Add ResNet_vd to ResNet backbone

* Addressed requested parameter changes

* Fixed tests and updated comments

* Added new parameters to docstring
* Add `VAEImageDecoder` for StableDiffusionV3

* Use `keras.Model` for `VAEImageDecoder` and follows the coding style in `VAEAttention`
* add pyramid outputs

* fix testcase

* format fix

* make common testcase for pyramid outputs

* change default shape

* simplify testcase

* test case change and add channel axis
* Add `MMDiT`

* Update

* Update

* Update implementation
* - Add formats, iou, utils for bounding box

* - Add `AnchorGenerator`, `BoxMatcher` and `NonMaxSupression` layers

* - Remove scope_name  not required.

* use default keras name scope

* - Correct format error

* - Remove layers as of now and keep them at model level till keras core supports them

* - Correct api_gen
@mattdangerw
Copy link
Member

Awesome work! The samples are very exciting. Still reading through it, but some initial comments.

StableDiffusion3Backbone is a large model, resulting in a very long init signature. Is this acceptable? How could we refactor it?

Yeah definitely see what you mean. I think the best way to break this up would be to have StableDiffusion3Backbone take in other sub-models as arguments, and serialize them in the config with keras.layers.serialize(component), but that requires we make all the sub-components we expose this way public.

In the case of clip, it actually seems like we would like a standalone clip model with it's own from_preset constructor. So we could probably make a separate model directory for clip, and a clip backbone class with a from_preset constructor (even if we don't upload any standalone clip weights/configs yet).

In the case of MMDiT, I'm not sure if that makes sense to be it's own standalone backbone we someday provide weights for, but we could just expose it as a separate model (but not backbone with from_preset) so that StableDiffusion3Backbone can just take it as a argument and have a decluttered constructor.

WDYT?

I have dropped T5 for simplification: skip it and zero-pad the embeddings from CLIP models.

You mean in the colab right? The code looks like it's still supporting it everywhere. I think that is fine, probably the right way to start. Part of me is tempted to just rip our the T5 part entirely for now, wait for someone to ask for it. It'd make the initial implementation a lot simpler, and it seems like what almost all users will want anyway.

@james77777778
Copy link
Collaborator Author

I think the best way to break this up would be to have StableDiffusion3Backbone take in other sub-models as arguments

Sounds good to me. This should be the first nested backbone model in KerasHub but it probably won’t be the last. 😅

So we could probably make a separate model directory for clip, and a clip backbone class with a from_preset constructor

I’m not too familiar with CLIP, but I noticed CLIPTextModel and CLIPVisionModel structures in huggingface/transformers. Will move CLIPTextEncoder into a new model directory and rename it to CLIPTextModel. We can complete the full implementation of CLIP as a future work.

In the case of MMDiT, I'm not sure if that makes sense to be it's own standalone backbone we someday provide weights for, but we could just expose it as a separate model (but not backbone with from_preset) so that StableDiffusion3Backbone can just take it as a argument and have a decluttered constructor.

Yeah, I second that. Will make MMDiT a simple keras.Model, which StableDiffusion3Backbone can import.

You mean in the colab right? The code looks like it's still supporting it everywhere. I think that is fine, probably the right way to start. Part of me is tempted to just rip our the T5 part entirely for now, wait for someone to ask for it.

Actually, in the colab, T5 wasn't loaded because I set t5_vocabulary_size to None to skip the loading.
I also agree with not providing T5 as default but we can still allow users to decide whether to load it through the constructor signature. Should we?

@mattdangerw
Copy link
Member

mattdangerw commented Sep 11, 2024

Actually, in the colab, T5 wasn't loaded because I set t5_vocabulary_size to None to skip the loading.
I also agree with not providing T5 as default but we can still allow users to decide whether to load it through the constructor signature. Should we?

Yeah that sounds like the right way to do it. Right now our weights our backbone weights are monolithic and loaded in full. We could consider supporting a partial instantiation of weights, but I might actually do this as a separate upload.

stable_diffusion_3_medium
stable_diffusion_3_medium_with_t5

Something like that. Then users wouldn't even have to download t5 in the "more usual" path. But anyway, that we can figure out later, supporting both as config options sounds good if it doesn't sound like it'd be too bad to implement.

@james77777778
Copy link
Collaborator Author

Then users wouldn't even have to download t5 in the "more usual" path. But anyway, that we can figure out later, supporting both as config options sounds good if it doesn't sound like it'd be too bad to implement.

Agree with this and it shouldn't be difficult to implement.

Some updates:

  • Use clip_l, clip_g and t5 as arguments to simplify the SD3 backbone signature
  • Ensure the implementation of CLIPTextEncoder is consistent with huggingface/transformers and move it into a new model directory.
  • Simplify the original encode, denoise and decode steps into one compilable text_to_image step using ops.fori_loop. However, there is no significant speedup - the win is in cleaner code.
  • I chose to instantiate MMDiT and VAEImageDecoder in the SD3 backbone because it's the only way to support arbitrary image sizes. If we separate them from the backbone, it becomes difficult or even impossible to change the image size in uploaded presets. (Reason: many ops.reshape need a fixed shape of latents)

Let me know when this is ready. If so, I will add docstrings, the weight conversion script (in tools/checkpoint_conversion) and tests.

@divyashreepathihalli
Copy link
Collaborator

@james77777778 We have renamed the repo and code to KerasHub! Sorry about this disruptive change but a one time cost. Please feel free to close this PR and open new one with the new master.

@james77777778
Copy link
Collaborator Author

No problem. I will be back on 9/22 and resubmit the PR soon.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good to me! I think we need to think a bit about saving and loading more, but we can do that in a follow up PR.

Interesting questions there.

Should keras_hub.tokenizers.Tokenizer.from_preset("sd3_preset_name") return something? Do we want to add a way to create each tokenizer individually? What does instantiation look like?

This will also need a rebase after our big symbol rename change.

keras_nlp/src/models/text_to_image.py Outdated Show resolved Hide resolved
metrics="auto",
**kwargs,
):
# TODO: Figure out how to compile.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we can just chain to super() here and clear the generate_function. Compile doesn't actually create a traced function, we actually do that lazily the first time a generate, predict, train function is called.

If we even had a argument that we thought we be common to a lot of models and requires recompiling the function (like sampler for text models), we could consider adding it here, but I don't think we need to do that now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. I referred to this implementation:
https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414

    def compile(
        self,
        optimizer="auto",
        loss="auto",
        *,
        metrics="auto",
        **kwargs,
    ):
        # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414
        if optimizer == "auto":
            optimizer = keras.optimizers.AdamW(
                1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0
            )
        if loss == "auto":
            loss = keras.losses.MeanSquaredError()
        if metrics == "auto":
            metrics = [keras.metrics.MeanSquaredError()]
        super().compile(
            optimizer=optimizer,
            loss=loss,
            metrics=metrics,
            **kwargs,
        )
        self.generate_function = None



@keras_nlp_export("keras_nlp.models.StableDiffusion3Backbone")
class StableDiffusion3Backbone(Backbone):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backbone is looking good! one good thing to test is model.summary() and make sure it's looking reasonable, now that we are on a functional model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added some helper layers to get a cleaner model.summary().
Now it looks like this:

Model: "stable_diffusion3_backbone"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                  ┃ Output Shape              ┃         Param # ┃ Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ clip_l_token_ids (InputLayer) │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g_token_ids (InputLayer) │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_l_negative_token_ids     │ (None, None)              │               0 │ -                          │
│ (InputLayer)                  │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g_negative_token_ids     │ (None, None)              │               0 │ -                          │
│ (InputLayer)                  │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_l (CLIPTextEncoder)      │ [(None, None, 768),       │     123,060,480 │ clip_l_token_ids[0][0],    │
│                               │ (None, None, 768)]        │                 │ clip_l_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g (CLIPTextEncoder)      │ [(None, None, 1280),      │     693,021,440 │ clip_g_token_ids[0][0],    │
│                               │ (None, None, 1280)]       │                 │ clip_g_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ num_steps (InputLayer)         ()                        │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_1 (Concatenate)   │ (None, None, 2048)        │               0 │ clip_l[0][0], clip_g[0][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_3 (Concatenate)   │ (None, None, 2048)        │               0 │ clip_l[1][0], clip_g[1][0] │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_63 (Cast)                │ (None, None, 768)         │               0 │ clip_l[0][1]               │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_64 (Cast)                │ (None, None, 1280)        │               0 │ clip_g[0][1]               │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_65 (Cast)                │ (None, None, 768)         │               0 │ clip_l[1][1]               │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_66 (Cast)                │ (None, None, 1280)        │               0 │ clip_g[1][1]               │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ scheduler                      ()                        │               0 │ num_steps[0][0],           │
│ (FlowMatchEulerDiscreteSched… │                           │                 │ num_steps[0][0]            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad (Pad)                     │ (None, None, 4096)        │               0 │ concatenate_1[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad_2 (Pad)                   │ (None, None, 4096)        │               0 │ concatenate_3[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_l_projection             │ (None, 768)               │         589,824 │ cast_63[0][0],             │
│ (CLIPProjection)              │                           │                 │ clip_l_token_ids[0][0],    │
│                               │                           │                 │ cast_65[0][0],             │
│                               │                           │                 │ clip_l_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ clip_g_projection             │ (None, 1280)              │       1,638,400 │ cast_64[0][0],             │
│ (CLIPProjection)              │                           │                 │ clip_g_token_ids[0][0],    │
│                               │                           │                 │ cast_66[0][0],             │
│                               │                           │                 │ clip_g_negative_token_ids… │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ multiply (Multiply)            ()                        │               0 │ scheduler[0][0]            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ latents (InputLayer)          │ (None, 100, 100, 16)      │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad_1 (Pad)                   │ (None, None, 4096)        │               0 │ pad[0][0]                  │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ pad_3 (Pad)                   │ (None, None, 4096)        │               0 │ pad_2[0][0]                │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate (Concatenate)     │ (None, 2048)              │               0 │ clip_l_projection[0][0],   │
│                               │                           │                 │ clip_g_projection[0][0]    │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_2 (Concatenate)   │ (None, 2048)              │               0 │ clip_l_projection[1][0],   │
│                               │                           │                 │ clip_g_projection[1][0]    │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ broadcast_to (BroadcastTo)    │ (None)                    │               0 │ multiply[0][0]             │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_4 (Concatenate)   │ (None, None, 4096)        │               0 │ pad_1[0][0], pad_3[0][0]   │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_6 (Concatenate)   │ (None, 100, 100, 16)      │               0 │ latents[0][0],             │
│                               │                           │                 │ latents[0][0]              │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_5 (Concatenate)   │ (None, 2048)              │               0 │ concatenate[0][0],         │
│                               │                           │                 │ concatenate_2[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ concatenate_7 (Concatenate)   │ (None)                    │               0 │ broadcast_to[0][0],        │
│                               │                           │                 │ broadcast_to[0][0]         │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ diffuser (MMDiT)              │ (None, None, None, 16)    │   2,084,951,104 │ concatenate_4[0][0],       │
│                               │                           │                 │ concatenate_6[0][0],       │
│                               │                           │                 │ concatenate_5[0][0],       │
│                               │                           │                 │ concatenate_7[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ cast_67 (Cast)                │ (None, None, None, 16)    │               0 │ diffuser[0][0]             │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ guidance_scale (InputLayer)    ()                        │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ classifier_free_guidance      │ (None, None, None, 16)    │               0 │ cast_67[0][0],             │
│ (ClassifierFreeGuidance)      │                           │                 │ guidance_scale[0][0]       │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ euler_step (EulerStep)        │ (None, 100, 100, 16)      │               0 │ latents[0][0],             │
│                               │                           │                 │ classifier_free_guidance[… │
│                               │                           │                 │ scheduler[0][0],           │
│                               │                           │                 │ scheduler[1][0]            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ latent_calibration            │ (None, 100, 100, 16)      │               0 │ euler_step[0][0]           │
│ (LatentCalibration)           │                           │                 │                            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ decoder (VAEImageDecoder)     │ (None, 800, 800, 3)       │      49,545,475 │ latent_calibration[0][0]   │
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
 Total params: 2,952,806,723 (5.50 GB)
 Trainable params: 2,952,806,723 (5.50 GB)
 Non-trainable params: 0 (0.00 B)

config = super().get_config()
config.update(
{
"clip_l_preprocessor": layers.serialize(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do clip_l and clip_g have separate tokenizers? We should think about seralialization a bit here. Right now for our "preset" saving, we assume one tokenizer, in a fixed directory of assets/tokenizer. We probably need to tweak our saving and loading a bit.

Probably makes sense to do this as a separate PR? I'll think about how to best do this and post some thoughts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a scaffold of what we could do here. #1860

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will keep the current implementation for now.
Should I create a new PR for this after SD3 is merged, or will you finish it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to patch this in to a PR you own. I think actually using it to save sd3 assets will be an important way to test it out

for batched inputs.

Args:
latents: A <float>[batch_size, height, width, channels] tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think we use this notation usually? I'd just say "A float tensor with shape (batch_size, height, width, channels)..."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be some legacy docstrings in the codebase:

"""Score a generation represented by the provided token ids.
Args:
token_ids: A <int>[batch_size, num_tokens] tensor containing tokens
to score. Typically, this tensor captures the output from a call
to `GemmaCausalLM.generate()`, i.e., tokens for both the input
text and the model-generated text.
padding_mask: A <bool>[batch_size, num_tokens] tensor indicating the
tokens that should be preserved during generation. This is an
artifact required by the GemmaBackbone and isn't influential on
the computation of this function. If omitted, this function uses
`keras.ops.ones()` to create a tensor of the appropriate shape.
scoring_mode: The type of scores to return, either "logits" or
"loss", both will be per input token.
layer_intercept_fn: An optional function for augmenting activations
with additional computation, for example, as part of
interpretability research. This function will be passed the
activations as its first parameter and a numeric index
associated with that backbone layer. _This index _is not_ an
index into `self.backbone.layers`_. The index -1 accompanies the
embeddings returned by calling `self.backbone.token_embedding()`
on `token_ids` in the forward direction. All subsequent indexes
will be 0-based indices for the activations returned by each of
the Transformers layers in the backbone. This function must
return a <float>[batch_size, num_tokens, hidden_dims] tensor
that can be passed as an input to the next layer in the model.
target_ids: An <bool>[batch_size, num_tokens] tensor containing the
predicted tokens against which the loss should be computed. If a
span of tokens is provided (sequential truthy values along
axis=1 in the tensor), the loss will be computed as the
aggregate across those tokens.

I adapted it from there. It's fixed now.

@mattdangerw
Copy link
Member

I think it probably makes sense to pull this in without compilation all the way figured out, or saving. And take those on as two (hopefully independent) follow ups. Nice work!

@james77777778
Copy link
Collaborator Author

I have switched to master branch and fixed the renaming issues.
I will start addressing the above comments tomorrow.

@james77777778
Copy link
Collaborator Author

@mattdangerw @divyashreepathihalli

I have addressed the above comments. The model works as-is:
https://colab.research.google.com/drive/1rrQMs0nlKSEzYNhIJChQwgnrZNiydexS?usp=sharing

Please let me know if the implementation is ready. I will add unit tests and complete the missing docstrings afterward.

Also, please let me know if I should take over #1860 or not.

@divyashreepathihalli
Copy link
Collaborator

@james77777778 The implementation is looking good!! the results are looking good! please go ahead and finish up the docstrings and unit tests!

Thanks!

@james77777778 james77777778 changed the title Add StableDiffusion3TextToImage Add StableDiffusion3 Sep 24, 2024
@james77777778
Copy link
Collaborator Author

@divyashreepathihalli
This PR is ready.
I will port the weights and add the preset in a separate PR (with the #1860 patch)

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Sep 24, 2024
@divyashreepathihalli
Copy link
Collaborator

@james77777778 the jax test is failing for Keras 3.1, but passing on others, let me know if it is fixable or just a bug on keras 3.1.

@james77777778
Copy link
Collaborator Author

the jax test is failing for Keras 3.1, but passing on others, let me know if it is fixable or just a bug on keras 3.1.

It should be fixed now. The root cause is that Keras 3.1 doesn't have a setter for dtype policy. The solution is to directly assign a DTypePolicy for the assignee.

@divyashreepathihalli
Copy link
Collaborator

Thanks James!! merging this!

@divyashreepathihalli divyashreepathihalli merged commit 743adea into keras-team:master Sep 25, 2024
7 checks passed
@james77777778 james77777778 deleted the add-sd3 branch September 26, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run Runs Tests on GPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants