Skip to content

Commit 4019202

Browse files
committed
add option to use cfg ++
1 parent 5a0e07f commit 4019202

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,14 @@ You could consider adding a suitable metric to the training loop yourself after
366366
url = {https://api.semanticscholar.org/CorpusID:270562607}
367367
}
368368
```
369+
370+
```bibtex
371+
@article{Chung2024CFGMC,
372+
title = {CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models},
373+
author = {Hyungjin Chung and Jeongsol Kim and Geon Yeong Park and Hyelin Nam and Jong Chul Ye},
374+
journal = {ArXiv},
375+
year = {2024},
376+
volume = {abs/2406.08070},
377+
url = {https://api.semanticscholar.org/CorpusID:270391454}
378+
}
379+
```

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,13 @@ def forward_with_cond_scale(
368368
scaled_logits = null_logits + (logits - null_logits) * cond_scale
369369

370370
if rescaled_phi == 0.:
371-
return scaled_logits
371+
return scaled_logits, null_logits
372372

373373
std_fn = partial(torch.std, dim = tuple(range(1, scaled_logits.ndim)), keepdim = True)
374374
rescaled_logits = scaled_logits * (std_fn(logits) / std_fn(scaled_logits))
375+
interpolated_rescaled_logits = rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi)
375376

376-
return rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi)
377+
return interpolated_rescaled_logits, null_logits
377378

378379
def forward(
379380
self,
@@ -478,7 +479,8 @@ def __init__(
478479
ddim_sampling_eta = 1.,
479480
offset_noise_strength = 0.,
480481
min_snr_loss_weight = False,
481-
min_snr_gamma = 5
482+
min_snr_gamma = 5,
483+
use_cfg_plus_plus = False # https://arxiv.org/pdf/2406.08070
482484
):
483485
super().__init__()
484486
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
@@ -507,6 +509,10 @@ def __init__(
507509
timesteps, = betas.shape
508510
self.num_timesteps = int(timesteps)
509511

512+
# use cfg++ when ddim sampling
513+
514+
self.use_cfg_plus_plus = use_cfg_plus_plus
515+
510516
# sampling related parameters
511517

512518
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
@@ -604,24 +610,33 @@ def q_posterior(self, x_start, x_t, t):
604610
return posterior_mean, posterior_variance, posterior_log_variance_clipped
605611

606612
def model_predictions(self, x, t, classes, cond_scale = 6., rescaled_phi = 0.7, clip_x_start = False):
607-
model_output = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescaled_phi = rescaled_phi)
613+
model_output, model_output_null = self.model.forward_with_cond_scale(x, t, classes, cond_scale = cond_scale, rescaled_phi = rescaled_phi)
608614
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
609615

610616
if self.objective == 'pred_noise':
611-
pred_noise = model_output
617+
pred_noise = model_output if not self.use_cfg_plus_plus else model_output_null
618+
612619
x_start = self.predict_start_from_noise(x, t, pred_noise)
613620
x_start = maybe_clip(x_start)
614621

615622
elif self.objective == 'pred_x0':
616623
x_start = model_output
617624
x_start = maybe_clip(x_start)
618-
pred_noise = self.predict_noise_from_start(x, t, x_start)
625+
x_start_for_pred_noise = x_start if not self.use_cfg_plus_plus else maybe_clip(model_output_null)
626+
627+
pred_noise = self.predict_noise_from_start(x, t, x_start_for_pred_noise)
619628

620629
elif self.objective == 'pred_v':
621630
v = model_output
622631
x_start = self.predict_start_from_v(x, t, v)
623632
x_start = maybe_clip(x_start)
624-
pred_noise = self.predict_noise_from_start(x, t, x_start)
633+
634+
x_start_for_pred_noise = x_start
635+
if self.use_cfg_plus_plus:
636+
x_start_for_pred_noise = self.predict_start_from_v(x, t, model_output_null)
637+
x_start_for_pred_noise = maybe_clip(x_start_for_pred_noise)
638+
639+
pred_noise = self.predict_noise_from_start(x, t, x_start_for_pred_noise)
625640

626641
return ModelPrediction(pred_noise, x_start)
627642

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.0.15'
1+
__version__ = '2.0.16'

0 commit comments

Comments
 (0)