Skip to content

Commit

Permalink
Merge branch 'master' into memory-optim
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 21, 2023
2 parents 0de7950 + d66b631 commit dca30ca
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 19 deletions.
17 changes: 12 additions & 5 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd

def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
Expand Down Expand Up @@ -154,10 +155,16 @@ def __init__(self, sd=None, device=None, config=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)

self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking
self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7

if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
Expand Down Expand Up @@ -206,7 +213,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
memory_used = self.memory_used_decode(samples_in.shape)
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
Expand Down Expand Up @@ -245,7 +252,7 @@ def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
Expand Down
19 changes: 14 additions & 5 deletions comfy/taesd/taesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5

def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))

@staticmethod
def scale_latents(x):
Expand All @@ -65,3 +66,11 @@ def scale_latents(x):
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample

def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
5 changes: 1 addition & 4 deletions latent_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def __init__(self, taesd, device):
self.device = device

def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1].to(self.device))[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)

x_sample = self.taesd.decode(x0[:1].to(self.device))[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
Expand Down
55 changes: 52 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,18 +573,67 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
return (model_lora, clip_lora)

class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False

for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes

@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")

encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))

enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]

dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]

if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd

@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"

CATEGORY = "loaders"

#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)

Expand Down
5 changes: 3 additions & 2 deletions web/scripts/domWidget.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ function getClipPath(node, element, elRect) {
const MARGIN = 7;
const scale = app.canvas.ds.scale;

const bounding = selectedNode.getBounding();
const intersection = intersect(
{ x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale },
{
x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN,
y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN,
width: selectedNode.size[0] + MARGIN + MARGIN,
height: selectedNode.size[1] + LiteGraph.NODE_TITLE_HEIGHT + MARGIN + MARGIN,
width: bounding[2] + MARGIN + MARGIN,
height: bounding[3] + MARGIN + MARGIN,
}
);

Expand Down

0 comments on commit dca30ca

Please sign in to comment.