Skip to content

Commit

Permalink
Add SA-Solver
Browse files Browse the repository at this point in the history
  • Loading branch information
chaObserv committed Sep 17, 2024
1 parent 7183fd1 commit 3ab36c0
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 1 deletion.
189 changes: 189 additions & 0 deletions comfy/k_diffusion/sa_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Modify from: https://github.com/scxue/SA-Solver
# MIT license

import torch

def get_coefficients_exponential_positive(order, interval_start, interval_end, tau):
"""
Calculate the integral of exp(x(1+tau^2)) * x^order dx from interval_start to interval_end
For calculating the coefficient of gradient terms after the lagrange interpolation,
see Eq.(15) and Eq.(18) in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
For data_prediction formula.
"""
assert order in [0, 1, 2, 3], "order is only supported for 0, 1, 2 and 3"

# after change of variable(cov)
interval_end_cov = (1 + tau ** 2) * interval_end
interval_start_cov = (1 + tau ** 2) * interval_start

if order == 0:
return (torch.exp(interval_end_cov)
* (1 - torch.exp(-(interval_end_cov - interval_start_cov)))
/ ((1 + tau ** 2))
)
elif order == 1:
return (torch.exp(interval_end_cov)
* ((interval_end_cov - 1) - (interval_start_cov - 1) * torch.exp(-(interval_end_cov - interval_start_cov)))
/ ((1 + tau ** 2) ** 2)
)
elif order == 2:
return (torch.exp(interval_end_cov)
* ((interval_end_cov ** 2 - 2 * interval_end_cov + 2)
- (interval_start_cov ** 2 - 2 * interval_start_cov + 2)
* torch.exp(-(interval_end_cov - interval_start_cov))
)
/ ((1 + tau ** 2) ** 3)
)
elif order == 3:
return (torch.exp(interval_end_cov)
* ((interval_end_cov ** 3 - 3 * interval_end_cov ** 2 + 6 * interval_end_cov - 6)
- (interval_start_cov ** 3 - 3 * interval_start_cov ** 2 + 6 * interval_start_cov - 6)
* torch.exp(-(interval_end_cov - interval_start_cov))
)
/ ((1 + tau ** 2) ** 4)
)

def lagrange_polynomial_coefficient(order, lambda_list):
"""
Calculate the coefficient of lagrange polynomial
For lagrange interpolation
"""
assert order in [0, 1, 2, 3]
assert order == len(lambda_list) - 1
if order == 0:
return [[1.0]]
elif order == 1:
return [[1.0 / (lambda_list[0] - lambda_list[1]), -lambda_list[1] / (lambda_list[0] - lambda_list[1])],
[1.0 / (lambda_list[1] - lambda_list[0]), -lambda_list[0] / (lambda_list[1] - lambda_list[0])]]
elif order == 2:
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2])
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2])
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1])
return [[1.0 / denominator1, (-lambda_list[1] - lambda_list[2]) / denominator1, lambda_list[1] * lambda_list[2] / denominator1],
[1.0 / denominator2, (-lambda_list[0] - lambda_list[2]) / denominator2, lambda_list[0] * lambda_list[2] / denominator2],
[1.0 / denominator3, (-lambda_list[0] - lambda_list[1]) / denominator3, lambda_list[0] * lambda_list[1] / denominator3]
]
elif order == 3:
denominator1 = (lambda_list[0] - lambda_list[1]) * (lambda_list[0] - lambda_list[2]) * (lambda_list[0] - lambda_list[3])
denominator2 = (lambda_list[1] - lambda_list[0]) * (lambda_list[1] - lambda_list[2]) * (lambda_list[1] - lambda_list[3])
denominator3 = (lambda_list[2] - lambda_list[0]) * (lambda_list[2] - lambda_list[1]) * (lambda_list[2] - lambda_list[3])
denominator4 = (lambda_list[3] - lambda_list[0]) * (lambda_list[3] - lambda_list[1]) * (lambda_list[3] - lambda_list[2])
return [[1.0 / denominator1,
(-lambda_list[1] - lambda_list[2] - lambda_list[3]) / denominator1,
(lambda_list[1] * lambda_list[2] + lambda_list[1] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator1,
(-lambda_list[1] * lambda_list[2] * lambda_list[3]) / denominator1],

[1.0 / denominator2,
(-lambda_list[0] - lambda_list[2] - lambda_list[3]) / denominator2,
(lambda_list[0] * lambda_list[2] + lambda_list[0] * lambda_list[3] + lambda_list[2] * lambda_list[3]) / denominator2,
(-lambda_list[0] * lambda_list[2] * lambda_list[3]) / denominator2],

[1.0 / denominator3,
(-lambda_list[0] - lambda_list[1] - lambda_list[3]) / denominator3,
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[3] + lambda_list[1] * lambda_list[3]) / denominator3,
(-lambda_list[0] * lambda_list[1] * lambda_list[3]) / denominator3],

[1.0 / denominator4,
(-lambda_list[0] - lambda_list[1] - lambda_list[2]) / denominator4,
(lambda_list[0] * lambda_list[1] + lambda_list[0] * lambda_list[2] + lambda_list[1] * lambda_list[2]) / denominator4,
(-lambda_list[0] * lambda_list[1] * lambda_list[2]) / denominator4]
]

