Skip to content

Commit

Permalink
Set default value to unet config sample size (#1223)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Jul 24, 2023
1 parent c146b75 commit 29675d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor
width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor

# check inputs. Raise error if not correct
self.check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def __call__(
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
height = height or self.unet.config["sample_size"] * self.vae_scale_factor
width = width or self.unet.config["sample_size"] * self.vae_scale_factor
height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor
width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor

# check inputs. Raise error if not correct
self.check_inputs(
Expand Down

0 comments on commit 29675d5

Please sign in to comment.