Skip to content

Commit d2ab4ac

Browse files
committed
another change for lbm
1 parent 34b430d commit d2ab4ac

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,12 @@ def q_posterior(self, x_start, x_t, t):
545545
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
546546
return posterior_mean, posterior_variance, posterior_log_variance_clipped
547547

548-
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
549-
model_output = self.model(x, t, x_self_cond)
548+
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False, model_forward_kwargs: dict = dict()):
549+
550+
if exists(x_self_cond):
551+
model_forward_kwargs = {**model_forward_kwargs, 'self_cond': x_self_cond}
552+
553+
model_output = self.model(x, t, **model_forward_kwargs)
550554
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
551555

552556
if self.objective == 'pred_noise':
@@ -605,7 +609,7 @@ def p_sample_loop(self, shape):
605609
return img
606610

607611
@torch.no_grad()
608-
def ddim_sample(self, shape, clip_denoised = True):
612+
def ddim_sample(self, shape, clip_denoised = True, model_forward_kwargs: dict = dict()):
609613
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
610614

611615
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
@@ -619,7 +623,7 @@ def ddim_sample(self, shape, clip_denoised = True):
619623
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
620624
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
621625
self_cond = x_start if self.self_condition else None
622-
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = clip_denoised)
626+
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = clip_denoised, model_forward_kwargs = model_forward_kwargs)
623627

624628
if time_next < 0:
625629
img = x_start
@@ -641,12 +645,12 @@ def ddim_sample(self, shape, clip_denoised = True):
641645
return img
642646

643647
@torch.no_grad()
644-
def sample(self, batch_size = 16):
648+
def sample(self, batch_size = 16, model_forward_kwargs: dict = dict()):
645649
seq_length, channels = self.seq_length, self.channels
646650
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
647651

648652
shape = (batch_size, channels, seq_length) if self.channel_first else (batch_size, seq_length, channels)
649-
return sample_fn(shape)
653+
return sample_fn(shape, model_forward_kwargs = model_forward_kwargs)
650654

651655
@torch.no_grad()
652656
def interpolate(self, x1, x2, t = None, lam = 0.5):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.2.0'
1+
__version__ = '2.2.1'

0 commit comments

Comments
 (0)