Skip to content

Commit 7f8c51e

Browse files
committed
v3 nodes: sd3, selfattent, s4_4xupscale, skiplayer
1 parent 27734d9 commit 7f8c51e

File tree

5 files changed

+569
-0
lines changed

5 files changed

+569
-0
lines changed

comfy_extras/v3/nodes_sag.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from __future__ import annotations
2+
3+
import math
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from einops import rearrange, repeat
8+
from torch import einsum
9+
10+
import comfy.samplers
11+
from comfy.ldm.modules.attention import optimized_attention
12+
from comfy_api.v3 import io
13+
14+
15+
# from comfy/ldm/modules/attention.py
16+
# but modified to return attention scores as well as output
17+
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
18+
b, _, dim_head = q.shape
19+
dim_head //= heads
20+
scale = dim_head ** -0.5
21+
22+
h = heads
23+
q, k, v = map(
24+
lambda t: t.unsqueeze(3)
25+
.reshape(b, -1, heads, dim_head)
26+
.permute(0, 2, 1, 3)
27+
.reshape(b * heads, -1, dim_head)
28+
.contiguous(),
29+
(q, k, v),
30+
)
31+
32+
# force cast to fp32 to avoid overflowing
33+
if attn_precision == torch.float32:
34+
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
35+
else:
36+
sim = einsum('b i d, b j d -> b i j', q, k) * scale
37+
38+
del q, k
39+
40+
if mask is not None:
41+
mask = rearrange(mask, 'b ... -> b (...)')
42+
max_neg_value = -torch.finfo(sim.dtype).max
43+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
44+
sim.masked_fill_(~mask, max_neg_value)
45+
46+
# attention, what we cannot get enough of
47+
sim = sim.softmax(dim=-1)
48+
49+
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
50+
out = (
51+
out.unsqueeze(0)
52+
.reshape(b, heads, -1, dim_head)
53+
.permute(0, 2, 1, 3)
54+
.reshape(b, -1, heads * dim_head)
55+
)
56+
return out, sim
57+
58+
59+
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
60+
# reshape and GAP the attention map
61+
_, hw1, hw2 = attn.shape
62+
b, _, lh, lw = x0.shape
63+
attn = attn.reshape(b, -1, hw1, hw2)
64+
# Global Average Pool
65+
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
66+
67+
total = mask.shape[-1]
68+
x = round(math.sqrt((lh / lw) * total))
69+
xx = None
70+
for i in range(0, math.floor(math.sqrt(total) / 2)):
71+
for j in [(x + i), max(1, x - i)]:
72+
if total % j == 0:
73+
xx = j
74+
break
75+
if xx is not None:
76+
break
77+
78+
x = xx
79+
y = total // x
80+
81+
# Reshape
82+
mask = (
83+
mask.reshape(b, x, y)
84+
.unsqueeze(1)
85+
.type(attn.dtype)
86+
)
87+
# Upsample
88+
mask = F.interpolate(mask, (lh, lw))
89+
90+
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
91+
blurred = blurred * mask + x0 * (1 - mask)
92+
return blurred
93+
94+
95+
def gaussian_blur_2d(img, kernel_size, sigma):
96+
ksize_half = (kernel_size - 1) * 0.5
97+
98+
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
99+
100+
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
101+
102+
x_kernel = pdf / pdf.sum()
103+
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
104+
105+
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
106+
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
107+
108+
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
109+
110+
img = F.pad(img, padding, mode="reflect")
111+
return F.conv2d(img, kernel2d, groups=img.shape[-3])
112+
113+
114+
class SelfAttentionGuidance(io.ComfyNodeV3):
115+
@classmethod
116+
def define_schema(cls):
117+
return io.SchemaV3(
118+
node_id="SelfAttentionGuidance_V3",
119+
display_name="Self-Attention Guidance _V3",
120+
category="_for_testing",
121+
inputs=[
122+
io.Model.Input("model"),
123+
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
124+
io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1),
125+
],
126+
outputs=[
127+
io.Model.Output(),
128+
],
129+
is_experimental=True,
130+
)
131+
132+
@classmethod
133+
def execute(cls, model, scale, blur_sigma):
134+
m = model.clone()
135+
136+
attn_scores = None
137+
138+
# TODO: make this work properly with chunked batches
139+
# currently, we can only save the attn from one UNet call
140+
def attn_and_record(q, k, v, extra_options):
141+
nonlocal attn_scores
142+
# if uncond, save the attention scores
143+
heads = extra_options["n_heads"]
144+
cond_or_uncond = extra_options["cond_or_uncond"]
145+
b = q.shape[0] // len(cond_or_uncond)
146+
if 1 in cond_or_uncond:
147+
uncond_index = cond_or_uncond.index(1)
148+
# do the entire attention operation, but save the attention scores to attn_scores
149+
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
150+
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
151+
n_slices = heads * b
152+
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
153+
return out
154+
else:
155+
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
156+
157+
def post_cfg_function(args):
158+
nonlocal attn_scores
159+
uncond_attn = attn_scores
160+
161+
sag_scale = scale
162+
sag_sigma = blur_sigma
163+
sag_threshold = 1.0
164+
model = args["model"]
165+
uncond_pred = args["uncond_denoised"]
166+
uncond = args["uncond"]
167+
cfg_result = args["denoised"]
168+
sigma = args["sigma"]
169+
model_options = args["model_options"]
170+
x = args["input"]
171+
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
172+
return cfg_result
173+
174+
# create the adversarially blurred image
175+
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
176+
degraded_noised = degraded + x - uncond_pred
177+
# call into the UNet
178+
(sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options)
179+
return cfg_result + (degraded - sag) * sag_scale
180+
181+
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
182+
183+
# from diffusers:
184+
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
185+
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
186+
187+
return io.NodeOutput(m)
188+
189+
NODES_LIST = [SelfAttentionGuidance]

