Skip to content

feat: ImageUpscaleWithModel and utils/tiled_scale_multidim to support parallel load with multiple GPU #8776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: worksplit-multigpu
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 147 additions & 58 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import itertools
from torch.nn.functional import interpolate
from torch.nn import DataParallel
from einops import rearrange
from comfy.cli_args import args

Expand Down Expand Up @@ -873,21 +874,30 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return rows * cols

@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4,
out_channels=3, output_device="cpu", downscale=False,
index_formulas=None, pbar=None):
dims = len(tile)

if not (isinstance(upscale_amount, (tuple, list))):
if not isinstance(upscale_amount, (tuple, list)):
upscale_amount = [upscale_amount] * dims

if not (isinstance(overlap, (tuple, list))):
if not isinstance(overlap, (tuple, list)):
overlap = [overlap] * dims

if index_formulas is None:
index_formulas = upscale_amount

if not (isinstance(index_formulas, (tuple, list))):
if not isinstance(index_formulas, (tuple, list)):
index_formulas = [index_formulas] * dims

def pad_to_size(tensor, target_size):
c, h, w = tensor.shape[-3:]
pad_h = max(target_size[0] - h, 0)
pad_w = max(target_size[1] - w, 0)
if pad_h == 0 and pad_w == 0:
return tensor, (0, 0, 0, 0) # (left, right, top, bottom)

padding = [0, pad_w, 0, pad_h] # [left, right, top, bottom]
padded = torch.nn.functional.pad(tensor, padding, mode='constant', value=0)
return padded, (0, pad_w, 0, pad_h) # 明确返回 (left, right, top, bottom)

def get_upscale(dim, val):
up = upscale_amount[dim]
if callable(up):
Expand Down Expand Up @@ -929,58 +939,137 @@ def mult_list_upscale(a):
out.append(round(get_scale(i, a[i])))
return out

output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]),
device=output_device)
is_data_parallel = isinstance(function, DataParallel)

for b in range(samples.shape[0]):
s = samples[b:b+1]

# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue

out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)

positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]

for it in itertools.product(*positions):
s_in = s
upscaled = []

for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))

ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)

for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)

o = out
o_d = out_div
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])

o.add_(ps * mask)
o_d.add_(mask)

if pbar is not None:
pbar.update(1)

output[b:b+1] = out/out_div
s = samples[b:b + 1]
if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
with torch.no_grad():
output[b:b + 1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue

out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]),
device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]),
device=output_device)

positions = [range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d])
if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
all_positions = list(itertools.product(*positions))

if is_data_parallel:
target_tile_size = (tile[0], tile[1]) # H, W

tile_inputs = []
positions_list = []
pad_info_list = []

for it in all_positions:
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))

padded_s_in, pad_info = pad_to_size(s_in, target_tile_size)
tile_inputs.append(padded_s_in)
positions_list.append(upscaled)
pad_info_list.append(pad_info)

with torch.amp.autocast(device_type='cuda', enabled=(samples.device.type == 'cuda')):
batched_input = torch.cat(tile_inputs, dim=0)
batched_output = function(batched_input).to(output_device)

for idx, upscaled in enumerate(positions_list):
ps = batched_output[idx:idx+1]
left_pad, right_pad, top_pad, bottom_pad = pad_info_list[idx]

if any(x > 0 for x in (left_pad, right_pad, top_pad, bottom_pad)):
ps = ps[...,
top_pad:ps.shape[-2] - bottom_pad,
left_pad:ps.shape[-1] - right_pad]

mask = torch.ones_like(ps)

for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
ps_cropped = ps

for d in range(dims):
dim_size = o.shape[d + 2]
start = upscaled[d]
length = mask.shape[d + 2]

if start + length > dim_size:
length = dim_size - start
if length <= 0:
continue
o = o.narrow(d + 2, start, length)
o_d = o_d.narrow(d + 2, start, length)
ps_cropped = ps_cropped.narrow(d + 2, 0, length)
mask = mask.narrow(d + 2, 0, length)
else:
o = o.narrow(d + 2, start, length)
o_d = o_d.narrow(d + 2, start, length)

o.add_(ps_cropped * mask)
o_d.add_(mask)

if pbar is not None:
pbar.update(1)

output[b:b + 1] = out / out_div
else:
for it in all_positions:
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))

with torch.amp.autocast(device_type='cuda', enabled=(samples.device.type == 'cuda')):
ps = function(s_in).to(output_device)

mask = torch.ones_like(ps)

for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)

o = out
o_d = out_div
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])

o.add_(ps * mask)
o_d.add_(mask)

if pbar is not None:
pbar.update(1)
output[b:b + 1] = out / out_div

return output

def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
Expand Down
32 changes: 28 additions & 4 deletions comfy_extras/nodes_upscale_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import comfy.utils
import folder_paths
from torch.nn import DataParallel

try:
from spandrel_extra_arches import EXTRA_REGISTRY
Expand Down Expand Up @@ -39,16 +40,29 @@ def load_model(self, model_name):
class ImageUpscaleWithModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
inputs = {"required": {
"upscale_model": ("UPSCALE_MODEL",),
"image": ("IMAGE",),
}}
},
"optional": {}}
for i in range(torch.cuda.device_count()):
inputs["optional"]["cuda_%d" % i] = ("BOOLEAN", {"default": True, "tooltip": "Use device %s" % torch.cuda.get_device_name(i)})
return inputs
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"

CATEGORY = "image/upscaling"

def upscale(self, upscale_model, image):
def upscale(self, upscale_model, image, **kwargs):
device = model_management.get_torch_device()
device_ids = []
for k, v in kwargs.items():
if k.startswith("cuda_") and v:
device_ids.append(int(k[5:]))
if kwargs.get("cuda_0"):
device = "cuda:0"
elif len(device_ids) > 0:
device = "cuda:%d" % device_ids[0]

memory_required = model_management.module_size(upscale_model.model)
memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate
Expand All @@ -58,6 +72,11 @@ def upscale(self, upscale_model, image):
upscale_model.to(device)
in_img = image.movedim(-1,-3).to(device)

if len(device_ids) > 1:
parallel_model = DataParallel(upscale_model.model, device_ids=device_ids)
else:
parallel_model = upscale_model

tile = 512
overlap = 32

Expand All @@ -66,14 +85,19 @@ def upscale(self, upscale_model, image):
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
s = comfy.utils.tiled_scale(in_img, parallel_model, tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=device)
oom = False
except torch.OutOfMemoryError as e:
tile //= 2
if tile < 128:
raise e
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e

upscale_model.to("cpu")
s.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,)

Expand Down