11
11
import torch .nn .functional as F
12
12
from torch .amp import autocast
13
13
14
- from einops import rearrange , reduce , repeat
14
+ from einops import rearrange , reduce , repeat , pack , unpack
15
15
from einops .layers .torch import Rearrange
16
16
17
17
from tqdm .auto import tqdm
@@ -54,6 +54,15 @@ def convert_image_to_fn(img_type, image):
54
54
return image .convert (img_type )
55
55
return image
56
56
57
+ def pack_one_with_inverse (x , pattern ):
58
+ packed , packed_shape = pack ([x ], pattern )
59
+
60
+ def inverse (x , inverse_pattern = None ):
61
+ inverse_pattern = default (inverse_pattern , pattern )
62
+ return unpack (x , packed_shape , inverse_pattern )[0 ]
63
+
64
+ return packed , inverse
65
+
57
66
# normalization functions
58
67
59
68
def normalize_to_neg_one_to_one (img ):
@@ -75,6 +84,19 @@ def prob_mask_like(shape, prob, device):
75
84
else :
76
85
return torch .zeros (shape , device = device ).float ().uniform_ (0 , 1 ) < prob
77
86
87
+ def project (x , y ):
88
+ x , inverse = pack_one_with_inverse (x , 'b *' )
89
+ y , _ = pack_one_with_inverse (y , 'b *' )
90
+
91
+ dtype = x .dtype
92
+ x , y = x .double (), y .double ()
93
+ unit = F .normalize (y , dim = - 1 )
94
+
95
+ parallel = (x * unit ).sum (dim = - 1 , keepdim = True ) * unit
96
+ orthogonal = x - parallel
97
+
98
+ return inverse (parallel ).to (dtype ), inverse (orthogonal ).to (dtype )
99
+
78
100
# small helper modules
79
101
80
102
class Residual (nn .Module ):
@@ -357,6 +379,8 @@ def forward_with_cond_scale(
357
379
* args ,
358
380
cond_scale = 1. ,
359
381
rescaled_phi = 0. ,
382
+ remove_parallel_component = True ,
383
+ keep_parallel_frac = 0. ,
360
384
** kwargs
361
385
):
362
386
logits = self .forward (* args , cond_drop_prob = 0. , ** kwargs )
@@ -365,7 +389,13 @@ def forward_with_cond_scale(
365
389
return logits
366
390
367
391
null_logits = self .forward (* args , cond_drop_prob = 1. , ** kwargs )
368
- scaled_logits = null_logits + (logits - null_logits ) * cond_scale
392
+ update = logits - null_logits
393
+
394
+ if remove_parallel_component :
395
+ parallel , orthog = project (update , logits )
396
+ update = orthog + parallel * keep_parallel_frac
397
+
398
+ scaled_logits = logits + update * (cond_scale - 1. )
369
399
370
400
if rescaled_phi == 0. :
371
401
return scaled_logits , null_logits
0 commit comments