diff --git a/k_diffusion/external.py b/k_diffusion/external.py index e8563a3..9acfcf7 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -1,17 +1,75 @@ import math import torch -from torch import nn +from torch import nn, Tensor +from typing import Protocol, Generic, TypeVar, TYPE_CHECKING from . import sampling, utils +class AbstractModel(Protocol): + def __call__(self, *args, **kwargs) -> Tensor: ... + +class DenoiserModel(AbstractModel): + def __call__(self, x: Tensor, t: Tensor, *args, **kwargs) -> Tensor: ... + +class CompVisModel(AbstractModel): + alphas_cumprod: Tensor + def apply_model(self, x: Tensor, t: Tensor, cond: Tensor) -> Tensor: ... + +class WrappedModelProto(AbstractModel): + def sigma_to_t(self, sigma: Tensor) -> Tensor: ... + def t_to_sigma(self, t: Tensor) -> Tensor: ... + def discretize_sigma(self, sigma: Tensor) -> Tensor: ... + +# the 'default' arg of TypeVar isn't valid at runtime, but amazingly seems to be utilised +# at compile-time by some type-checkers +# https://github.com/python/mypy/issues/4236#issuecomment-344660299 +if TYPE_CHECKING: + TModel = TypeVar('TModel', bound=AbstractModel, default=AbstractModel) + TDenoiserModel = TypeVar('TDenoiserModel', bound=DenoiserModel, default=DenoiserModel) + TCompVisModel = TypeVar('TCompVisModel', bound=CompVisModel, default=CompVisModel) +else: + TModel = TypeVar('TModel', bound=AbstractModel) + TDenoiserModel = TypeVar('TDenoiserModel', bound=DenoiserModel) + TCompVisModel = TypeVar('TCompVisModel', bound=CompVisModel) + +class BaseModelWrapper(nn.Module, WrappedModelProto, Generic[TModel]): + inner_model: TModel + + """The base wrapper class for the k-diffusion model wrapper idiom. Model + wrappers should subclass this class and customize the behavior of the + wrapped model by implementing or overriding methods.""" + def __init__(self, model: TModel): + super().__init__() + self.inner_model = model + + def __dir__(self): + return list(set(super().__dir__() + dir(self.inner_model))) -class VDenoiser(nn.Module): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.inner_model, name) + + def forward(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def sigma_to_t(self, sigma: Tensor) -> Tensor: + return sigma + + def t_to_sigma(self, t: Tensor) -> Tensor: + return t + + def discretize_sigma(self, sigma: Tensor) -> Tensor: + return self.t_to_sigma(self.sigma_to_t(sigma)) + + +class VDenoiser(BaseModelWrapper[TDenoiserModel]): """A v-diffusion-pytorch model wrapper for k-diffusion.""" - def __init__(self, inner_model): - super().__init__() - self.inner_model = inner_model + def __init__(self, model: TDenoiserModel): + super().__init__(model) self.sigma_data = 1. def get_scalings(self, sigma): @@ -38,12 +96,12 @@ def forward(self, input, sigma, **kwargs): return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip -class DiscreteSchedule(nn.Module): +class DiscreteSchedule(BaseModelWrapper[TModel]): """A mapping between continuous noise levels (sigmas) and a list of discrete noise levels.""" - def __init__(self, sigmas, quantize): - super().__init__() + def __init__(self, sigmas, quantize, model: TModel): + super().__init__(model) self.register_buffer('sigmas', sigmas) self.register_buffer('log_sigmas', sigmas.log()) self.quantize = quantize @@ -86,13 +144,12 @@ def t_to_sigma(self, t): return log_sigma.exp() -class DiscreteEpsDDPMDenoiser(DiscreteSchedule): +class DiscreteEpsDDPMDenoiser(DiscreteSchedule[TModel]): """A wrapper for discrete schedule DDPM models that output eps (the predicted noise).""" - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model + def __init__(self, model: TModel, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize, model) self.sigma_data = 1. def get_scalings(self, sigma): @@ -115,10 +172,10 @@ def forward(self, input, sigma, **kwargs): return input + eps * c_out -class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): +class OpenAIDenoiser(DiscreteEpsDDPMDenoiser[TModel]): """A wrapper for OpenAI diffusion models.""" - def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): + def __init__(self, model: TModel, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) super().__init__(model, alphas_cumprod, quantize=quantize) self.has_learned_sigmas = has_learned_sigmas @@ -130,22 +187,21 @@ def get_eps(self, *args, **kwargs): return model_output -class CompVisDenoiser(DiscreteEpsDDPMDenoiser): +class CompVisDenoiser(DiscreteEpsDDPMDenoiser[TCompVisModel]): """A wrapper for CompVis diffusion models.""" - def __init__(self, model, quantize=False, device='cpu'): + def __init__(self, model: TCompVisModel, quantize=False, device='cpu'): super().__init__(model, model.alphas_cumprod, quantize=quantize) def get_eps(self, *args, **kwargs): return self.inner_model.apply_model(*args, **kwargs) -class DiscreteVDDPMDenoiser(DiscreteSchedule): +class DiscreteVDDPMDenoiser(DiscreteSchedule[TModel]): """A wrapper for discrete schedule DDPM models that output v.""" - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model + def __init__(self, model: TModel, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize, model) self.sigma_data = 1. def get_scalings(self, sigma): @@ -169,10 +225,10 @@ def forward(self, input, sigma, **kwargs): return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip -class CompVisVDenoiser(DiscreteVDDPMDenoiser): +class CompVisVDenoiser(DiscreteVDDPMDenoiser[TCompVisModel]): """A wrapper for CompVis diffusion models that output v.""" - def __init__(self, model, quantize=False, device='cpu'): + def __init__(self, model: TCompVisModel, quantize=False, device='cpu'): super().__init__(model, model.alphas_cumprod, quantize=quantize) def get_v(self, x, t, cond, **kwargs): diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index cf2a5c1..0328d50 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -8,6 +8,7 @@ from tqdm.auto import trange, tqdm from . import utils +from .external import WrappedModelProto def append_zero(x): @@ -115,7 +116,7 @@ def __call__(self, sigma, sigma_next): @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -123,6 +124,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = model.discretize_sigma(sigma_hat) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -136,7 +138,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_euler_ancestral(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -156,7 +158,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -164,6 +166,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = model.discretize_sigma(sigma_hat) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -185,7 +188,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -193,6 +196,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = model.discretize_sigma(sigma_hat) if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -216,7 +220,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpm_2_ancestral(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -258,7 +262,7 @@ def fn(tau): @torch.no_grad() -def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): +def sample_lms(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, order=4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) sigmas_cpu = sigmas.detach().cpu().numpy() @@ -278,7 +282,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o @torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): +def log_likelihood(model: WrappedModelProto, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) v = torch.randint_like(x, 2) * 2 - 1 @@ -332,8 +336,9 @@ def propose_step(self, error): class DPMSolver(nn.Module): """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" + model: WrappedModelProto - def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + def __init__(self, model: WrappedModelProto, extra_args=None, eps_callback=None, info_callback=None): super().__init__() self.model = model self.extra_args = {} if extra_args is None else extra_args @@ -479,7 +484,7 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 @torch.no_grad() -def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): +def sample_dpm_fast(model: WrappedModelProto, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') @@ -491,7 +496,7 @@ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback @torch.no_grad() -def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): +def sample_dpm_adaptive(model: WrappedModelProto, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: raise ValueError('sigma_min and sigma_max must not be 0') @@ -506,7 +511,7 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac @torch.no_grad() -def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_2s_ancestral(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler @@ -540,7 +545,7 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() -def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): +def sample_dpmpp_sde(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler @@ -582,7 +587,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N @torch.no_grad() -def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None, warmup_lms=False): +def sample_dpmpp_2m(model: WrappedModelProto, x, sigmas, extra_args=None, callback=None, disable=None, warmup_lms=False): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]])