Skip to content

Commit

Permalink
Add dtype parameter to VAE object.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 12, 2023
1 parent 32b7e7e commit 824e493
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_key_patches(self):
return self.patcher.get_key_patches()

class VAE:
def __init__(self, sd=None, device=None, config=None):
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)

Expand Down Expand Up @@ -188,7 +188,9 @@ def __init__(self, sd=None, device=None, config=None):
device = model_management.vae_device()
self.device = device
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()

Expand Down

0 comments on commit 824e493

Please sign in to comment.