diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index bfb2756ebe..3687119baa 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -839,6 +839,7 @@ def sample( mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -851,6 +852,7 @@ def sample( mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -877,15 +879,31 @@ def sample( if isinstance(diffusion_model, SPADEDiffusionModelUNet) else diffusion_model ) - if mode == "concat" and conditioning is not None: - model_input = torch.cat([image, conditioning], dim=1) + if ( + cfg is not None + ): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated. + model_input = torch.cat([image] * 2, dim=0) + if conditioning is not None: + uncondition = torch.ones_like(conditioning) + uncondition.fill_(-1) + conditioning_input = torch.cat([uncondition, conditioning], dim=0) + else: + conditioning_input = None + else: + model_input = image + conditioning_input = conditioning + if mode == "concat" and conditioning_input is not None: + model_input = torch.cat([model_input, conditioning_input], dim=1) model_output = diffusion_model( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None ) else: model_output = diffusion_model( - image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input ) + if cfg is not None: + model_output_uncond, model_output_cond = model_output.chunk(2) + model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 2. compute previous image: x_t -> x_t-1 if not isinstance(scheduler, RFlowScheduler): @@ -1166,6 +1184,7 @@ def sample( # type: ignore[override] mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1180,6 +1199,7 @@ def sample( # type: ignore[override] verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if ( @@ -1203,6 +1223,7 @@ def sample( # type: ignore[override] mode=mode, verbose=verbose, seg=seg, + cfg=cfg, ) if save_intermediates: @@ -1381,6 +1402,7 @@ def sample( # type: ignore[override] mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1395,6 +1417,7 @@ def sample( # type: ignore[override] mode: Conditioning mode for the network. verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1413,14 +1436,31 @@ def sample( # type: ignore[override] progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps)) intermediates = [] + if cfg is not None: + cn_cond = torch.cat([cn_cond] * 2, dim=0) + for t, next_t in progress_bar: + # Controlnet prediction + if cfg is not None: + model_input = torch.cat([image] * 2, dim=0) + if conditioning is not None: + uncondition = torch.ones_like(conditioning) + uncondition.fill_(-1) + conditioning_input = torch.cat([uncondition, conditioning], dim=0) + else: + conditioning_input = None + else: + model_input = image + conditioning_input = conditioning + + # Diffusion model prediction diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat" and conditioning is not None: + if mode == "concat" and conditioning_input is not None: # 1. Conditioning - model_input = torch.cat([image, conditioning], dim=1) + model_input = torch.cat([model_input, conditioning_input], dim=1) # 2. ControlNet forward down_block_res_samples, mid_block_res_sample = controlnet( x=model_input, @@ -1437,20 +1477,28 @@ def sample( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + # 1. Controlnet forward down_block_res_samples, mid_block_res_sample = controlnet( - x=image, + x=model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond, - context=conditioning, + context=conditioning_input, ) + # 2. predict noise model_output model_output = diffuse( - image, + model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), - context=conditioning, + context=conditioning_input, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, ) + # If classifier-free guidance isn't None, we split and compute the weighting between + # conditioned and unconditioned output. + if cfg is not None: + model_output_uncond, model_output_cond = model_output.chunk(2) + model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) + # 3. compute previous image: x_t -> x_t-1 if not isinstance(scheduler, RFlowScheduler): image, _ = scheduler.step(model_output, t, image) # type: ignore @@ -1714,6 +1762,7 @@ def sample( # type: ignore[override] mode: str = "crossattn", verbose: bool = True, seg: torch.Tensor | None = None, + cfg: float | None = None, ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: """ Args: @@ -1730,6 +1779,7 @@ def sample( # type: ignore[override] verbose: if true, prints the progression bar of the sampling process. seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model is instance of SPADEAutoencoderKL, segmentation must be provided. + cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning. """ if ( @@ -1757,6 +1807,7 @@ def sample( # type: ignore[override] mode=mode, verbose=verbose, seg=seg, + cfg=cfg, ) if save_intermediates: diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 1ce81a71d5..2b6777a75f 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -482,16 +482,20 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape scheduler = DDPMScheduler(num_train_timesteps=10) inferer = ControlNetDiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - controlnet=controlnet, - cn_cond=mask, - save_intermediates=True, - intermediate_steps=1, - ) - self.assertEqual(len(intermediates), 10) + + for cfg in [5, None]: + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + cfg=cfg, + ) + + self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 59b320d8a7..02890a71d4 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -88,6 +88,27 @@ def test_sample_intermediates(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_cfg(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + cfg=5, + ) + self.assertEqual(sample.shape, noise.shape) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, input_shape): @@ -244,6 +265,38 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sampler_conditioned_concat_cfg(self, model_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = None + model_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + cfg=5, + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned_concat_rflow(self, model_params, input_shape): diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index c20cb5d6ff..ed5e1a149e 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -414,6 +414,55 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_with_cfg( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + + for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]: + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + sample = inferer.sample( + input_noise=noise, + autoencoder_model=stage_1, + diffusion_model=stage_2, + scheduler=scheduler, + seg=input_seg, + cfg=5, + ) + else: + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5 + ) + self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_intermediates(