def get_coefficients_fn(order, interval_start, interval_end, lambda_list, tau):
"""
Calculate the coefficient of gradients.
"""
assert order in [1, 2, 3, 4]
assert order == len(lambda_list), 'the length of lambda list must be equal to the order'
lagrange_coefficient = lagrange_polynomial_coefficient(order - 1, lambda_list)
coefficients = [sum(lagrange_coefficient[i][j] * get_coefficients_exponential_positive(order - 1 - j, interval_start, interval_end, tau)
for j in range(order))
for i in range(order)]
assert len(coefficients) == order, 'the length of coefficients does not match the order'
return coefficients

def adams_bashforth_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
"""
SA-Predictor, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
"""

assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
t_fn = lambda sigma: sigma.log().neg()
sigma_prev = sigma_prev_list[-1]
gradient_part = torch.zeros_like(x)
lambda_list = [t_fn(sigma_prev_list[-(i + 1)]) for i in range(order)]
lambda_t = t_fn(sigma)
lambda_prev = lambda_list[0]
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)

if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to unipc. Note: This is used only for few steps sampling.
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
# ODE case
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h ** 2 / 2 - (h - 1 + torch.exp(-h))) / (ns.marginal_lambda(t_prev_list[-1]) - ns.marginal_lambda(t_prev_list[-2]))
gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
/ (lambda_prev - lambda_list[1])
)
gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h ** 2 / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h))) / ((1 + tau ** 2) ** 2))
/ (lambda_prev - lambda_list[1])
)

for i in range(order):
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
return x_t

def adams_moulton_update_few_steps(order, x, tau, model_prev_list, sigma_prev_list, noise, sigma):
"""
SA-Corrector, with the "rescaling" trick in Appendix D in SA-Solver paper https://arxiv.org/pdf/2309.05019.pdf
"""

assert order in [1, 2, 3, 4], "order of stochastic adams bashforth method is only supported for 1, 2, 3 and 4"
t_fn = lambda sigma: sigma.log().neg()
sigma_prev = sigma_prev_list[-1]
gradient_part = torch.zeros_like(x)
sigma_list = sigma_prev_list + [sigma]
lambda_list = [t_fn(sigma_list[-(i + 1)]) for i in range(order)]
lambda_t = lambda_list[0]
lambda_prev = lambda_list[1]
h = lambda_t - lambda_prev
gradient_coefficients = get_coefficients_fn(order, lambda_prev, lambda_t, lambda_list, tau)

if order == 2: ## if order = 2 we do a modification that does not influence the convergence order similar to UniPC. Note: This is used only for few steps sampling.
# The added term is O(h^3). Empirically we find it will slightly improve the image quality.
# ODE case
# gradient_coefficients[0] += 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
# gradient_coefficients[1] -= 1.0 * torch.exp(lambda_t) * (h / 2 - (h - 1 + torch.exp(-h)) / h)
gradient_coefficients[0] += (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
/ ((1 + tau ** 2) ** 2 * h))
)
gradient_coefficients[1] -= (1.0 * torch.exp((1 + tau ** 2) * lambda_t)
* (h / 2 - (h * (1 + tau ** 2) - 1 + torch.exp((1 + tau ** 2) * (-h)))
/ ((1 + tau ** 2) ** 2 * h))
)

for i in range(order):
gradient_part += gradient_coefficients[i] * model_prev_list[-(i + 1)]
gradient_part *= (1 + tau ** 2) * sigma * torch.exp(- tau ** 2 * lambda_t)
noise_part = 0 if tau == 0 else sigma * torch.sqrt(1. - torch.exp(-2 * tau ** 2 * h)) * noise
x_t = torch.exp(-tau ** 2 * h) * (sigma / sigma_prev) * x + gradient_part + noise_part
return x_t

