Skip to content

Commit 74f5aee

Browse files
Virginia FernandezVirginia Fernandez
authored andcommitted
Add classifier-free guidance to MONAI inferers and tests for them
Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
1 parent 70eb1d7 commit 74f5aee

File tree

4 files changed

+176
-18
lines changed

4 files changed

+176
-18
lines changed

monai/inferers/inferer.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def sample(
839839
mode: str = "crossattn",
840840
verbose: bool = True,
841841
seg: torch.Tensor | None = None,
842+
cfg: float | None = None,
842843
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
843844
"""
844845
Args:
@@ -851,6 +852,7 @@ def sample(
851852
mode: Conditioning mode for the network.
852853
verbose: if true, prints the progression bar of the sampling process.
853854
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
855+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
854856
"""
855857
if mode not in ["crossattn", "concat"]:
856858
raise NotImplementedError(f"{mode} condition is not supported")
@@ -877,15 +879,31 @@ def sample(
877879
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
878880
else diffusion_model
879881
)
882+
if (
883+
cfg is not None
884+
): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.
885+
model_input = torch.cat([image] * 2, dim=0)
886+
if conditioning is not None:
887+
uncondition = torch.ones_like(conditioning)
888+
uncondition.fill_(-1)
889+
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
890+
else:
891+
conditioning_input = None
892+
else:
893+
model_input = image
894+
conditioning_input = conditioning
880895
if mode == "concat" and conditioning is not None:
881-
model_input = torch.cat([image, conditioning], dim=1)
896+
model_input = torch.cat([model_input, conditioning], dim=1)
882897
model_output = diffusion_model(
883898
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
884899
)
885900
else:
886901
model_output = diffusion_model(
887-
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
902+
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning_input
888903
)
904+
if cfg is not None:
905+
model_output_uncond, model_output_cond = model_output.chunk(2)
906+
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
889907

890908
# 2. compute previous image: x_t -> x_t-1
891909
if not isinstance(scheduler, RFlowScheduler):
@@ -1166,6 +1184,7 @@ def sample( # type: ignore[override]
11661184
mode: str = "crossattn",
11671185
verbose: bool = True,
11681186
seg: torch.Tensor | None = None,
1187+
cfg: float | None = None,
11691188
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
11701189
"""
11711190
Args:
@@ -1180,6 +1199,7 @@ def sample( # type: ignore[override]
11801199
verbose: if true, prints the progression bar of the sampling process.
11811200
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
11821201
is instance of SPADEAutoencoderKL, segmentation must be provided.
1202+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
11831203
"""
11841204

11851205
if (
@@ -1203,6 +1223,7 @@ def sample( # type: ignore[override]
12031223
mode=mode,
12041224
verbose=verbose,
12051225
seg=seg,
1226+
cfg=cfg,
12061227
)
12071228

12081229
if save_intermediates:
@@ -1381,6 +1402,7 @@ def sample( # type: ignore[override]
13811402
mode: str = "crossattn",
13821403
verbose: bool = True,
13831404
seg: torch.Tensor | None = None,
1405+
cfg: float | None = None,
13841406
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
13851407
"""
13861408
Args:
@@ -1395,6 +1417,7 @@ def sample( # type: ignore[override]
13951417
mode: Conditioning mode for the network.
13961418
verbose: if true, prints the progression bar of the sampling process.
13971419
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1420+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
13981421
"""
13991422
if mode not in ["crossattn", "concat"]:
14001423
raise NotImplementedError(f"{mode} condition is not supported")
@@ -1413,14 +1436,31 @@ def sample( # type: ignore[override]
14131436
progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps))
14141437
intermediates = []
14151438

1439+
if cfg is not None:
1440+
cn_cond = torch.cat([cn_cond] * 2, dim=0)
1441+
14161442
for t, next_t in progress_bar:
1443+
# Controlnet prediction
1444+
if cfg is not None:
1445+
model_input = torch.cat([image] * 2, dim=0)
1446+
if conditioning is not None:
1447+
uncondition = torch.ones_like(conditioning)
1448+
uncondition.fill_(-1)
1449+
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
1450+
else:
1451+
conditioning_input = None
1452+
else:
1453+
model_input = image
1454+
conditioning_input = conditioning
1455+
1456+
# Diffusion model prediction
14171457
diffuse = diffusion_model
14181458
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
14191459
diffuse = partial(diffusion_model, seg=seg)
14201460

1421-
if mode == "concat" and conditioning is not None:
1461+
if mode == "concat" and conditioning_input is not None:
14221462
# 1. Conditioning
1423-
model_input = torch.cat([image, conditioning], dim=1)
1463+
model_input = torch.cat([model_input, conditioning_input], dim=1)
14241464
# 2. ControlNet forward
14251465
down_block_res_samples, mid_block_res_sample = controlnet(
14261466
x=model_input,
@@ -1437,20 +1477,28 @@ def sample( # type: ignore[override]
14371477
mid_block_additional_residual=mid_block_res_sample,
14381478
)
14391479
else:
1480+
# 1. Controlnet forward
14401481
down_block_res_samples, mid_block_res_sample = controlnet(
1441-
x=image,
1482+
x=model_input,
14421483
timesteps=torch.Tensor((t,)).to(input_noise.device),
14431484
controlnet_cond=cn_cond,
1444-
context=conditioning,
1485+
context=conditioning_input,
14451486
)
1487+
# 2. predict noise model_output
14461488
model_output = diffuse(
1447-
image,
1489+
model_input,
14481490
timesteps=torch.Tensor((t,)).to(input_noise.device),
1449-
context=conditioning,
1491+
context=conditioning_input,
14501492
down_block_additional_residuals=down_block_res_samples,
14511493
mid_block_additional_residual=mid_block_res_sample,
14521494
)
14531495

1496+
# If classifier-free guidance isn't None, we split and compute the weighting between
1497+
# conditioned and unconditioned output.
1498+
if cfg is not None:
1499+
model_output_uncond, model_output_cond = model_output.chunk(2)
1500+
model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond)
1501+
14541502
# 3. compute previous image: x_t -> x_t-1
14551503
if not isinstance(scheduler, RFlowScheduler):
14561504
image, _ = scheduler.step(model_output, t, image) # type: ignore
@@ -1714,6 +1762,7 @@ def sample( # type: ignore[override]
17141762
mode: str = "crossattn",
17151763
verbose: bool = True,
17161764
seg: torch.Tensor | None = None,
1765+
cfg: float | None = None,
17171766
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
17181767
"""
17191768
Args:
@@ -1730,6 +1779,7 @@ def sample( # type: ignore[override]
17301779
verbose: if true, prints the progression bar of the sampling process.
17311780
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
17321781
is instance of SPADEAutoencoderKL, segmentation must be provided.
1782+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
17331783
"""
17341784

17351785
if (
@@ -1757,6 +1807,7 @@ def sample( # type: ignore[override]
17571807
mode=mode,
17581808
verbose=verbose,
17591809
seg=seg,
1810+
cfg=cfg,
17601811
)
17611812

17621813
if save_intermediates:

tests/inferers/test_controlnet_inferers.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,20 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape
482482
scheduler = DDPMScheduler(num_train_timesteps=10)
483483
inferer = ControlNetDiffusionInferer(scheduler=scheduler)
484484
scheduler.set_timesteps(num_inference_steps=10)
485-
sample, intermediates = inferer.sample(
486-
input_noise=noise,
487-
diffusion_model=model,
488-
scheduler=scheduler,
489-
controlnet=controlnet,
490-
cn_cond=mask,
491-
save_intermediates=True,
492-
intermediate_steps=1,
493-
)
494-
self.assertEqual(len(intermediates), 10)
485+
486+
for cfg in [5, None]:
487+
sample, intermediates = inferer.sample(
488+
input_noise=noise,
489+
diffusion_model=model,
490+
scheduler=scheduler,
491+
controlnet=controlnet,
492+
cn_cond=mask,
493+
save_intermediates=True,
494+
intermediate_steps=1,
495+
cfg=cfg,
496+
)
497+
498+
self.assertEqual(len(intermediates), 10)
495499

496500
@parameterized.expand(CNDM_TEST_CASES)
497501
@skipUnless(has_einops, "Requires einops")

tests/inferers/test_diffusion_inferer.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,27 @@ def test_sample_intermediates(self, model_params, input_shape):
8888
)
8989
self.assertEqual(len(intermediates), 10)
9090

91+
@parameterized.expand(TEST_CASES)
92+
@skipUnless(has_einops, "Requires einops")
93+
def test_sample_cfg(self, model_params, input_shape):
94+
model = DiffusionModelUNet(**model_params)
95+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
96+
model.to(device)
97+
model.eval()
98+
noise = torch.randn(input_shape).to(device)
99+
scheduler = DDPMScheduler(num_train_timesteps=10)
100+
inferer = DiffusionInferer(scheduler=scheduler)
101+
scheduler.set_timesteps(num_inference_steps=10)
102+
sample, intermediates = inferer.sample(
103+
input_noise=noise,
104+
diffusion_model=model,
105+
scheduler=scheduler,
106+
save_intermediates=True,
107+
intermediate_steps=1,
108+
cfg=5,
109+
)
110+
self.assertEqual(sample.shape, noise.shape)
111+
91112
@parameterized.expand(TEST_CASES)
92113
@skipUnless(has_einops, "Requires einops")
93114
def test_ddpm_sampler(self, model_params, input_shape):
@@ -244,6 +265,38 @@ def test_sampler_conditioned_concat(self, model_params, input_shape):
244265
)
245266
self.assertEqual(len(intermediates), 10)
246267

268+
@parameterized.expand(TEST_CASES)
269+
@skipUnless(has_einops, "Requires einops")
270+
def test_sampler_conditioned_concat_cfg(self, model_params, input_shape):
271+
# copy the model_params dict to prevent from modifying test cases
272+
model_params = model_params.copy()
273+
n_concat_channel = 2
274+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
275+
model_params["cross_attention_dim"] = None
276+
model_params["with_conditioning"] = False
277+
model = DiffusionModelUNet(**model_params)
278+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
279+
model.to(device)
280+
model.eval()
281+
noise = torch.randn(input_shape).to(device)
282+
conditioning_shape = list(input_shape)
283+
conditioning_shape[1] = n_concat_channel
284+
conditioning = torch.randn(conditioning_shape).to(device)
285+
scheduler = DDIMScheduler(num_train_timesteps=1000)
286+
inferer = DiffusionInferer(scheduler=scheduler)
287+
scheduler.set_timesteps(num_inference_steps=10)
288+
sample, intermediates = inferer.sample(
289+
input_noise=noise,
290+
diffusion_model=model,
291+
scheduler=scheduler,
292+
save_intermediates=True,
293+
intermediate_steps=1,
294+
conditioning=conditioning,
295+
mode="concat",
296+
cfg=5,
297+
)
298+
self.assertEqual(len(intermediates), 10)
299+
247300
@parameterized.expand(TEST_CASES)
248301
@skipUnless(has_einops, "Requires einops")
249302
def test_sampler_conditioned_concat_rflow(self, model_params, input_shape):

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,56 @@ def test_sample_shape(
414414
)
415415
self.assertEqual(sample.shape, input_shape)
416416

417+
@parameterized.expand(TEST_CASES)
418+
@skipUnless(has_einops, "Requires einops")
419+
def test_sample_shape_with_cfg(
420+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
421+
):
422+
stage_1 = None
423+
424+
if ae_model_type == "AutoencoderKL":
425+
stage_1 = AutoencoderKL(**autoencoder_params)
426+
if ae_model_type == "VQVAE":
427+
stage_1 = VQVAE(**autoencoder_params)
428+
if dm_model_type == "SPADEDiffusionModelUNet":
429+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
430+
else:
431+
stage_2 = DiffusionModelUNet(**stage_2_params)
432+
433+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
434+
stage_1.to(device)
435+
stage_2.to(device)
436+
stage_1.eval()
437+
stage_2.eval()
438+
439+
noise = torch.randn(latent_shape).to(device)
440+
441+
for scheduler in [DDPMScheduler(num_train_timesteps=10), RFlowScheduler(num_train_timesteps=1000)]:
442+
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
443+
scheduler.set_timesteps(num_inference_steps=10)
444+
445+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
446+
input_shape_seg = list(input_shape)
447+
if "label_nc" in stage_2_params.keys():
448+
input_shape_seg[1] = stage_2_params["label_nc"]
449+
else:
450+
input_shape_seg[1] = autoencoder_params["label_nc"]
451+
input_seg = torch.randn(input_shape_seg).to(device)
452+
sample = inferer.sample(
453+
input_noise=noise,
454+
autoencoder_model=stage_1,
455+
diffusion_model=stage_2,
456+
scheduler=scheduler,
457+
seg=input_seg,
458+
cfg=5
459+
)
460+
else:
461+
sample = inferer.sample(
462+
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler,
463+
cfg=5
464+
)
465+
self.assertEqual(sample.shape, input_shape)
466+
417467
@parameterized.expand(TEST_CASES)
418468
@skipUnless(has_einops, "Requires einops")
419469
def test_sample_intermediates(

0 commit comments

Comments
 (0)