@@ -368,12 +368,13 @@ def forward_with_cond_scale(
368
368
scaled_logits = null_logits + (logits - null_logits ) * cond_scale
369
369
370
370
if rescaled_phi == 0. :
371
- return scaled_logits
371
+ return scaled_logits , null_logits
372
372
373
373
std_fn = partial (torch .std , dim = tuple (range (1 , scaled_logits .ndim )), keepdim = True )
374
374
rescaled_logits = scaled_logits * (std_fn (logits ) / std_fn (scaled_logits ))
375
+ interpolated_rescaled_logits = rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi )
375
376
376
- return rescaled_logits * rescaled_phi + scaled_logits * ( 1. - rescaled_phi )
377
+ return interpolated_rescaled_logits , null_logits
377
378
378
379
def forward (
379
380
self ,
@@ -478,7 +479,8 @@ def __init__(
478
479
ddim_sampling_eta = 1. ,
479
480
offset_noise_strength = 0. ,
480
481
min_snr_loss_weight = False ,
481
- min_snr_gamma = 5
482
+ min_snr_gamma = 5 ,
483
+ use_cfg_plus_plus = False # https://arxiv.org/pdf/2406.08070
482
484
):
483
485
super ().__init__ ()
484
486
assert not (type (self ) == GaussianDiffusion and model .channels != model .out_dim )
@@ -507,6 +509,10 @@ def __init__(
507
509
timesteps , = betas .shape
508
510
self .num_timesteps = int (timesteps )
509
511
512
+ # use cfg++ when ddim sampling
513
+
514
+ self .use_cfg_plus_plus = use_cfg_plus_plus
515
+
510
516
# sampling related parameters
511
517
512
518
self .sampling_timesteps = default (sampling_timesteps , timesteps ) # default num sampling timesteps to number of timesteps at training
@@ -604,24 +610,33 @@ def q_posterior(self, x_start, x_t, t):
604
610
return posterior_mean , posterior_variance , posterior_log_variance_clipped
605
611
606
612
def model_predictions (self , x , t , classes , cond_scale = 6. , rescaled_phi = 0.7 , clip_x_start = False ):
607
- model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescaled_phi = rescaled_phi )
613
+ model_output , model_output_null = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescaled_phi = rescaled_phi )
608
614
maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
609
615
610
616
if self .objective == 'pred_noise' :
611
- pred_noise = model_output
617
+ pred_noise = model_output if not self .use_cfg_plus_plus else model_output_null
618
+
612
619
x_start = self .predict_start_from_noise (x , t , pred_noise )
613
620
x_start = maybe_clip (x_start )
614
621
615
622
elif self .objective == 'pred_x0' :
616
623
x_start = model_output
617
624
x_start = maybe_clip (x_start )
618
- pred_noise = self .predict_noise_from_start (x , t , x_start )
625
+ x_start_for_pred_noise = x_start if not self .use_cfg_plus_plus else maybe_clip (model_output_null )
626
+
627
+ pred_noise = self .predict_noise_from_start (x , t , x_start_for_pred_noise )
619
628
620
629
elif self .objective == 'pred_v' :
621
630
v = model_output
622
631
x_start = self .predict_start_from_v (x , t , v )
623
632
x_start = maybe_clip (x_start )
624
- pred_noise = self .predict_noise_from_start (x , t , x_start )
633
+
634
+ x_start_for_pred_noise = x_start
635
+ if self .use_cfg_plus_plus :
636
+ x_start_for_pred_noise = self .predict_start_from_v (x , t , model_output_null )
637
+ x_start_for_pred_noise = maybe_clip (x_start_for_pred_noise )
638
+
639
+ pred_noise = self .predict_noise_from_start (x , t , x_start_for_pred_noise )
625
640
626
641
return ModelPrediction (pred_noise , x_start )
627
642
0 commit comments