Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions iddm/model/samples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@
from .ddim import DDIMDiffusion
from .ddpm import DDPMDiffusion
from .plms import PLMSDiffusion
from .dpm2 import DPM2Diffusion
from .dpmpp import DPMPlusPlusDiffusion
27 changes: 14 additions & 13 deletions iddm/model/samples/dpmpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,11 @@ def _sample_loop(

# DPM++ 2M (second-order)
if self.order == 2:
if i < len(time_steps) - 1:
# Intermediate noise prediction
x_inter = torch.sqrt(alpha_prev) * x0 + torch.sqrt(
1 - alpha_prev - sigma ** 2) * predicted_noise
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels,
cfg_scale)

# Second-order correction
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2
# 2nd-order correction: only needs model at t_prev, available on all steps
sqrt_term = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term * predicted_noise
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels, cfg_scale)
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2

# DPM++ 3M (third-order)
elif self.order == 3:
Expand All @@ -134,20 +130,25 @@ def _sample_loop(
t_next_tensor = (torch.ones(n) * t_next).long().to(self.device)
alpha_next = self.alpha_hat[t_next_tensor][:, None, None, None]

# First intermediate step, 1e-8 to avoid NaN
# First intermediate step
sqrt_term1 = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
x_inter1 = torch.sqrt(alpha_prev) * x0 + sqrt_term1 * predicted_noise
pred_noise1 = self._get_predicted_noise(model, x_inter1, t_prev_tensor, labels, cfg_scale)

# Second intermediate step, 1e-8 to avoid NaN
# Second intermediate step
sqrt_term2 = torch.sqrt(torch.clamp(1 - alpha_next - sigma ** 2, min=1e-8))
x_inter2 = torch.sqrt(alpha_next) * x0 + sqrt_term2 * pred_noise1
pred_noise2 = self._get_predicted_noise(model, x_inter2, t_next_tensor, labels, cfg_scale)

# Third-order correction
predicted_noise = (23 * predicted_noise - 16 * pred_noise1 + 5 * pred_noise2) / 12
# Or use a more stable variant
# predicted_noise = (18 * predicted_noise - 12 * pred_noise1 + 3 * pred_noise2) / 9
else:
# Last step: no look-ahead available, fallback to 2nd-order correction
sqrt_term_inter = torch.sqrt(torch.clamp(1 - alpha_prev - sigma ** 2, min=1e-8))
x_inter = torch.sqrt(alpha_prev) * x0 + sqrt_term_inter * predicted_noise
predicted_noise_inter = self._get_predicted_noise(model, x_inter, t_prev_tensor, labels,
cfg_scale)
predicted_noise = (3 * predicted_noise - predicted_noise_inter) / 2

# Add noise for stochastic sampling
noise = torch.randn_like(x) if t > 1 else torch.zeros_like(x)
Expand Down
6 changes: 1 addition & 5 deletions iddm/model/samples/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,7 @@ def _sample_loop(
c1 = self.eta * torch.sqrt((1 - alpha_t / alpha_prev) * (1 - alpha_prev) / (1 - alpha_t))
c2 = torch.sqrt((1 - alpha_prev) - c1 ** 2)
p_x = torch.sqrt(alpha_prev) * x0_t + c2 * predicted_noise + c1 * noise
if labels is None and cfg_scale is None:
# Images and time steps input into the model
predicted_noise_next = model(p_x, p_t)
else:
predicted_noise_next = model(p_x, p_t, labels)
predicted_noise_next = self._get_predicted_noise(model, p_x, p_t, labels, cfg_scale)
predicted_noise_prime = (predicted_noise + predicted_noise_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
Expand Down
2 changes: 1 addition & 1 deletion iddm/utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def sample_initializer(sample, image_size, device, schedule_name="linear", **kwa
**kwargs)
elif sample == "dpmpp3m":
diffusion = DPMPlusPlusDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, order=3,
**kwargs)
sample_steps=50, **kwargs)
else:
diffusion = DDPMDiffusion(img_size=image_size, device=device, schedule_name=schedule_name, **kwargs)
logger.warning(msg=f"[{device}]: Setting sample error, we has been automatically set to ddpm.")
Expand Down
Loading