diff --git a/distrifuser/pipelines.py b/distrifuser/pipelines.py index 7dbe05a..158c852 100644 --- a/distrifuser/pipelines.py +++ b/distrifuser/pipelines.py @@ -229,7 +229,7 @@ def prepare(self, **kwargs): prompt_embeds=None, negative_prompt_embeds=None, lora_scale=None, - clip_skip=pipeline.clip_skip, + clip_skip=kwargs.get("clip_skip", None), ) batch_size = 2 if distri_config.do_classifier_free_guidance else 1 diff --git a/scripts/run_sdxl.py b/scripts/run_sdxl.py index 087ad8d..3ff758b 100644 --- a/scripts/run_sdxl.py +++ b/scripts/run_sdxl.py @@ -40,7 +40,7 @@ def get_args() -> argparse.Namespace: "--sync_mode", type=str, default="corrected_async_gn", - choices=["separate_gn", "async_gn", "corrected_async_gn", "sync_gn", "full_sync", "no_sync"], + choices=["separate_gn", "stale_gn", "corrected_async_gn", "sync_gn", "full_sync", "no_sync"], help="Different GroupNorm synchronization modes", ) parser.add_argument( diff --git a/scripts/sd_example.py b/scripts/sd_example.py index cf92b5a..ccc5a08 100644 --- a/scripts/sd_example.py +++ b/scripts/sd_example.py @@ -3,7 +3,7 @@ from distrifuser.pipelines import DistriSDPipeline from distrifuser.utils import DistriConfig -distri_config = DistriConfig(height=512, width=512, warmup_steps=4) +distri_config = DistriConfig(height=512, width=512, warmup_steps=4, mode="stale_gn") pipeline = DistriSDPipeline.from_pretrained( distri_config=distri_config, pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",