Skip to content

Commit

Permalink
add assertion for the guidance_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
lmxyy committed May 5, 2024
1 parent f1f45fb commit c71cba0
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions distrifuser/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __call__(self, *args, **kwargs):
assert "height" not in kwargs, "height should not be in kwargs"
assert "width" not in kwargs, "width should not be in kwargs"
config = self.distri_config
if not config.do_classifier_free_guidance:
if "guidance_scale" not in kwargs:
kwargs["guidance_scale"] = 1
else:
assert kwargs["guidance_scale"] == 1
self.pipeline.unet.set_counter(0)
return self.pipeline(height=config.height, width=config.width, *args, **kwargs)

Expand Down Expand Up @@ -202,6 +207,11 @@ def __call__(self, *args, **kwargs):
assert "height" not in kwargs, "height should not be in kwargs"
assert "width" not in kwargs, "width should not be in kwargs"
config = self.distri_config
if not config.do_classifier_free_guidance:
if not "guidance_scale" not in kwargs:
kwargs["guidance_scale"] = 1
else:
assert kwargs["guidance_scale"] == 1
self.pipeline.unet.set_counter(0)
return self.pipeline(height=config.height, width=config.width, *args, **kwargs)

Expand Down

0 comments on commit c71cba0

Please sign in to comment.