Skip to content

Commit 7c1a4cf

Browse files
committed
add a new technique for countering oversaturation at higher cfg guidance strength
1 parent ef4421a commit 7c1a4cf

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,12 @@ You could consider adding a suitable metric to the training loop yourself after
377377
url = {https://api.semanticscholar.org/CorpusID:270391454}
378378
}
379379
```
380+
381+
```bibtex
382+
@inproceedings{Sadat2024EliminatingOA,
383+
title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
384+
author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
385+
year = {2024},
386+
url = {https://api.semanticscholar.org/CorpusID:273098845}
387+
}
388+
```

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from torch.amp import autocast
1313

14-
from einops import rearrange, reduce, repeat
14+
from einops import rearrange, reduce, repeat, pack, unpack
1515
from einops.layers.torch import Rearrange
1616

1717
from tqdm.auto import tqdm
@@ -54,6 +54,15 @@ def convert_image_to_fn(img_type, image):
5454
return image.convert(img_type)
5555
return image
5656

57+
def pack_one_with_inverse(x, pattern):
58+
packed, packed_shape = pack([x], pattern)
59+
60+
def inverse(x, inverse_pattern = None):
61+
inverse_pattern = default(inverse_pattern, pattern)
62+
return unpack(x, packed_shape, inverse_pattern)[0]
63+
64+
return packed, inverse
65+
5766
# normalization functions
5867

5968
def normalize_to_neg_one_to_one(img):
@@ -75,6 +84,19 @@ def prob_mask_like(shape, prob, device):
7584
else:
7685
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
7786

87+
def project(x, y):
88+
x, inverse = pack_one_with_inverse(x, 'b *')
89+
y, _ = pack_one_with_inverse(y, 'b *')
90+
91+
dtype = x.dtype
92+
x, y = x.double(), y.double()
93+
unit = F.normalize(y, dim = -1)
94+
95+
parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
96+
orthogonal = x - parallel
97+
98+
return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
99+
78100
# small helper modules
79101

80102
class Residual(nn.Module):
@@ -357,6 +379,8 @@ def forward_with_cond_scale(
357379
*args,
358380
cond_scale = 1.,
359381
rescaled_phi = 0.,
382+
remove_parallel_component = True,
383+
keep_parallel_frac = 0.,
360384
**kwargs
361385
):
362386
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
@@ -365,7 +389,13 @@ def forward_with_cond_scale(
365389
return logits
366390

367391
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
368-
scaled_logits = null_logits + (logits - null_logits) * cond_scale
392+
update = logits - null_logits
393+
394+
if remove_parallel_component:
395+
parallel, orthog = project(update, logits)
396+
update = orthog + parallel * keep_parallel_frac
397+
398+
scaled_logits = logits + update * (cond_scale - 1.)
369399

370400
if rescaled_phi == 0.:
371401
return scaled_logits, null_logits
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.0.17'
1+
__version__ = '2.0.18'

0 commit comments

Comments
 (0)