From 824e4935f53fdbda8f4608f511b4c2e8daf79dfa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Dec 2023 12:03:29 -0500 Subject: [PATCH] Add dtype parameter to VAE object. --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8c056e4ea2f..220637a05d7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -151,7 +151,7 @@ def get_key_patches(self): return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None, dtype=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -188,7 +188,9 @@ def __init__(self, sd=None, device=None, config=None): device = model_management.vae_device() self.device = device offload_device = model_management.vae_offload_device() - self.vae_dtype = model_management.vae_dtype() + if dtype is None: + dtype = model_management.vae_dtype() + self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device()