Skip to content

Commit

Permalink
Fix lambda_prev in corrector
Browse files Browse the repository at this point in the history
  • Loading branch information
chaObserv committed Sep 17, 2024
1 parent 3ab36c0 commit 1bbb13b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion comfy/k_diffusion/sa_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_li
sigma_list = sigma_prev_list + [sigma]
lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)]
lambda_t = lambda_list[0]
lambda_prev = lambda_list[1]
lambda_prev = lambda_list[1] if order >= 2 else t_fn(sigma_prev)
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)

Expand Down
12 changes: 6 additions & 6 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,8 +1093,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F

# Predictor step
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)

# Evaluation step
denoised = model(x_p, sigma * s_in, **extra_args)
Expand All @@ -1105,8 +1105,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
# Corrector step
if corrector_order_used > 0:
x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)

else:
x = x_p
Expand All @@ -1129,8 +1129,8 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F

# Extra final step
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
return x

@torch.no_grad()
Expand Down

0 comments on commit 1bbb13b

Please sign in to comment.