diff --git a/comfy/ops.py b/comfy/ops.py index a0ff4e8f1710..90027e7dabe7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -464,6 +464,22 @@ def forward_comfy_cast_weights(self, input): uncast_bias_weight(self, weight, bias, offload_stream) return x + +def scale_hadamard(larger, smaller, k_value=None, divide=False): + if smaller.shape == torch.Size([]): + # unset + return larger + h, w = smaller.shape + if k_value is None: + #calculate from larger compared to smaller + k_value = larger.shape[-1] // smaller.shape[-1] + expected_shape = (h * k_value, w * k_value) + assert larger.shape == expected_shape, "weight_scale_inv mismatch, skipping" + if divide: + smaller = 1.0 / smaller + result = larger.view(h, k_value, w, k_value) * smaller.view(h, 1, w, 1) + return result.reshape(h * k_value, w * k_value) + def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): @@ -477,6 +493,9 @@ def reset_parameters(self): if not hasattr(self, 'scale_weight'): self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) + if not hasattr(self, 'weight_scale_inv'): + self.weight_scale_inv = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) + if not scale_input: self.scale_input = None @@ -502,12 +521,17 @@ def forward_comfy_cast_weights(self, input): def convert_weight(self, weight, inplace=False, **kwargs): if inplace: weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) + if self.weight_scale_inv.shape == torch.Size([]): + return weight + weight = scale_hadamard(weight, self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None, divide=False) return weight else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) + return scale_hadamard(weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32), + self.weight_scale_inv.to(device=weight.device, dtype=torch.float32), k_value=None, divide=False) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) + weight = comfy.float.stochastic_rounding(scale_hadamard(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), + self.weight_scale_inv.to(device=weight.device, dtype=weight.dtype), k_value=None, divide=True), self.weight.dtype, seed=seed) if return_weight: return weight if inplace_update: