diff --git a/distrifuser/pipelines.py b/distrifuser/pipelines.py index 158c852..ea95b9f 100644 --- a/distrifuser/pipelines.py +++ b/distrifuser/pipelines.py @@ -49,6 +49,11 @@ def __call__(self, *args, **kwargs): assert "height" not in kwargs, "height should not be in kwargs" assert "width" not in kwargs, "width should not be in kwargs" config = self.distri_config + if not config.do_classifier_free_guidance: + if "guidance_scale" not in kwargs: + kwargs["guidance_scale"] = 1 + else: + assert kwargs["guidance_scale"] == 1 self.pipeline.unet.set_counter(0) return self.pipeline(height=config.height, width=config.width, *args, **kwargs) @@ -202,6 +207,11 @@ def __call__(self, *args, **kwargs): assert "height" not in kwargs, "height should not be in kwargs" assert "width" not in kwargs, "width should not be in kwargs" config = self.distri_config + if not config.do_classifier_free_guidance: + if not "guidance_scale" not in kwargs: + kwargs["guidance_scale"] = 1 + else: + assert kwargs["guidance_scale"] == 1 self.pipeline.unet.set_counter(0) return self.pipeline(height=config.height, width=config.width, *args, **kwargs)