diff --git a/iddm/model/samples/__init__.py b/iddm/model/samples/__init__.py index c2f8c97..70b487a 100644 --- a/iddm/model/samples/__init__.py +++ b/iddm/model/samples/__init__.py @@ -24,3 +24,5 @@ from .ddim import DDIMDiffusion from .ddpm import DDPMDiffusion from .plms import PLMSDiffusion +from .dpm2 import DPM2Diffusion +from .dpmpp import DPMPlusPlusDiffusion diff --git a/iddm/model/samples/dpmpp.py b/iddm/model/samples/dpmpp.py index c06c2ff..c2ac162 100644 --- a/iddm/model/samples/dpmpp.py +++ b/iddm/model/samples/dpmpp.py @@ -117,15 +117,11 @@ def _sample_loop( # DPM++ 2M (second-order) if self.order == 2: - if i < len(time_steps) - 1: - # Intermediate noise prediction - x_inter = torch.sqrt(alpha_prev) * x0 + torch.sqrt( - 1 - alpha_prev - sigma ** 2) * predicted_noise - predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, - cfg_scale) - - # Second-order correction - predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2 + # 2nd-order correction: only needs model at t_prev, available on all steps + sqrt_term = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8)) + x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term * predicted_noise + predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, cfg_scale) + predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2 # DPM++ 3M (third-order) elif self.order == 3: @@ -134,20 +130,25 @@ def _sample_loop( t_next_tensor = (torch.ones(n) * t_next).long().to(self.device) alpha_next = self.alpha_hat[t_next_tensor][:, None, None, None] - # First intermediate step, 1e-8 to avoid NaN + # First intermediate step sqrt_term1 = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8)) x_inter1 = torch.sqrt(alpha_prev) * x0 + sqrt_term1 * predicted_noise pred_noise1 = self._get_predicted_noise(model, x_inter1, t_prev_tensor, labels, cfg_scale) - # Second intermediate step, 1e-8 to avoid NaN + # Second intermediate step sqrt_term2 = torch.sqrt(torch.clamp(1 - alpha_next - sigma ** 2, min=1e-8)) x_inter2 = torch.sqrt(alpha_next) * x0 + sqrt_term2 * pred_noise1 pred_noise2 = self._get_predicted_noise(model, x_inter2, t_next_tensor, labels, cfg_scale) # Third-order correction predicted_noise = (23 * predicted_noise - 16 * pred_noise1 + 5 * pred_noise2) / 12 - # Or use a more stable variant - # predicted_noise = (18 * predicted_noise - 12 * pred_noise1 + 3 * pred_noise2) / 9 + else: + # Last step: no look-ahead available, fallback to 2nd-order correction + sqrt_term_inter = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8)) + x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term_inter * predicted_noise + predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, + cfg_scale) + predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2 # Add noise for stochastic sampling noise = torch.randn_like(x) if t > 1 else torch.zeros_like(x) diff --git a/iddm/model/samples/plms.py b/iddm/model/samples/plms.py index 9e89098..54dac63 100644 --- a/iddm/model/samples/plms.py +++ b/iddm/model/samples/plms.py @@ -111,11 +111,7 @@ def _sample_loop( c1 = self.eta * torch.sqrt((1 - alpha_t / alpha_prev) * (1 - alpha_prev) / (1 - alpha_t)) c2 = torch.sqrt((1 - alpha_prev) - c1 ** 2) p_x = torch.sqrt(alpha_prev) * x0_t + c2 * predicted_noise + c1 * noise - if labels is None and cfg_scale is None: - # Images and time steps input into the model - predicted_noise_next = model(p_x, p_t) - else: - predicted_noise_next = model(p_x, p_t, labels) + predicted_noise_next = self._get_predicted_noise(model, p_x, p_t, labels, cfg_scale) predicted_noise_prime = (predicted_noise + predicted_noise_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) diff --git a/iddm/utils/initializer.py b/iddm/utils/initializer.py index 67db1fd..de9bddc 100644 --- a/iddm/utils/initializer.py +++ b/iddm/utils/initializer.py @@ -236,7 +236,7 @@ def sample_initializer(sample, image_size, device, schedule_name="linear", **kwa **kwargs) elif sample == "dpmpp3m": diffusion = DPMPlusPlusDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, order=3, - **kwargs) + sample_steps=50, **kwargs) else: diffusion = DDPMDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, **kwargs) logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.")