@@ -545,8 +545,12 @@ def q_posterior(self, x_start, x_t, t):
545
545
posterior_log_variance_clipped = extract (self .posterior_log_variance_clipped , t , x_t .shape )
546
546
return posterior_mean , posterior_variance , posterior_log_variance_clipped
547
547
548
- def model_predictions (self , x , t , x_self_cond = None , clip_x_start = False , rederive_pred_noise = False ):
549
- model_output = self .model (x , t , x_self_cond )
548
+ def model_predictions (self , x , t , x_self_cond = None , clip_x_start = False , rederive_pred_noise = False , model_forward_kwargs : dict = dict ()):
549
+
550
+ if exists (x_self_cond ):
551
+ model_forward_kwargs = {** model_forward_kwargs , 'self_cond' : x_self_cond }
552
+
553
+ model_output = self .model (x , t , ** model_forward_kwargs )
550
554
maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
551
555
552
556
if self .objective == 'pred_noise' :
@@ -605,7 +609,7 @@ def p_sample_loop(self, shape):
605
609
return img
606
610
607
611
@torch .no_grad ()
608
- def ddim_sample (self , shape , clip_denoised = True ):
612
+ def ddim_sample (self , shape , clip_denoised = True , model_forward_kwargs : dict = dict () ):
609
613
batch , device , total_timesteps , sampling_timesteps , eta , objective = shape [0 ], self .betas .device , self .num_timesteps , self .sampling_timesteps , self .ddim_sampling_eta , self .objective
610
614
611
615
times = torch .linspace (- 1 , total_timesteps - 1 , steps = sampling_timesteps + 1 ) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -619,7 +623,7 @@ def ddim_sample(self, shape, clip_denoised = True):
619
623
for time , time_next in tqdm (time_pairs , desc = 'sampling loop time step' ):
620
624
time_cond = torch .full ((batch ,), time , device = device , dtype = torch .long )
621
625
self_cond = x_start if self .self_condition else None
622
- pred_noise , x_start , * _ = self .model_predictions (img , time_cond , self_cond , clip_x_start = clip_denoised )
626
+ pred_noise , x_start , * _ = self .model_predictions (img , time_cond , self_cond , clip_x_start = clip_denoised , model_forward_kwargs = model_forward_kwargs )
623
627
624
628
if time_next < 0 :
625
629
img = x_start
@@ -641,12 +645,12 @@ def ddim_sample(self, shape, clip_denoised = True):
641
645
return img
642
646
643
647
@torch .no_grad ()
644
- def sample (self , batch_size = 16 ):
648
+ def sample (self , batch_size = 16 , model_forward_kwargs : dict = dict () ):
645
649
seq_length , channels = self .seq_length , self .channels
646
650
sample_fn = self .p_sample_loop if not self .is_ddim_sampling else self .ddim_sample
647
651
648
652
shape = (batch_size , channels , seq_length ) if self .channel_first else (batch_size , seq_length , channels )
649
- return sample_fn (shape )
653
+ return sample_fn (shape , model_forward_kwargs = model_forward_kwargs )
650
654
651
655
@torch .no_grad ()
652
656
def interpolate (self , x1 , x2 , t = None , lam = 0.5 ):
0 commit comments