From a0f66a72623f8d2627e92d1d8e68869321ecc32e Mon Sep 17 00:00:00 2001 From: MagicBear Date: Thu, 3 Jul 2025 16:47:31 +0800 Subject: [PATCH 1/6] feat(utils): Optimize tiled_scale_multidim to support multiple GPU - add DataParallel model support for tiled_scale_multidim --- comfy/utils.py | 207 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 150 insertions(+), 57 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 1f8d71292511..c1d04145e837 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -872,22 +872,35 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap)) return rows * cols + @torch.inference_mode() -def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): +def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, + out_channels=3, output_device="cpu", downscale=False, + index_formulas=None, pbar=None): dims = len(tile) - - if not (isinstance(upscale_amount, (tuple, list))): + if not isinstance(upscale_amount, (tuple, list)): upscale_amount = [upscale_amount] * dims - - if not (isinstance(overlap, (tuple, list))): + if not isinstance(overlap, (tuple, list)): overlap = [overlap] * dims - if index_formulas is None: index_formulas = upscale_amount - - if not (isinstance(index_formulas, (tuple, list))): + if not isinstance(index_formulas, (tuple, list)): index_formulas = [index_formulas] * dims + def pad_to_size(tensor, target_size): + """ + Pad tensor to target_size (C, H, W) with zeros + """ + c, h, w = tensor.shape[-3:] + pad_h = max(target_size[0] - h, 0) + pad_w = max(target_size[1] - w, 0) + if pad_h == 0 and pad_w == 0: + return tensor, (0, 0, 0, 0) + + padding = [0, pad_w, 0, pad_h] # left, right, top, bottom + padded = torch.nn.functional.pad(tensor, padding, mode='constant', value=0) + return padded, (0, 0, pad_h, pad_w) + def get_upscale(dim, val): up = upscale_amount[dim] if callable(up): @@ -929,60 +942,140 @@ def mult_list_upscale(a): out.append(round(get_scale(i, a[i]))) return out - output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) + output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), + device=output_device) + is_data_parallel = isinstance(function, DataParallel) for b in range(samples.shape[0]): - s = samples[b:b+1] - + s = samples[b:b + 1] # handle entire input fitting in a single tile - if all(s.shape[d+2] <= tile[d] for d in range(dims)): - output[b:b+1] = function(s).to(output_device) - if pbar is not None: - pbar.update(1) - continue - - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - - positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - - for it in itertools.product(*positions): - s_in = s - upscaled = [] - - for d in range(dims): - pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) - l = min(tile[d], s.shape[d + 2] - pos) - s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_pos(d, pos))) - - ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) - - for d in range(2, dims + 2): - feather = round(get_scale(d - 2, overlap[d - 2])) - if feather >= mask.shape[d]: - continue - for t in range(feather): - a = (t + 1) / feather - mask.narrow(d, t, 1).mul_(a) - mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) - - o = out - o_d = out_div - for d in range(dims): - o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - - o.add_(ps * mask) - o_d.add_(mask) - - if pbar is not None: - pbar.update(1) - - output[b:b+1] = out/out_div + if all(s.shape[d + 2] <= tile[d] for d in range(dims)): + with torch.no_grad(): + output[b:b + 1] = function(s).to(output_device) + if pbar is not None: + pbar.update(1) + continue + + out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), + device=output_device) + out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), + device=output_device) + + positions = [range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d]) + if s.shape[d + 2] > tile[d] else [0] for d in range(dims)] + all_positions = list(itertools.product(*positions)) + total_positions = len(all_positions) + + if is_data_parallel: + target_tile_size = (tile[0], tile[1]) # H, W + + tile_inputs = [] + positions_list = [] + pad_info_list = [] + + for it in all_positions: + s_in = s + upscaled = [] + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(get_pos(d, pos))) + + padded_s_in, pad_info = pad_to_size(s_in, target_tile_size) + tile_inputs.append(padded_s_in) + positions_list.append(upscaled) + pad_info_list.append(pad_info) + + with torch.amp.autocast(device_type='cuda', enabled=(samples.device.type == 'cuda')): + batched_input = torch.cat(tile_inputs, dim=0) + batched_output = function(batched_input).to(output_device) + + for idx, upscaled in enumerate(positions_list): + ps = batched_output[idx:idx + 1] + + pad_t, pad_b, pad_l, pad_r = pad_info_list[idx] + if pad_t > 0 or pad_b > 0 or pad_l > 0 or pad_r > 0: + ps = ps[..., pad_t:ps.shape[-2] - pad_b, pad_l:ps.shape[-1] - pad_r] + + mask = torch.ones_like(ps) + + for d in range(2, dims + 2): + feather = round(get_scale(d - 2, overlap[d - 2])) + if feather >= mask.shape[d]: + continue + for t in range(feather): + a = (t + 1) / feather + mask.narrow(d, t, 1).mul_(a) + mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) + o = out + o_d = out_div + ps_cropped = ps + + for d in range(dims): + dim_size = o.shape[d + 2] + start = upscaled[d] + length = mask.shape[d + 2] + + if start + length > dim_size: + length = dim_size - start + if length <= 0: + continue + o = o.narrow(d + 2, start, length) + o_d = o_d.narrow(d + 2, start, length) + ps_cropped = ps_cropped.narrow(d + 2, 0, length) + mask = mask.narrow(d + 2, 0, length) + else: + o = o.narrow(d + 2, start, length) + o_d = o_d.narrow(d + 2, start, length) + + o.add_(ps_cropped * mask) + o_d.add_(mask) + + if pbar is not None: + pbar.update(1) + + output[b:b + 1] = out / out_div + else: + for it in all_positions: + s_in = s + upscaled = [] + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(get_pos(d, pos))) + + with torch.amp.autocast(device_type='cuda', enabled=(samples.device.type == 'cuda')): + ps = function(s_in).to(output_device) + + mask = torch.ones_like(ps) + + for d in range(2, dims + 2): + feather = round(get_scale(d - 2, overlap[d - 2])) + if feather >= mask.shape[d]: + continue + for t in range(feather): + a = (t + 1) / feather + mask.narrow(d, t, 1).mul_(a) + mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) + + o = out + o_d = out_div + for d in range(dims): + o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + + o.add_(ps * mask) + o_d.add_(mask) + + if pbar is not None: + pbar.update(1) + output[b:b + 1] = out / out_div + return output + def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) From 2f386dcba1857ec89500b95ccf13eea75792b8da Mon Sep 17 00:00:00 2001 From: MagicBear Date: Thu, 3 Jul 2025 16:54:51 +0800 Subject: [PATCH 2/6] feat(utils): Optimize tiled_scale_multidim to support multiple GPU - add DataParallel model support for tiled_scale_multidim --- comfy/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/utils.py b/comfy/utils.py index c1d04145e837..5e16ae7618df 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -27,6 +27,7 @@ import logging import itertools from torch.nn.functional import interpolate +from torch.nn import DataParallel from einops import rearrange from comfy.cli_args import args From 71b505f902e599cb94602091f215d89d08ab67e5 Mon Sep 17 00:00:00 2001 From: MagicBear Date: Fri, 4 Jul 2025 00:31:15 +0800 Subject: [PATCH 3/6] feat(upscale): Add parallel support for multiple GPU --- comfy_extras/nodes_upscale_model.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 04c948341296..d17da8514b7c 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,6 +4,7 @@ import torch import comfy.utils import folder_paths +from torch.nn import DataParallel try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -39,16 +40,29 @@ def load_model(self, model_name): class ImageUpscaleWithModel: @classmethod def INPUT_TYPES(s): - return {"required": { "upscale_model": ("UPSCALE_MODEL",), + inputs = {"required": { + "upscale_model": ("UPSCALE_MODEL",), "image": ("IMAGE",), - }} + }, + "optional": {}} + for i in range(torch.cuda.device_count()): + inputs["optional"]["cuda_%d" % i] = ("BOOLEAN", {"default": True, "tooltip": "Use device %s" % torch.cuda.get_device_name(i)}) + return inputs RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" CATEGORY = "image/upscaling" - def upscale(self, upscale_model, image): + def upscale(self, upscale_model, image, **kwargs): device = model_management.get_torch_device() + device_ids = [] + for k, v in kwargs.items(): + if k.startswith("cuda_") and v: + device_ids.append(int(k[5:])) + if kwargs.get("cuda_0"): + device = "cuda:0" + elif len(device_ids) > 0: + device = "cuda:%d" % device_ids[0] memory_required = model_management.module_size(upscale_model.model) memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate @@ -58,6 +72,11 @@ def upscale(self, upscale_model, image): upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) + if len(device_ids) > 1: + parallel_model = DataParallel(upscale_model.model, device_ids=device_ids) + else: + parallel_model = upscale_model + tile = 512 overlap = 32 @@ -66,7 +85,7 @@ def upscale(self, upscale_model, image): try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + s = comfy.utils.tiled_scale(in_img, parallel_model, tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=device) oom = False except model_management.OOM_EXCEPTION as e: tile //= 2 @@ -75,6 +94,7 @@ def upscale(self, upscale_model, image): upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) + s.to("cpu") return (s,) NODE_CLASS_MAPPINGS = { From 3a3b5babf2eb9ca02d46874516472ce955acdb6e Mon Sep 17 00:00:00 2001 From: MagicBear Date: Fri, 4 Jul 2025 00:34:18 +0800 Subject: [PATCH 4/6] refactor(comfy/utils.py): remove unused variable --- comfy/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 5e16ae7618df..3c2bb04a2b34 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -965,7 +965,6 @@ def mult_list_upscale(a): positions = [range(0, s.shape[d + 2] - overlap[d], tile[d] - overlap[d]) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)] all_positions = list(itertools.product(*positions)) - total_positions = len(all_positions) if is_data_parallel: target_tile_size = (tile[0], tile[1]) # H, W From ddec938cde47cc229732b3846c7fa4cdd1f48d48 Mon Sep 17 00:00:00 2001 From: MagicBear Date: Fri, 4 Jul 2025 02:34:01 +0800 Subject: [PATCH 5/6] fix(comfy_extras): Fix OOM Tile --- comfy_extras/nodes_upscale_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d17da8514b7c..199c83d33682 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -87,14 +87,18 @@ def upscale(self, upscale_model, image, **kwargs): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, parallel_model, tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=device) oom = False + except torch.OutOfMemoryError as e: + tile //= 2 + if tile < 128: + raise e except model_management.OOM_EXCEPTION as e: tile //= 2 if tile < 128: raise e upscale_model.to("cpu") - s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s.to("cpu") + s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) NODE_CLASS_MAPPINGS = { From 26c55e145cff28813b26d92119bdf236a0c4da03 Mon Sep 17 00:00:00 2001 From: MagicBear Date: Fri, 4 Jul 2025 04:08:33 +0800 Subject: [PATCH 6/6] fix(comfy/utils.py): fix crop pad_to_size bug --- comfy/utils.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 3c2bb04a2b34..0b806abdef68 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -873,7 +873,6 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap)) return rows * cols - @torch.inference_mode() def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, @@ -889,18 +888,15 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am index_formulas = [index_formulas] * dims def pad_to_size(tensor, target_size): - """ - Pad tensor to target_size (C, H, W) with zeros - """ c, h, w = tensor.shape[-3:] pad_h = max(target_size[0] - h, 0) pad_w = max(target_size[1] - w, 0) if pad_h == 0 and pad_w == 0: - return tensor, (0, 0, 0, 0) + return tensor, (0, 0, 0, 0) # (left, right, top, bottom) - padding = [0, pad_w, 0, pad_h] # left, right, top, bottom + padding = [0, pad_w, 0, pad_h] # [left, right, top, bottom] padded = torch.nn.functional.pad(tensor, padding, mode='constant', value=0) - return padded, (0, 0, pad_h, pad_w) + return padded, (0, pad_w, 0, pad_h) # 明确返回 (left, right, top, bottom) def get_upscale(dim, val): up = upscale_amount[dim] @@ -949,7 +945,6 @@ def mult_list_upscale(a): for b in range(samples.shape[0]): s = samples[b:b + 1] - # handle entire input fitting in a single tile if all(s.shape[d + 2] <= tile[d] for d in range(dims)): with torch.no_grad(): output[b:b + 1] = function(s).to(output_device) @@ -992,11 +987,13 @@ def mult_list_upscale(a): batched_output = function(batched_input).to(output_device) for idx, upscaled in enumerate(positions_list): - ps = batched_output[idx:idx + 1] + ps = batched_output[idx:idx+1] + left_pad, right_pad, top_pad, bottom_pad = pad_info_list[idx] - pad_t, pad_b, pad_l, pad_r = pad_info_list[idx] - if pad_t > 0 or pad_b > 0 or pad_l > 0 or pad_r > 0: - ps = ps[..., pad_t:ps.shape[-2] - pad_b, pad_l:ps.shape[-1] - pad_r] + if any(x > 0 for x in (left_pad, right_pad, top_pad, bottom_pad)): + ps = ps[..., + top_pad:ps.shape[-2] - bottom_pad, + left_pad:ps.shape[-1] - right_pad] mask = torch.ones_like(ps) @@ -1075,7 +1072,6 @@ def mult_list_upscale(a): return output - def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)