|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import math |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | +from einops import rearrange, repeat |
| 8 | +from torch import einsum |
| 9 | + |
| 10 | +import comfy.samplers |
| 11 | +from comfy.ldm.modules.attention import optimized_attention |
| 12 | +from comfy_api.v3 import io |
| 13 | + |
| 14 | + |
| 15 | +# from comfy/ldm/modules/attention.py |
| 16 | +# but modified to return attention scores as well as output |
| 17 | +def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None): |
| 18 | + b, _, dim_head = q.shape |
| 19 | + dim_head //= heads |
| 20 | + scale = dim_head ** -0.5 |
| 21 | + |
| 22 | + h = heads |
| 23 | + q, k, v = map( |
| 24 | + lambda t: t.unsqueeze(3) |
| 25 | + .reshape(b, -1, heads, dim_head) |
| 26 | + .permute(0, 2, 1, 3) |
| 27 | + .reshape(b * heads, -1, dim_head) |
| 28 | + .contiguous(), |
| 29 | + (q, k, v), |
| 30 | + ) |
| 31 | + |
| 32 | + # force cast to fp32 to avoid overflowing |
| 33 | + if attn_precision == torch.float32: |
| 34 | + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale |
| 35 | + else: |
| 36 | + sim = einsum('b i d, b j d -> b i j', q, k) * scale |
| 37 | + |
| 38 | + del q, k |
| 39 | + |
| 40 | + if mask is not None: |
| 41 | + mask = rearrange(mask, 'b ... -> b (...)') |
| 42 | + max_neg_value = -torch.finfo(sim.dtype).max |
| 43 | + mask = repeat(mask, 'b j -> (b h) () j', h=h) |
| 44 | + sim.masked_fill_(~mask, max_neg_value) |
| 45 | + |
| 46 | + # attention, what we cannot get enough of |
| 47 | + sim = sim.softmax(dim=-1) |
| 48 | + |
| 49 | + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) |
| 50 | + out = ( |
| 51 | + out.unsqueeze(0) |
| 52 | + .reshape(b, heads, -1, dim_head) |
| 53 | + .permute(0, 2, 1, 3) |
| 54 | + .reshape(b, -1, heads * dim_head) |
| 55 | + ) |
| 56 | + return out, sim |
| 57 | + |
| 58 | + |
| 59 | +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): |
| 60 | + # reshape and GAP the attention map |
| 61 | + _, hw1, hw2 = attn.shape |
| 62 | + b, _, lh, lw = x0.shape |
| 63 | + attn = attn.reshape(b, -1, hw1, hw2) |
| 64 | + # Global Average Pool |
| 65 | + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold |
| 66 | + |
| 67 | + total = mask.shape[-1] |
| 68 | + x = round(math.sqrt((lh / lw) * total)) |
| 69 | + xx = None |
| 70 | + for i in range(0, math.floor(math.sqrt(total) / 2)): |
| 71 | + for j in [(x + i), max(1, x - i)]: |
| 72 | + if total % j == 0: |
| 73 | + xx = j |
| 74 | + break |
| 75 | + if xx is not None: |
| 76 | + break |
| 77 | + |
| 78 | + x = xx |
| 79 | + y = total // x |
| 80 | + |
| 81 | + # Reshape |
| 82 | + mask = ( |
| 83 | + mask.reshape(b, x, y) |
| 84 | + .unsqueeze(1) |
| 85 | + .type(attn.dtype) |
| 86 | + ) |
| 87 | + # Upsample |
| 88 | + mask = F.interpolate(mask, (lh, lw)) |
| 89 | + |
| 90 | + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) |
| 91 | + blurred = blurred * mask + x0 * (1 - mask) |
| 92 | + return blurred |
| 93 | + |
| 94 | + |
| 95 | +def gaussian_blur_2d(img, kernel_size, sigma): |
| 96 | + ksize_half = (kernel_size - 1) * 0.5 |
| 97 | + |
| 98 | + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) |
| 99 | + |
| 100 | + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) |
| 101 | + |
| 102 | + x_kernel = pdf / pdf.sum() |
| 103 | + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) |
| 104 | + |
| 105 | + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) |
| 106 | + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) |
| 107 | + |
| 108 | + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] |
| 109 | + |
| 110 | + img = F.pad(img, padding, mode="reflect") |
| 111 | + return F.conv2d(img, kernel2d, groups=img.shape[-3]) |
| 112 | + |
| 113 | + |
| 114 | +class SelfAttentionGuidance(io.ComfyNodeV3): |
| 115 | + @classmethod |
| 116 | + def define_schema(cls): |
| 117 | + return io.SchemaV3( |
| 118 | + node_id="SelfAttentionGuidance_V3", |
| 119 | + display_name="Self-Attention Guidance _V3", |
| 120 | + category="_for_testing", |
| 121 | + inputs=[ |
| 122 | + io.Model.Input("model"), |
| 123 | + io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), |
| 124 | + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), |
| 125 | + ], |
| 126 | + outputs=[ |
| 127 | + io.Model.Output(), |
| 128 | + ], |
| 129 | + is_experimental=True, |
| 130 | + ) |
| 131 | + |
| 132 | + @classmethod |
| 133 | + def execute(cls, model, scale, blur_sigma): |
| 134 | + m = model.clone() |
| 135 | + |
| 136 | + attn_scores = None |
| 137 | + |
| 138 | + # TODO: make this work properly with chunked batches |
| 139 | + # currently, we can only save the attn from one UNet call |
| 140 | + def attn_and_record(q, k, v, extra_options): |
| 141 | + nonlocal attn_scores |
| 142 | + # if uncond, save the attention scores |
| 143 | + heads = extra_options["n_heads"] |
| 144 | + cond_or_uncond = extra_options["cond_or_uncond"] |
| 145 | + b = q.shape[0] // len(cond_or_uncond) |
| 146 | + if 1 in cond_or_uncond: |
| 147 | + uncond_index = cond_or_uncond.index(1) |
| 148 | + # do the entire attention operation, but save the attention scores to attn_scores |
| 149 | + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) |
| 150 | + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] |
| 151 | + n_slices = heads * b |
| 152 | + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] |
| 153 | + return out |
| 154 | + else: |
| 155 | + return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) |
| 156 | + |
| 157 | + def post_cfg_function(args): |
| 158 | + nonlocal attn_scores |
| 159 | + uncond_attn = attn_scores |
| 160 | + |
| 161 | + sag_scale = scale |
| 162 | + sag_sigma = blur_sigma |
| 163 | + sag_threshold = 1.0 |
| 164 | + model = args["model"] |
| 165 | + uncond_pred = args["uncond_denoised"] |
| 166 | + uncond = args["uncond"] |
| 167 | + cfg_result = args["denoised"] |
| 168 | + sigma = args["sigma"] |
| 169 | + model_options = args["model_options"] |
| 170 | + x = args["input"] |
| 171 | + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding |
| 172 | + return cfg_result |
| 173 | + |
| 174 | + # create the adversarially blurred image |
| 175 | + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) |
| 176 | + degraded_noised = degraded + x - uncond_pred |
| 177 | + # call into the UNet |
| 178 | + (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) |
| 179 | + return cfg_result + (degraded - sag) * sag_scale |
| 180 | + |
| 181 | + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) |
| 182 | + |
| 183 | + # from diffusers: |
| 184 | + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch |
| 185 | + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) |
| 186 | + |
| 187 | + return io.NodeOutput(m) |
| 188 | + |
| 189 | +NODES_LIST = [SelfAttentionGuidance] |
0 commit comments