def device_noise_sampler(x, noise_device='gpu'):
if noise_device == "gpu":
return torch.randn_like(x)
else:
return torch.randn(x.shape, device='cpu').to(x.device)

# Default tau function from https://github.com/scxue/SA-Solver?tab=readme-ov-file#-abstract
def default_tau_func(sigma, eta, eta_start_sigma, eta_end_sigma):
if eta == 0:
# Pure ODE
return 0
return eta if eta_end_sigma <= sigma <= eta_start_sigma else 0
112 changes: 112 additions & 0 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from functools import partial

from scipy import integrate
import torch
Expand All @@ -8,6 +9,7 @@

from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling

Expand Down Expand Up @@ -1050,6 +1052,116 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None,

return x_next

# Modify from: https://github.com/scxue/SA-Solver
# MIT license
@torch.no_grad()
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, pc_mode="PEC", tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x

if sigmas[-1] == 0:
sigmas = sigmas.clone()
sigmas[-1] = 0.001

extra_args = {} if extra_args is None else extra_args
if tau_func is None:
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
start_sigma = model_sampling.percent_to_sigma(0.2)
end_sigma = model_sampling.percent_to_sigma(0.8)
tau_func = partial(sa_solver.default_tau_func, eta=1.0, eta_start_sigma=start_sigma, eta_end_sigma=end_sigma)
tau = tau_func
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='cpu') if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])

sigma_prev_list = []
model_prev_list = []

for i in trange(len(sigmas) - 1, disable=disable):
sigma = sigmas[i]
if i == 0:
# Init the initial values.
denoised = model(x, sigma * s_in, **extra_args)
model_prev_list.append(denoised)
sigma_prev_list.append(sigma)
else:
# Lower order final
predictor_order_used = min(predictor_order, i, len(sigmas) - i)
corrector_order_used = min(corrector_order, i + 1, len(sigmas) - i + 1)

tau_val = tau(sigma)
noise = None if tau_val == 0 else noise_sampler()

# Predictor step
x_p = sa_solver.adams_bashforth_update_few_steps(order=predictor_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)

# Evaluation step
denoised = model(x_p, sigma * s_in, **extra_args)

# Update model_list
model_prev_list.append(denoised)

# Corrector step
if corrector_order_used > 0:
x = sa_solver.adams_moulton_update_few_steps(order=corrector_order_used, x=x, tau=tau_val,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=noise, sigma=sigma)

else:
x = x_p

del noise, x_p

# Evaluation step if mode = pece and step != steps
if corrector_order_used > 0 and pc_mode == 'PECE':
del model_prev_list[-1]
denoised = model(x, sigma * s_in, **extra_args)
model_prev_list.append(denoised)

sigma_prev_list.append(sigma)
if len(model_prev_list) > max(predictor_order, corrector_order):
del model_prev_list[0]
del sigma_prev_list[0]

if callback is not None:
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})

# Extra final step
x = sa_solver.adams_bashforth_update_few_steps(order=1, x=x, tau=0,
model_prev_list=model_prev_list, sigma_prev_list=sigma_prev_list,
noise=0, sigma=sigmas[-1])
return x

@torch.no_grad()
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
predictor_order=predictor_order, corrector_order=corrector_order,
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
)

@torch.no_grad()
def sample_sa_solver_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
predictor_order=predictor_order, corrector_order=corrector_order,
pc_mode="PEC", tau_func=tau_func, noise_sampler=noise_sampler,
)

@torch.no_grad()
def sample_sa_solver_pece_gpu(model, x, sigmas, extra_args=None, callback=None, disable=False, predictor_order=3, corrector_order=4, tau_func=None, noise_sampler=None):
if len(sigmas) <= 1:
return x
noise_sampler = partial(sa_solver.device_noise_sampler, x=x, noise_device='gpu') if noise_sampler is None else noise_sampler
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable,
predictor_order=predictor_order, corrector_order=corrector_order,
pc_mode="PECE", tau_func=tau_func, noise_sampler=noise_sampler,
)

@torch.no_grad()
def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
Expand Down
2 changes: 1 addition & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def max_denoise(self, model_wrap, sigmas):
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
"ipndm", "ipndm_v", "deis", 'sa_solver', "sa_solver_gpu", "sa_solver_pece", "sa_solver_pece_gpu"]

class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
Expand Down
Loading

0 comments on commit 3ab36c0

Please sign in to comment.