From 701e4999dc9bd6cd413444d01d60b5ae9c4491d1 Mon Sep 17 00:00:00 2001 From: yoinked Date: Wed, 26 Nov 2025 17:17:39 -0800 Subject: [PATCH 1/3] add weight_scale_inv impl --- comfy/ops.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1710..caceb14a25eb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -464,6 +464,22 @@ def forward_comfy_cast_weights(self, input): uncast_bias_weight(self, weight, bias, offload_stream) return x + +def scale_hadamard(larger, smaller, k_value=None, divide=False): + if smaller.shape == torch.Size([]): + # unset + return larger + h, w = smaller.shape + if k_value is None: + #calculate from larger compared to smaller + k_value = larger.shape[-1] // smaller.shape[-1] + expected_shape = (h * k_value, w * k_value) + assert larger.shape == expected_shape, "weight_scale_inv mismatch, skipping" + if divide: + smaller = 1.0 / smaller + result = larger.view(h, k_value, w, k_value) * smaller.view(h, 1, w, 1) + return result.reshape(h * k_value, w * k_value) + def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): @@ -476,6 +492,9 @@ def __init__(self, *args, **kwargs): def reset_parameters(self): if not hasattr(self, 'scale_weight'): self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) + + if not hasattr(self, 'weight_scale_inv'): + self.weight_scale_inv = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) if not scale_input: self.scale_input = None @@ -502,12 +521,17 @@ def forward_comfy_cast_weights(self, input): def convert_weight(self, weight, inplace=False, **kwargs): if inplace: weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) + if self.weight_scale_inv.shape == torch.Size([]): + return weight + weight = scale_hadamard(weight, self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None, divide=False) return weight else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) + return scale_hadamard(weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32), + self.weight_scale_inv.to(device=weight.device, dtype=torch.float32), k_value=None, divide=False) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) + weight = comfy.float.stochastic_rounding(scale_hadamard(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), + self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None), self.weight.dtype, seed=seed) if return_weight: return weight if inplace_update: From a4c2b6f5b203035c2e1c599633b130a45ff17632 Mon Sep 17 00:00:00 2001 From: yoinked Date: Wed, 26 Nov 2025 17:31:25 -0800 Subject: [PATCH 2/3] divide when dividing --- comfy/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index caceb14a25eb..a16ad44a1e8d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -531,7 +531,7 @@ def convert_weight(self, weight, inplace=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): weight = comfy.float.stochastic_rounding(scale_hadamard(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), - self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None), self.weight.dtype, seed=seed) + self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None, divide=True), self.weight.dtype, seed=seed) if return_weight: return weight if inplace_update: From af7becd2f197d992d2c6cd594bd5a0cc5c4ed199 Mon Sep 17 00:00:00 2001 From: yoinked Date: Wed, 26 Nov 2025 18:07:59 -0800 Subject: [PATCH 3/3] lint --- comfy/ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index a16ad44a1e8d..90027e7dabe7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -476,7 +476,7 @@ def scale_hadamard(larger, smaller, k_value=None, divide=False): expected_shape = (h * k_value, w * k_value) assert larger.shape == expected_shape, "weight_scale_inv mismatch, skipping" if divide: - smaller = 1.0 / smaller + smaller = 1.0 / smaller result = larger.view(h, k_value, w, k_value) * smaller.view(h, 1, w, 1) return result.reshape(h * k_value, w * k_value) @@ -492,7 +492,7 @@ def __init__(self, *args, **kwargs): def reset_parameters(self): if not hasattr(self, 'scale_weight'): self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - + if not hasattr(self, 'weight_scale_inv'): self.weight_scale_inv = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) @@ -526,11 +526,11 @@ def convert_weight(self, weight, inplace=False, **kwargs): weight = scale_hadamard(weight, self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None, divide=False) return weight else: - return scale_hadamard(weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32), + return scale_hadamard(weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32), self.weight_scale_inv.to(device=weight.device, dtype=torch.float32), k_value=None, divide=False) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(scale_hadamard(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), + weight = comfy.float.stochastic_rounding(scale_hadamard(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None, divide=True), self.weight.dtype, seed=seed) if return_weight: return weight