Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StableDiffusion.text_to_image() casuses an excaption in Colab #2467

Open
KouichiMatsuda opened this issue Jul 2, 2024 · 5 comments
Open
Assignees
Labels
type:Bug Something isn't working

Comments

@KouichiMatsuda
Copy link

KouichiMatsuda commented Jul 2, 2024

Hi Keras Team,

Current Behavior:

The code based on https://keras.io/api/keras_cv/models/tasks/stable_diffusion/ causes an exception: ValueError: Exception encountered when calling DiffusionModelV2.call().

https://colab.research.google.com/drive/1OYet7JBOwgt7L5itxOOzVg-jclgPpdnT?usp=sharing

StableDiffusion class, too.

Am I missing something?

Steps To Reproduce:

https://colab.research.google.com/drive/1OYet7JBOwgt7L5itxOOzVg-jclgPpdnT?usp=sharing

Version:

Keras 3.4.1
TF 2.16.2
KerasCV 0.9.0

@pjogi-testy
Copy link

I have the same error when trying to replicate any Stable-Diffusion-related tutorial from official Keras repo (eg. https://keras.io/examples/generative/random_walks_with_stable_diffusion/), no matter if run locally or using original repo in colab. Possibly something is broken with Keras3 and latest TF? It seems like encoder (77 tokens, 768 values) does not communicate with diffusor UNET (basic 64x64x3 shape). Any clues would be most welcome. Or even confirmation on which exact versions of tf, keras, keras_cv does the official repo works on, because the Keras api is so inconsistent between versions that it is really hard to follow on. Further details of Error:

ValueError: Exception encountered when calling DiffusionModel.call().

Invalid input shape for input Tensor("data_2:0", shape=(3, 77, 768), dtype=float32). Expected shape (None, 64, 64, 4), but input has incompatible shape (3, 77, 768)

Arguments received by DiffusionModel.call():
• inputs={'latent': 'tf.Tensor(shape=(3, 64, 64, 4), dtype=float32)', 'timestep_embedding': 'tf.Tensor(shape=(3, 320), dtype=float32)', 'context': 'tf.Tensor(shape=(3, 77, 768), dtype=float32)'}
• training=False
• mask={'latent': 'None', 'timestep_embedding': 'None', 'context': 'None'}

@heydaari
Copy link

i have the same as
#2467 (comment)

i tried different backends , issue wont be gone

@OttoERM
Copy link

OttoERM commented Jul 19, 2024

I got the same problem by following this tensorflow tutorial with the same error as @pjogi-testy

Long story short got it working on: keras 2.13.1, keras-core 0.1.7, keras-cv 0.9.0, tensorflow 2.13.1
And also on: keras 2.15.0, keras-core 0.1.7, keras-cv 0.9.0, tensorflow 2.15.1

import time
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt
from PIL import Image

model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

image = model.text_to_image(prompt="Flower", batch_size=1, num_steps=15)

Image.fromarray(image[0]).save("Flower.png")
print("Saved at flower.png")

Didn't tried out tensorflow 2.16.*
Anyways I guess is just a broken version between keras 3.4.1 and tensorflow 2.17.0 (Latest release at this time)

In the keras repo readme there is a note "Keras 3 will not function with TensorFlow 2.14 or earlier." Not sure how are you suppose to use Keras 3 because whenever I installed tensorflow a specific version of keras was added, I install --upgrade the keras version but gave me a version error incompatibility.

@Hsieh-Yao-Tsung
Copy link

I had the same error when i use docker image tensorflow/tensorflow:2.17.0-gpu
I got it working on docker image tensorflow/tensorflow:2.16.1-gpu

@mehtamansi29 mehtamansi29 added the type:Bug Something isn't working label Sep 17, 2024
@tenkeyless
Copy link

tenkeyless commented Sep 23, 2024

This issue is caused by a bug in the predict_on_batch method that does not properly recognize inputs when inputting a dict as an argument.

In the stable_diffusion.py file, lines 125, 228 and 235 of the StableDiffusionBase class use dict inputs, but if you change them to array inputs, they will work fine. You can solve the problem by modifying them like this.

However, this is a temporary solution, and the error when inserting a dict into a method in Keras will eventually be completely resolved.

If you use it in colab,

  1. Copy the stable_diffusion.py file,
  2. Paste it in colab,
  3. Change the line after context = self.text_encoder.predict_on_batch(, unconditional_latent = self.diffusion_model.predict_on_batch( and the line after latent = self.diffusion_model.predict_on_batch() as follows.
...

        # context = self.text_encoder.predict_on_batch(
        #     {"tokens": phrase, "positions": self._get_pos_ids()}
        # )
        context = self.text_encoder.predict_on_batch([phrase, self._get_pos_ids()])

...

            # unconditional_latent = self.diffusion_model.predict_on_batch(
            #     {
            #         "latent": latent,
            #         "timestep_embedding": t_emb,
            #         "context": unconditional_context,
            #     }
            # )
            unconditional_latent = self.diffusion_model.predict_on_batch([latent, t_emb, unconditional_context])
            # latent = self.diffusion_model.predict_on_batch(
            #     {
            #         "latent": latent,
            #         "timestep_embedding": t_emb,
            #         "context": context,
            #     }
            # )
            latent = self.diffusion_model.predict_on_batch([latent, t_emb, context])
...
  1. And the model call is as follows.
# model = keras_cv.models.StableDiffusion(
model = StableDiffusion(
    img_width=512, img_height=512, jit_compile=False
)

Although it is in Korean, check out my colab documentation.

https://colab.research.google.com/drive/1JzF8rP8y8XN2jUsmJiH0NMS-Ti1Nd6Ql#scrollTo=y100OUIHDkOH

  • Keras 3.4.1
  • TF 2.17.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants