From 26b94beb5c0c130322186233b1cf249cd957b354 Mon Sep 17 00:00:00 2001 From: cheny Date: Thu, 30 Apr 2026 10:49:11 +0800 Subject: [PATCH 1/3] fix(agent): plms.py first iteration bypasses CFG. --- iddm/model/samples/plms.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/iddm/model/samples/plms.py b/iddm/model/samples/plms.py index 9e89098..54dac63 100644 --- a/iddm/model/samples/plms.py +++ b/iddm/model/samples/plms.py @@ -111,11 +111,7 @@ def _sample_loop( c1 = self.eta * torch.sqrt((1 - alpha_t / alpha_prev) * (1 - alpha_prev) / (1 - alpha_t)) c2 = torch.sqrt((1 - alpha_prev) - c1 ** 2) p_x = torch.sqrt(alpha_prev) * x0_t + c2 * predicted_noise + c1 * noise - if labels is None and cfg_scale is None: - # Images and time steps input into the model - predicted_noise_next = model(p_x, p_t) - else: - predicted_noise_next = model(p_x, p_t, labels) + predicted_noise_next = self._get_predicted_noise(model, p_x, p_t, labels, cfg_scale) predicted_noise_prime = (predicted_noise + predicted_noise_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) From 235e0f8b95d075ead761295acc58013f050d7a7b Mon Sep 17 00:00:00 2001 From: cheny Date: Thu, 30 Apr 2026 10:50:56 +0800 Subject: [PATCH 2/3] fix(agent): dpmpp.py last step high-order correction missing. --- iddm/model/samples/dpmpp.py | 27 ++++++++++++++------------- iddm/utils/initializer.py | 2 +- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/iddm/model/samples/dpmpp.py b/iddm/model/samples/dpmpp.py index c06c2ff..c2ac162 100644 --- a/iddm/model/samples/dpmpp.py +++ b/iddm/model/samples/dpmpp.py @@ -117,15 +117,11 @@ def _sample_loop( # DPM++ 2M (second-order) if self.order == 2: - if i < len(time_steps) - 1: - # Intermediate noise prediction - x_inter = torch.sqrt(alpha_prev) * x0 + torch.sqrt( - 1 - alpha_prev - sigma ** 2) * predicted_noise - predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, - cfg_scale) - - # Second-order correction - predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2 + # 2nd-order correction: only needs model at t_prev, available on all steps + sqrt_term = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8)) + x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term * predicted_noise + predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, cfg_scale) + predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2 # DPM++ 3M (third-order) elif self.order == 3: @@ -134,20 +130,25 @@ def _sample_loop( t_next_tensor = (torch.ones(n) * t_next).long().to(self.device) alpha_next = self.alpha_hat[t_next_tensor][:, None, None, None] - # First intermediate step, 1e-8 to avoid NaN + # First intermediate step sqrt_term1 = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8)) x_inter1 = torch.sqrt(alpha_prev) * x0 + sqrt_term1 * predicted_noise pred_noise1 = self._get_predicted_noise(model, x_inter1, t_prev_tensor, labels, cfg_scale) - # Second intermediate step, 1e-8 to avoid NaN + # Second intermediate step sqrt_term2 = torch.sqrt(torch.clamp(1 - alpha_next - sigma ** 2, min=1e-8)) x_inter2 = torch.sqrt(alpha_next) * x0 + sqrt_term2 * pred_noise1 pred_noise2 = self._get_predicted_noise(model, x_inter2, t_next_tensor, labels, cfg_scale) # Third-order correction predicted_noise = (23 * predicted_noise - 16 * pred_noise1 + 5 * pred_noise2) / 12 - # Or use a more stable variant - # predicted_noise = (18 * predicted_noise - 12 * pred_noise1 + 3 * pred_noise2) / 9 + else: + # Last step: no look-ahead available, fallback to 2nd-order correction + sqrt_term_inter = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8)) + x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term_inter * predicted_noise + predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, + cfg_scale) + predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2 # Add noise for stochastic sampling noise = torch.randn_like(x) if t > 1 else torch.zeros_like(x) diff --git a/iddm/utils/initializer.py b/iddm/utils/initializer.py index 67db1fd..de9bddc 100644 --- a/iddm/utils/initializer.py +++ b/iddm/utils/initializer.py @@ -236,7 +236,7 @@ def sample_initializer(sample, image_size, device, schedule_name="linear", **kwa **kwargs) elif sample == "dpmpp3m": diffusion = DPMPlusPlusDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, order=3, - **kwargs) + sample_steps=50, **kwargs) else: diffusion = DDPMDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, **kwargs) logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.") From 7352f06899ff8de626b418c16bd17bebee3ca570 Mon Sep 17 00:00:00 2001 From: cheny Date: Thu, 30 Apr 2026 10:51:35 +0800 Subject: [PATCH 3/3] fix(agent): init .py missing DPM2/DPM++ exports. --- iddm/model/samples/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/iddm/model/samples/__init__.py b/iddm/model/samples/__init__.py index c2f8c97..70b487a 100644 --- a/iddm/model/samples/__init__.py +++ b/iddm/model/samples/__init__.py @@ -24,3 +24,5 @@ from .ddim import DDIMDiffusion from .ddpm import DDPMDiffusion from .plms import PLMSDiffusion +from .dpm2 import DPM2Diffusion +from .dpmpp import DPMPlusPlusDiffusion