@@ -839,6 +839,7 @@ def sample(
839
839
mode : str = "crossattn" ,
840
840
verbose : bool = True ,
841
841
seg : torch .Tensor | None = None ,
842
+ cfg : float | None = None ,
842
843
) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
843
844
"""
844
845
Args:
@@ -851,6 +852,7 @@ def sample(
851
852
mode: Conditioning mode for the network.
852
853
verbose: if true, prints the progression bar of the sampling process.
853
854
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
855
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
854
856
"""
855
857
if mode not in ["crossattn" , "concat" ]:
856
858
raise NotImplementedError (f"{ mode } condition is not supported" )
@@ -877,15 +879,31 @@ def sample(
877
879
if isinstance (diffusion_model , SPADEDiffusionModelUNet )
878
880
else diffusion_model
879
881
)
882
+ if (
883
+ cfg is not None
884
+ ): # if classifier-free guidance is used, a conditioned and unconditioned bit is generated.
885
+ model_input = torch .cat ([image ] * 2 , dim = 0 )
886
+ if conditioning is not None :
887
+ uncondition = torch .ones_like (conditioning )
888
+ uncondition .fill_ (- 1 )
889
+ conditioning_input = torch .cat ([uncondition , conditioning ], dim = 0 )
890
+ else :
891
+ conditioning_input = None
892
+ else :
893
+ model_input = image
894
+ conditioning_input = conditioning
880
895
if mode == "concat" and conditioning is not None :
881
- model_input = torch .cat ([image , conditioning ], dim = 1 )
896
+ model_input = torch .cat ([model_input , conditioning ], dim = 1 )
882
897
model_output = diffusion_model (
883
898
model_input , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), context = None
884
899
)
885
900
else :
886
901
model_output = diffusion_model (
887
- image , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), context = conditioning
902
+ model_input , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), context = conditioning_input
888
903
)
904
+ if cfg is not None :
905
+ model_output_uncond , model_output_cond = model_output .chunk (2 )
906
+ model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond )
889
907
890
908
# 2. compute previous image: x_t -> x_t-1
891
909
if not isinstance (scheduler , RFlowScheduler ):
@@ -1166,6 +1184,7 @@ def sample( # type: ignore[override]
1166
1184
mode : str = "crossattn" ,
1167
1185
verbose : bool = True ,
1168
1186
seg : torch .Tensor | None = None ,
1187
+ cfg : float | None = None ,
1169
1188
) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
1170
1189
"""
1171
1190
Args:
@@ -1180,6 +1199,7 @@ def sample( # type: ignore[override]
1180
1199
verbose: if true, prints the progression bar of the sampling process.
1181
1200
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1182
1201
is instance of SPADEAutoencoderKL, segmentation must be provided.
1202
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1183
1203
"""
1184
1204
1185
1205
if (
@@ -1203,6 +1223,7 @@ def sample( # type: ignore[override]
1203
1223
mode = mode ,
1204
1224
verbose = verbose ,
1205
1225
seg = seg ,
1226
+ cfg = cfg ,
1206
1227
)
1207
1228
1208
1229
if save_intermediates :
@@ -1381,6 +1402,7 @@ def sample( # type: ignore[override]
1381
1402
mode : str = "crossattn" ,
1382
1403
verbose : bool = True ,
1383
1404
seg : torch .Tensor | None = None ,
1405
+ cfg : float | None = None ,
1384
1406
) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
1385
1407
"""
1386
1408
Args:
@@ -1395,6 +1417,7 @@ def sample( # type: ignore[override]
1395
1417
mode: Conditioning mode for the network.
1396
1418
verbose: if true, prints the progression bar of the sampling process.
1397
1419
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1420
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1398
1421
"""
1399
1422
if mode not in ["crossattn" , "concat" ]:
1400
1423
raise NotImplementedError (f"{ mode } condition is not supported" )
@@ -1413,14 +1436,31 @@ def sample( # type: ignore[override]
1413
1436
progress_bar = iter (zip (scheduler .timesteps , all_next_timesteps ))
1414
1437
intermediates = []
1415
1438
1439
+ if cfg is not None :
1440
+ cn_cond = torch .cat ([cn_cond ] * 2 , dim = 0 )
1441
+
1416
1442
for t , next_t in progress_bar :
1443
+ # Controlnet prediction
1444
+ if cfg is not None :
1445
+ model_input = torch .cat ([image ] * 2 , dim = 0 )
1446
+ if conditioning is not None :
1447
+ uncondition = torch .ones_like (conditioning )
1448
+ uncondition .fill_ (- 1 )
1449
+ conditioning_input = torch .cat ([uncondition , conditioning ], dim = 0 )
1450
+ else :
1451
+ conditioning_input = None
1452
+ else :
1453
+ model_input = image
1454
+ conditioning_input = conditioning
1455
+
1456
+ # Diffusion model prediction
1417
1457
diffuse = diffusion_model
1418
1458
if isinstance (diffusion_model , SPADEDiffusionModelUNet ):
1419
1459
diffuse = partial (diffusion_model , seg = seg )
1420
1460
1421
- if mode == "concat" and conditioning is not None :
1461
+ if mode == "concat" and conditioning_input is not None :
1422
1462
# 1. Conditioning
1423
- model_input = torch .cat ([image , conditioning ], dim = 1 )
1463
+ model_input = torch .cat ([model_input , conditioning_input ], dim = 1 )
1424
1464
# 2. ControlNet forward
1425
1465
down_block_res_samples , mid_block_res_sample = controlnet (
1426
1466
x = model_input ,
@@ -1437,20 +1477,28 @@ def sample( # type: ignore[override]
1437
1477
mid_block_additional_residual = mid_block_res_sample ,
1438
1478
)
1439
1479
else :
1480
+ # 1. Controlnet forward
1440
1481
down_block_res_samples , mid_block_res_sample = controlnet (
1441
- x = image ,
1482
+ x = model_input ,
1442
1483
timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
1443
1484
controlnet_cond = cn_cond ,
1444
- context = conditioning ,
1485
+ context = conditioning_input ,
1445
1486
)
1487
+ # 2. predict noise model_output
1446
1488
model_output = diffuse (
1447
- image ,
1489
+ model_input ,
1448
1490
timesteps = torch .Tensor ((t ,)).to (input_noise .device ),
1449
- context = conditioning ,
1491
+ context = conditioning_input ,
1450
1492
down_block_additional_residuals = down_block_res_samples ,
1451
1493
mid_block_additional_residual = mid_block_res_sample ,
1452
1494
)
1453
1495
1496
+ # If classifier-free guidance isn't None, we split and compute the weighting between
1497
+ # conditioned and unconditioned output.
1498
+ if cfg is not None :
1499
+ model_output_uncond , model_output_cond = model_output .chunk (2 )
1500
+ model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond )
1501
+
1454
1502
# 3. compute previous image: x_t -> x_t-1
1455
1503
if not isinstance (scheduler , RFlowScheduler ):
1456
1504
image , _ = scheduler .step (model_output , t , image ) # type: ignore
@@ -1714,6 +1762,7 @@ def sample( # type: ignore[override]
1714
1762
mode : str = "crossattn" ,
1715
1763
verbose : bool = True ,
1716
1764
seg : torch .Tensor | None = None ,
1765
+ cfg : float | None = None ,
1717
1766
) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
1718
1767
"""
1719
1768
Args:
@@ -1730,6 +1779,7 @@ def sample( # type: ignore[override]
1730
1779
verbose: if true, prints the progression bar of the sampling process.
1731
1780
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1732
1781
is instance of SPADEAutoencoderKL, segmentation must be provided.
1782
+ cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1733
1783
"""
1734
1784
1735
1785
if (
@@ -1757,6 +1807,7 @@ def sample( # type: ignore[override]
1757
1807
mode = mode ,
1758
1808
verbose = verbose ,
1759
1809
seg = seg ,
1810
+ cfg = cfg ,
1760
1811
)
1761
1812
1762
1813
if save_intermediates :
0 commit comments