Skip to content

Commit

Permalink
Disable autocast in unet for increased speed.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 6, 2023
1 parent 603f02d commit ddc6f12
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 82 deletions.
9 changes: 6 additions & 3 deletions comfy/gligen.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,12 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8):

def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)
dtype = self.linears[0].weight.dtype
masks = masks.unsqueeze(-1).to(dtype)
positive_embeddings = positive_embeddings.to(dtype)

# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C

# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
Expand Down Expand Up @@ -252,7 +254,8 @@ def _set_position(self, boxes, masks, positive_embeddings):

if self.lowvram == True:
self.position_net.cpu()
def func_lowvram(key, x):
def func_lowvram(x, extra_options):
key = extra_options["transformer_index"]
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def forward(self, x, context=None, value=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in

r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)

mem_free_total = model_management.get_free_memory(q.device)

Expand Down Expand Up @@ -314,7 +314,7 @@ def forward(self, x, context=None, value=None, mask=None):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
first_op_done = True

s2 = s1.softmax(dim=-1)
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
Expand Down
10 changes: 5 additions & 5 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
self.use_scale_shift_norm = use_scale_shift_norm

self.in_layers = nn.Sequential(
normalization(channels, dtype=dtype),
nn.GroupNorm(32, channels, dtype=dtype),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
)
Expand All @@ -244,7 +244,7 @@ def __init__(
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels, dtype=dtype),
nn.GroupNorm(32, self.out_channels, dtype=dtype),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
Expand Down Expand Up @@ -778,13 +778,13 @@ def __init__(
self._feature_size += ch

self.out = nn.Sequential(
normalization(ch, dtype=self.dtype),
nn.GroupNorm(32, ch, dtype=self.dtype),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
nn.GroupNorm(32, ch, dtype=self.dtype),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
Expand Down Expand Up @@ -821,7 +821,7 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)

if self.num_classes is not None:
Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
torch.exp(attn_weights - max_score, out=attn_weights)
exp_weights = attn_weights
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
Expand Down Expand Up @@ -166,7 +166,7 @@ def _get_attention_scores_no_kv_chunking(
attn_scores /= summed
attn_probs = attn_scores

hidden_states_slice = torch.bmm(attn_probs, value)
hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value)
return hidden_states_slice

class ScannedChunk(NamedTuple):
Expand Down
8 changes: 7 additions & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control
else:
xc = x
context = torch.cat(c_crossattn, 1)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options)
dtype = self.get_dtype()
xc = xc.to(dtype)
t = t.to(dtype)
context = context.to(dtype)
if c_adm is not None:
c_adm = c_adm.to(dtype)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()

def get_dtype(self):
return self.diffusion_model.dtype
Expand Down
1 change: 1 addition & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def load_model_gpu(model):

torch_dev = model.load_device
model.model_patches_to(torch_dev)
model.model_patches_to(model.model_dtype())

if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
Expand Down
6 changes: 3 additions & 3 deletions comfy/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def get_models_from_cond(cond, model_type):
models += [c[1][model_type]]
return models

def load_additional_models(positive, negative):
def load_additional_models(positive, negative, dtype):
"""loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
gligen = [x[1].to(dtype) for x in gligen]
models = control_nets + gligen
comfy.model_management.load_controlnet_gpu(models)
return models
Expand All @@ -81,7 +81,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)

models = load_additional_models(positive, negative)
models = load_additional_models(positive, negative, model.model_dtype())

sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)

Expand Down
124 changes: 59 additions & 65 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc
import torch
import contextlib
from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
Expand Down Expand Up @@ -577,11 +576,6 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N
apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])

if self.model.get_dtype() == torch.float16:
precision_scope = torch.autocast
else:
precision_scope = contextlib.nullcontext

if self.model.is_adm():
positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
Expand Down Expand Up @@ -612,67 +606,67 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N
else:
max_denoise = True

with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask

ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)

sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)

else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise

if max_denoise:
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]
if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask

ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)

sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)

k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image
self.model_k.noise = noise

if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
else:
noise = noise * sigmas[0]

k_callback = None
total_steps = len(sigmas) - 1
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps)

if latent_image is not None:
noise += latent_image
if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)

return self.model.process_latent_out(samples.to(torch.float32))
3 changes: 2 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def model_patches_to(self, device):
patch_list[k] = patch_list[k].to(device)

def model_dtype(self):
return self.model.get_dtype()
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()

def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = {}
Expand Down

0 comments on commit ddc6f12

Please sign in to comment.