comfy_extras/v3/nodes_sd3.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import comfy.model_management
6+
import comfy.sd
7+
import folder_paths
8+
import nodes
9+
from comfy_api.v3 import io, resources
10+
from comfy_extras.v3.nodes_slg import SkipLayerGuidanceDiT
11+
12+
13+
class CLIPTextEncodeSD3(io.ComfyNodeV3):
14+
@classmethod
15+
def define_schema(cls):
16+
return io.SchemaV3(
17+
node_id="CLIPTextEncodeSD3_V3",
18+
category="advanced/conditioning",
19+
inputs=[
20+
io.Clip.Input("clip"),
21+
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
22+
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
23+
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
24+
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
25+
],
26+
outputs=[
27+
io.Conditioning.Output(),
28+
],
29+
)
30+
31+
@classmethod
32+
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding: str):
33+
no_padding = empty_padding == "none"
34+
35+
tokens = clip.tokenize(clip_g)
36+
if len(clip_g) == 0 and no_padding:
37+
tokens["g"] = []
38+
39+
if len(clip_l) == 0 and no_padding:
40+
tokens["l"] = []
41+
else:
42+
tokens["l"] = clip.tokenize(clip_l)["l"]
43+
44+
if len(t5xxl) == 0 and no_padding:
45+
tokens["t5xxl"] = []
46+
else:
47+
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
48+
if len(tokens["l"]) != len(tokens["g"]):
49+
empty = clip.tokenize("")
50+
while len(tokens["l"]) < len(tokens["g"]):
51+
tokens["l"] += empty["l"]
52+
while len(tokens["l"]) > len(tokens["g"]):
53+
tokens["g"] += empty["g"]
54+
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
55+
56+
57+
class EmptySD3LatentImage(io.ComfyNodeV3):
58+
@classmethod
59+
def define_schema(cls):
60+
return io.SchemaV3(
61+
node_id="EmptySD3LatentImage_V3",
62+
category="latent/sd3",
63+
inputs=[
64+
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
65+
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
66+
io.Int.Input("batch_size", default=1, min=1, max=4096),
67+
],
68+
outputs=[
69+
io.Latent.Output(),
70+
],
71+
)
72+
73+
@classmethod
74+
def execute(cls, width: int, height: int, batch_size=1):
75+
latent = torch.zeros(
76+
[batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()
77+
)
78+
return io.NodeOutput({"samples":latent})
79+
80+
81+
class SkipLayerGuidanceSD3(SkipLayerGuidanceDiT):
82+
"""
83+
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
84+
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
85+
Experimental implementation by Dango233@StabilityAI.
86+
"""
87+
@classmethod
88+
def define_schema(cls):
89+
return io.SchemaV3(
90+
node_id="SkipLayerGuidanceSD3_V3",
91+
category="advanced/guidance",
92+
inputs=[
93+
io.Model.Input("model"),
94+
io.String.Input("layers", default="7, 8, 9", multiline=False),
95+
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
96+
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
97+
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
98+
],
99+
outputs=[
100+
io.Model.Output(),
101+
],
102+
is_experimental=True,
103+
)
104+
105+
@classmethod
106+
def execute(cls, model, layers: str, scale: float, start_percent: float, end_percent: float):
107+
return SkipLayerGuidanceDiT.execute(
108+
model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers
109+
)
110+
111+
112+
class TripleCLIPLoader(io.ComfyNodeV3):
113+
@classmethod
114+
def define_schema(cls):
115+
return io.SchemaV3(
116+
node_id="TripleCLIPLoader_V3",
117+
category="advanced/loaders",
118+
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
119+
inputs=[
120+
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
121+
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
122+
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
123+
],
124+
outputs=[
125+
io.Clip.Output(),
126+
],
127+
)
128+
129+
@classmethod
130+
def execute(cls, clip_name1: str, clip_name2: str, clip_name3: str):
131+
clip_data =[
132+
cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name1)),
133+
cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name2)),
134+
cls.resources.get(resources.TorchDictFolderFilename("text_encoders", clip_name3)),
135+
]
136+
return io.NodeOutput(
137+
comfy.sd.load_text_encoder_state_dicts(
138+
clip_data, embedding_directory=folder_paths.get_folder_paths("embeddings")
139+
)
140+
)
141+
142+
NODES_LIST = [
143+
CLIPTextEncodeSD3,
144+
EmptySD3LatentImage,
145+
SkipLayerGuidanceSD3,
146+
TripleCLIPLoader,
147+
]

