Skip to content
69 changes: 60 additions & 9 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def sample(
mode: str = "crossattn",
verbose: bool = True,
seg: torch.Tensor | None = None,
cfg: float | None = None,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Args:
Expand All @@ -851,6 +852,7 @@ def sample(
mode: Conditioning mode for the network.
verbose: if true, prints the progression bar of the sampling process.
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
"""
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")
Expand All @@ -877,15 +879,31 @@ def sample(
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
else diffusion_model
)
if mode == "concat" and conditioning is not None:
model_input = torch.cat([image, conditioning], dim=1)
if (
cfg is not None
): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.
model_input = torch.cat([image] * 2, dim=0)
if conditioning is not None:
uncondition = torch.ones_like(conditioning)
uncondition.fill_(-1)
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
else:
conditioning_input = None
else:
model_input = image
conditioning_input = conditioning
if mode == "concat" and conditioning_input is not None:
model_input = torch.cat([model_input, conditioning_input], dim=1)
model_output = diffusion_model(
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
)
else:
model_output = diffusion_model(
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input
)
if cfg is not None:
model_output_uncond, model_output_cond = model_output.chunk(2)
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)

# 2. compute previous image: x_t -> x_t-1
if not isinstance(scheduler, RFlowScheduler):
Expand Down Expand Up @@ -1166,6 +1184,7 @@ def sample( # type: ignore[override]
mode: str = "crossattn",
verbose: bool = True,
seg: torch.Tensor | None = None,
cfg: float | None = None,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Args:
Expand All @@ -1180,6 +1199,7 @@ def sample( # type: ignore[override]
verbose: if true, prints the progression bar of the sampling process.
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
is instance of SPADEAutoencoderKL, segmentation must be provided.
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
"""

if (
Expand All @@ -1203,6 +1223,7 @@ def sample( # type: ignore[override]
mode=mode,
verbose=verbose,
seg=seg,
cfg=cfg,
)

if save_intermediates:
Expand Down Expand Up @@ -1381,6 +1402,7 @@ def sample( # type: ignore[override]
mode: str = "crossattn",
verbose: bool = True,
seg: torch.Tensor | None = None,
cfg: float | None = None,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Args:
Expand All @@ -1395,6 +1417,7 @@ def sample( # type: ignore[override]
mode: Conditioning mode for the network.
verbose: if true, prints the progression bar of the sampling process.
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
"""
if mode not in ["crossattn", "concat"]:
raise NotImplementedError(f"{mode} condition is not supported")
Expand All @@ -1413,14 +1436,31 @@ def sample( # type: ignore[override]
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
intermediates = []

if cfg is not None:
cn_cond = torch.cat([cn_cond] * 2, dim=0)

for t, next_t in progress_bar:
# Controlnet prediction
if cfg is not None:
model_input = torch.cat([image] * 2, dim=0)
if conditioning is not None:
uncondition = torch.ones_like(conditioning)
uncondition.fill_(-1)
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
else:
conditioning_input = None
else:
model_input = image
conditioning_input = conditioning

# Diffusion model prediction
diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
if mode == "concat" and conditioning_input is not None:
# 1. Conditioning
model_input = torch.cat([image, conditioning], dim=1)
model_input = torch.cat([model_input, conditioning_input], dim=1)
# 2. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=model_input,
Expand All @@ -1437,20 +1477,28 @@ def sample( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
# 1. Controlnet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=image,
x=model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=conditioning,
context=conditioning_input,
)
# 2. predict noise model_output
model_output = diffuse(
image,
model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
context=conditioning,
context=conditioning_input,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)

# If classifier-free guidance isn't None, we split and compute the weighting between
# conditioned and unconditioned output.
if cfg is not None:
model_output_uncond, model_output_cond = model_output.chunk(2)
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)

# 3. compute previous image: x_t -> x_t-1
if not isinstance(scheduler, RFlowScheduler):
image, _ = scheduler.step(model_output, t, image) # type: ignore
Expand Down Expand Up @@ -1714,6 +1762,7 @@ def sample( # type: ignore[override]
mode: str = "crossattn",
verbose: bool = True,
seg: torch.Tensor | None = None,
cfg: float | None = None,
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
"""
Args:
Expand All @@ -1730,6 +1779,7 @@ def sample( # type: ignore[override]
verbose: if true, prints the progression bar of the sampling process.
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
is instance of SPADEAutoencoderKL, segmentation must be provided.
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
"""

if (
Expand Down Expand Up @@ -1757,6 +1807,7 @@ def sample( # type: ignore[override]
mode=mode,
verbose=verbose,
seg=seg,
cfg=cfg,
)

if save_intermediates:
Expand Down
24 changes: 14 additions & 10 deletions tests/inferers/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,16 +482,20 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
sample, intermediates = inferer.sample(
input_noise=noise,
diffusion_model=model,
scheduler=scheduler,
controlnet=controlnet,
cn_cond=mask,
save_intermediates=True,
intermediate_steps=1,
)
self.assertEqual(len(intermediates), 10)

for cfg in [5, None]:
sample, intermediates = inferer.sample(
input_noise=noise,
diffusion_model=model,
scheduler=scheduler,
controlnet=controlnet,
cn_cond=mask,
save_intermediates=True,
intermediate_steps=1,
cfg=cfg,
)

self.assertEqual(len(intermediates), 10)

@parameterized.expand(CNDM_TEST_CASES)
@skipUnless(has_einops, "Requires einops")
Expand Down
53 changes: 53 additions & 0 deletions tests/inferers/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,27 @@ def test_sample_intermediates(self, model_params, input_shape):
)
self.assertEqual(len(intermediates), 10)

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sample_cfg(self, model_params, input_shape):
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
noise = torch.randn(input_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
sample, intermediates = inferer.sample(
input_noise=noise,
diffusion_model=model,
scheduler=scheduler,
save_intermediates=True,
intermediate_steps=1,
cfg=5,
)
self.assertEqual(sample.shape, noise.shape)

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_ddpm_sampler(self, model_params, input_shape):
Expand Down Expand Up @@ -244,6 +265,38 @@ def test_sampler_conditioned_concat(self, model_params, input_shape):
)
self.assertEqual(len(intermediates), 10)

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sampler_conditioned_concat_cfg(self, model_params, input_shape):
# copy the model_params dict to prevent from modifying test cases
model_params = model_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
noise = torch.randn(input_shape).to(device)
conditioning_shape = list(input_shape)
conditioning_shape[1] = n_concat_channel
conditioning = torch.randn(conditioning_shape).to(device)
scheduler = DDIMScheduler(num_train_timesteps=1000)
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
sample, intermediates = inferer.sample(
input_noise=noise,
diffusion_model=model,
scheduler=scheduler,
save_intermediates=True,
intermediate_steps=1,
conditioning=conditioning,
mode="concat",
cfg=5,
)
self.assertEqual(len(intermediates), 10)

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):
Expand Down
49 changes: 49 additions & 0 deletions tests/inferers/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,55 @@ def test_sample_shape(
)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_with_cfg(
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
):
stage_1 = None

if ae_model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if ae_model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

noise = torch.randn(latent_shape).to(device)

for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
input_shape_seg = list(input_shape)
if "label_nc" in stage_2_params.keys():
input_shape_seg[1] = stage_2_params["label_nc"]
else:
input_shape_seg[1] = autoencoder_params["label_nc"]
input_seg = torch.randn(input_shape_seg).to(device)
sample = inferer.sample(
input_noise=noise,
autoencoder_model=stage_1,
diffusion_model=stage_2,
scheduler=scheduler,
seg=input_seg,
cfg=5,
)
else:
sample = inferer.sample(
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, cfg=5
)
self.assertEqual(sample.shape, input_shape)

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sample_intermediates(
Expand Down
Loading