From 79f357a45168a7dc5166a2fb9ff48515cd5dd56c Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 9 May 2025 15:10:27 +0100 Subject: [PATCH 1/8] Fix AutoencoderKL docstrings. --- monai/networks/nets/autoencoderkl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index af191e748b..f9cec5800e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -153,9 +153,9 @@ class Encoder(nn.Module): channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. num_res_blocks: number of residual blocks (see _ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from num_channels contain an attention block. + attention_levels: indicate which level from channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to False. @@ -299,9 +299,9 @@ class Decoder(nn.Module): in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. num_res_blocks: number of residual blocks (see _ResBlock) per level. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from num_channels contain an attention block. + attention_levels: indicate which level from channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. include_fc: whether to include the final linear layer. Default to True. @@ -483,7 +483,7 @@ class AutoencoderKL(nn.Module): channels: number of output channels for each block. attention_levels: sequence of levels to add attention. latent_channels: latent embedding dimension. - norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number. norm_eps: epsilon for the normalization. with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. @@ -518,10 +518,10 @@ def __init__( # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels): - raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") + raise ValueError("AutoencoderKL expects all channels being multiple of norm_num_groups") if len(channels) != len(attention_levels): - raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") + raise ValueError("AutoencoderKL expects channels being same size of attention_levels") if isinstance(num_res_blocks, int): num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels)) @@ -529,7 +529,7 @@ def __init__( if len(num_res_blocks) != len(channels): raise ValueError( "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " - "`num_channels`." + "`channels`." ) self.encoder: nn.Module = Encoder( From 35c2edfc542843dc7c6f685cf093437264d5b971 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 9 May 2025 15:15:45 +0100 Subject: [PATCH 2/8] Sign off. Signed-off-by: Virginia Fernandez --- monai/networks/nets/autoencoderkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index f9cec5800e..385e90450d 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -619,7 +619,7 @@ def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: Args: z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image - z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image Returns: sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] From f60c087a53cfcdb73ddd2d87a940c94cb641629d Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 9 May 2025 15:17:14 +0100 Subject: [PATCH 3/8] Remove empty space Signed-off-by: Virginia Fernandez --- monai/networks/nets/autoencoderkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 385e90450d..f9cec5800e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -619,7 +619,7 @@ def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: Args: z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image - z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image + z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image Returns: sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE] From 0381eea4c8b628db5f11f434d6d8eec03e486e67 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 9 May 2025 15:36:34 +0100 Subject: [PATCH 4/8] DCO Remediation Commit for Virginia Fernandez \nI, Virginia Fernandez , hereby add my Signed-off-by to this commit: 79f357a45168a7dc5166a2fb9ff48515cd5dd56c\nSigned-off-by: Virginia Fernandez Signed-off-by: Virginia Fernandez --- monai/networks/nets/autoencoderkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index f9cec5800e..32b5b3bf81 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -614,7 +614,7 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: """ From the mean and sigma representations resulting of encoding an image through the latent space, - obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and adding the mean. Args: From 0ae0a685be0627d7896bfd3791d0727c56b578b8 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 9 May 2025 15:36:52 +0100 Subject: [PATCH 5/8] Remove empty space Signed-off-by: Virginia Fernandez --- monai/networks/nets/autoencoderkl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 32b5b3bf81..f9cec5800e 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -614,7 +614,7 @@ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor: """ From the mean and sigma representations resulting of encoding an image through the latent space, - obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and + obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and adding the mean. Args: From 74f5aee8faa011ad78a139538d80b3997d1e18af Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 23 May 2025 18:05:18 +0100 Subject: [PATCH 6/8] Add classifier-free guidance to MONAI inferers and tests for them Signed-off-by: Virginia Fernandez --- monai/inferers/inferer.py | 67 ++++++++++++++++--- tests/inferers/test_controlnet_inferers.py | 24 ++++--- tests/inferers/test_diffusion_inferer.py | 53 +++++++++++++++ .../inferers/test_latent_diffusion_inferer.py | 50 ++++++++++++++ 4 files changed, 176 insertions(+), 18 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index bfb2756ebe..67b87dc6db 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -839,6 +839,7 @@ def sample( mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -851,6 +852,7 @@ def sample( mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -877,15 +879,31 @@ def sample( if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model ) + if ( + cfg is not None + ): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated. + model_input = torch.cat([image] * 2, dim=0) + if conditioning is not None: + uncondition = torch.ones_like(conditioning) + uncondition.fill_(-1) + conditioning_input = torch.cat([uncondition, conditioning], dim=0) + else: + conditioning_input = None + else: + model_input = image + conditioning_input = conditioning if mode == "concat" and conditioning is not None: - model_input = torch.cat([image, conditioning], dim=1) + model_input = torch.cat([model_input, conditioning], dim=1) model_output = diffusion_model( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None ) else: model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input ) + if cfg is not None: + model_output_uncond, model_output_cond = model_output.chunk(2) + model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 2. compute previous image: x_t -> x_t-1 if not isinstance(scheduler, RFlowScheduler): @@ -1166,6 +1184,7 @@ def sample( # type: ignore[override] mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1180,6 +1199,7 @@ def sample( # type: ignore[override] verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if ( @@ -1203,6 +1223,7 @@ def sample( # type: ignore[override] mode=mode, verbose=verbose, seg=seg, + cfg=cfg, ) if save_intermediates: @@ -1381,6 +1402,7 @@ def sample( # type: ignore[override] mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1395,6 +1417,7 @@ def sample( # type: ignore[override] mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1413,14 +1436,31 @@ def sample( # type: ignore[override] progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps)) intermediates = [] + if cfg is not None: + cn_cond = torch.cat([cn_cond] * 2, dim=0) + for t, next_t in progress_bar: + # Controlnet prediction + if cfg is not None: + model_input = torch.cat([image] * 2, dim=0) + if conditioning is not None: + uncondition = torch.ones_like(conditioning) + uncondition.fill_(-1) + conditioning_input = torch.cat([uncondition, conditioning], dim=0) + else: + conditioning_input = None + else: + model_input = image + conditioning_input = conditioning + + # Diffusion model prediction diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat" and conditioning is not None: + if mode == "concat" and conditioning_input is not None: # 1. Conditioning - model_input = torch.cat([image, conditioning], dim=1) + model_input = torch.cat([model_input, conditioning_input], dim=1) # 2. ControlNet forward down_block_res_samples, mid_block_res_sample = controlnet( x=model_input, @@ -1437,20 +1477,28 @@ def sample( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + # 1. Controlnet forward down_block_res_samples, mid_block_res_sample = controlnet( - x=image, + x=model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond, - context=conditioning, + context=conditioning_input, ) + # 2. predict noise model_output model_output = diffuse( - image, + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), - context=conditioning, + context=conditioning_input, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ) + # If classifier-free guidance isn't None, we split and compute the weighting between + # conditioned and unconditioned output. + if cfg is not None: + model_output_uncond, model_output_cond = model_output.chunk(2) + model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) + # 3. compute previous image: x_t -> x_t-1 if not isinstance(scheduler, RFlowScheduler): image, _ = scheduler.step(model_output, t, image) # type: ignore @@ -1714,6 +1762,7 @@ def sample( # type: ignore[override] mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1730,6 +1779,7 @@ def sample( # type: ignore[override] verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if ( @@ -1757,6 +1807,7 @@ def sample( # type: ignore[override] mode=mode, verbose=verbose, seg=seg, + cfg=cfg, ) if save_intermediates: diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 1ce81a71d5..2b6777a75f 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -482,16 +482,20 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape scheduler = DDPMScheduler(num_train_timesteps=10) inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - controlnet=controlnet, - cn_cond=mask, - save_intermediates=True, - intermediate_steps=1, - ) - self.assertEqual(len(intermediates), 10) + + for cfg in [5, None]: + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + cfg=cfg, + ) + + self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 59b320d8a7..02890a71d4 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -88,6 +88,27 @@ def test_sample_intermediates(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_cfg(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + cfg=5, + ) + self.assertEqual(sample.shape, noise.shape) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, input_shape): @@ -244,6 +265,38 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_concat_cfg(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + cfg=5, + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned_concat_rflow(self, model_params, input_shape): diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index c20cb5d6ff..4063730820 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -414,6 +414,56 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_with_cfg( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + cfg=5 + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, + cfg=5 + ) + self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_intermediates( From 480114f834bb158ce84480a2d31af298ac831f79 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 23 May 2025 18:18:03 +0100 Subject: [PATCH 7/8] DCO Remediation Commit for Virginia Fernandez I, Virginia Fernandez , hereby add my Signed-off-by to this commit: 79f357a45168a7dc5166a2fb9ff48515cd5dd56c Signed-off-by: Virginia Fernandez --- tests/inferers/test_latent_diffusion_inferer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index 4063730820..ed5e1a149e 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -455,12 +455,11 @@ def test_sample_shape_with_cfg( diffusion_model=stage_2, scheduler=scheduler, seg=input_seg, - cfg=5 + cfg=5, ) else: sample = inferer.sample( - input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, - cfg=5 + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5 ) self.assertEqual(sample.shape, input_shape) From 473b8c3a5995500d36881f4f5cf136e201fc66f9 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Fri, 23 May 2025 22:25:36 +0100 Subject: [PATCH 8/8] Fix errors. Signed-off-by: Virginia Fernandez --- monai/inferers/inferer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 67b87dc6db..3687119baa 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -892,8 +892,8 @@ def sample( else: model_input = image conditioning_input = conditioning - if mode == "concat" and conditioning is not None: - model_input = torch.cat([model_input, conditioning], dim=1) + if mode == "concat" and conditioning_input is not None: + model_input = torch.cat([model_input, conditioning_input], dim=1) model_output = diffusion_model( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None )