comfy_extras/v3/nodes_sdupscale.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import comfy.utils
6+
from comfy_api.v3 import io
7+
8+
9+
class SD_4XUpscale_Conditioning(io.ComfyNodeV3):
10+
@classmethod
11+
def define_schema(cls):
12+
return io.SchemaV3(
13+
node_id="SD_4XUpscale_Conditioning_V3",
14+
category="conditioning/upscale_diffusion",
15+
inputs=[
16+
io.Image.Input("images"),
17+
io.Conditioning.Input("positive"),
18+
io.Conditioning.Input("negative"),
19+
io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01),
20+
io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001),
21+
],
22+
outputs=[
23+
io.Conditioning.Output(display_name="positive"),
24+
io.Conditioning.Output(display_name="negative"),
25+
io.Latent.Output(display_name="latent"),
26+
],
27+
)
28+
29+
@classmethod
30+
def execute(cls, images, positive, negative, scale_ratio, noise_augmentation):
31+
width = max(1, round(images.shape[-2] * scale_ratio))
32+
height = max(1, round(images.shape[-3] * scale_ratio))
33+
34+
pixels = comfy.utils.common_upscale(
35+
(images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center"
36+
)
37+
38+
out_cp = []
39+
out_cn = []
40+
41+
for t in positive:
42+
n = [t[0], t[1].copy()]
43+
n[1]['concat_image'] = pixels
44+
n[1]['noise_augmentation'] = noise_augmentation
45+
out_cp.append(n)
46+
47+
for t in negative:
48+
n = [t[0], t[1].copy()]
49+
n[1]['concat_image'] = pixels
50+
n[1]['noise_augmentation'] = noise_augmentation
51+
out_cn.append(n)
52+
53+
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
54+
return io.NodeOutput(out_cp, out_cn, {"samples":latent})
55+
56+
NODES_LIST = [SD_4XUpscale_Conditioning]

0 commit comments

Comments
 (0)