From c8794e1155fa64a58164eca7dca9168e903a1a50 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Sat, 29 Nov 2025 12:01:17 -0600 Subject: [PATCH] add svdquant int4 support, modify qwen model to support nunchaku style merged qkv --- QUANTIZATION.md | 157 +++++- comfy/ldm/qwen_image/model.py | 185 +++++-- comfy/model_detection.py | 24 + comfy/ops.py | 67 ++- comfy/quant_ops.py | 501 +++++++++++++++++- comfy/svdquant_converter.py | 377 +++++++++++++ convert_svdquant_checkpoint.py | 116 ++++ tests-unit/comfy_quant/test_quant_registry.py | 359 ++++++++++++- 8 files changed, 1720 insertions(+), 66 deletions(-) create mode 100644 comfy/svdquant_converter.py create mode 100644 convert_svdquant_checkpoint.py diff --git a/QUANTIZATION.md b/QUANTIZATION.md index 1693e13f32e2..98a8eadee7e0 100644 --- a/QUANTIZATION.md +++ b/QUANTIZATION.md @@ -124,6 +124,22 @@ We define 4 possible scaling parameters that should cover most recipes in the ne | Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | |--------|---------------|--------------|----------------|-----------------|-------------| | float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | +| svdquant_int4 | int8 (packed 4-bit) | - | - | - | - | +| svdquant_nvfp4 | int8 (packed 4-bit) | - | - | - | - | +| awq_int4 | int32 (packed 4-bit) | - | - | - | - | + +For SVDQuant formats, additional parameters are stored: +- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features) +- **smooth_factor**: Smoothing factors for inputs (shape: in_features) +- **smooth_factor_orig**: Original smoothing factors (shape: in_features) +- **proj_down**: Low-rank down projection (shape: in_features, rank) +- **proj_up**: Low-rank up projection (shape: out_features, rank) +- **wtscale**: Global weight scale (nvfp4 only, scalar float) +- **wcscales**: Channel-wise weight scales (nvfp4 only, shape: out_features) + +For AWQ format, the following parameters are stored: +- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features) +- **wzeros**: Weight zero points (shape: in_features // group_size, out_features) You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). @@ -139,9 +155,9 @@ Example: "_quantization_metadata": { "format_version": "1.0", "layers": { - "model.layers.0.mlp.up_proj": "float8_e4m3fn", - "model.layers.0.mlp.down_proj": "float8_e4m3fn", - "model.layers.1.mlp.up_proj": "float8_e4m3fn" + "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"}, + "model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"}, + "model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"} } } } @@ -165,4 +181,137 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s 3. **Compute scales**: Derive `input_scale` from collected statistics 4. **Store in checkpoint**: Save `input_scale` parameters alongside weights -The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. \ No newline at end of file +The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. + + +## SVDQuant + +SVDQuant is an advanced 4-bit quantization scheme that decomposes linear operations using low-rank factorization combined with residual quantization: + +``` +X*W = X * proj_down * proj_up + quantize(X) * quantize(R) +``` + +Where: +- `proj_down`, `proj_up`: Low-rank factorization matrices of the original weights +- `R`: Residual weights (quantized to 4-bit) +- `quantize()`: 4-bit quantization with smoothing factors + +### Key Features + +1. **Asymmetric Quantization**: Unlike FP8 where both weights and activations are quantized offline or use the same quantization scheme, SVDQuant: + - Quantizes weights offline with multiple parameters stored in the checkpoint + - Quantizes activations on-the-fly during forward pass using smoothing factors + +2. **Two Precision Modes**: + - `svdquant_int4`: 4-bit integer quantization with group_size=64 + - `svdquant_nvfp4`: 4-bit floating-point (NVIDIA FP4) with group_size=16, includes additional channel-wise scales + +3. **Low-Rank Optimization**: Separates the easy-to-approximate low-rank component from the hard-to-quantize residual, improving accuracy. + +### Implementation + +SVDQuant requires the `nunchaku` library for optimized CUDA kernels: +```bash +pip install nunchaku +``` + +The implementation uses two main operations: +- `svdq_quantize_w4a4_act_fuse_lora_cuda`: Quantizes activations and computes low-rank hidden states +- `svdq_gemm_w4a4_cuda`: Performs the quantized GEMM with low-rank residual addition + +### Checkpoint Format + +SVDQuant checkpoints contain the standard weight tensor (packed 4-bit residuals in int8) plus additional parameters per quantized layer: + +```python +{ + "layer_name.weight": tensor, # Packed 4-bit residual weights (out_features, in_features // 2) + "layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features) + "layer_name.smooth_factor": tensor, # Smoothing factors (in_features,) + "layer_name.smooth_factor_orig": tensor, # Original smoothing factors (in_features,) + "layer_name.proj_down": tensor, # Low-rank down projection (in_features, rank) + "layer_name.proj_up": tensor, # Low-rank up projection (out_features, rank) + + # For nvfp4 only: + "layer_name.wtscale": float, # Global weight scale + "layer_name.wcscales": tensor, # Channel-wise scales (out_features,) +} +``` + +The quantization metadata specifies which layers use SVDQuant: + +```json +{ + "_quantization_metadata": { + "format_version": "1.0", + "layers": { + "model.layers.0.mlp.up_proj": {"format": "svdquant_int4"}, + "model.layers.0.mlp.down_proj": {"format": "svdquant_int4"} + } + } +} +``` + +## AWQ + +AWQ (Activation-aware Weight Quantization) is a 4-bit weight quantization scheme that keeps activations in 16-bit precision (W4A16): + +``` +Y = X @ W_quantized +``` + +Where: +- `X`: 16-bit activations (float16/bfloat16) +- `W_quantized`: 4-bit quantized weights with per-group scales and zero points + +### Key Features + +1. **W4A16 Quantization**: + - Quantizes weights to 4-bit while keeping activations in 16-bit + - Uses per-group quantization with configurable group size (typically 64) + - Stores zero points for asymmetric quantization + +2. **Activation-Aware**: + - Quantization is calibrated based on activation statistics + - Protects salient weights that are important for accuracy + +3. **Hardware Efficient**: + - Optimized for GPU inference + - Significantly reduces memory footprint + - Increases throughput with specialized kernels + +### Implementation + +AWQ requires the `nunchaku` library for optimized CUDA kernels: +```bash +pip install nunchaku +``` + +The implementation uses the `awq_gemv_w4a16_cuda` kernel for efficient W4A16 matrix multiplication. + +### Checkpoint Format + +AWQ checkpoints contain the standard weight tensor (packed 4-bit weights in int32) plus additional parameters per quantized layer: + +```python +{ + "layer_name.weight": tensor, # Packed 4-bit weights (out_features // 4, in_features // 2) + "layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features) + "layer_name.wzeros": tensor, # Zero points (in_features // group_size, out_features) +} +``` + +The quantization metadata specifies which layers use AWQ: + +```json +{ + "_quantization_metadata": { + "format_version": "1.0", + "layers": { + "model.layers.0.mlp.up_proj": {"format": "awq_int4"}, + "model.layers.0.mlp.down_proj": {"format": "awq_int4"} + } + } +} +``` \ No newline at end of file diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 8c75670cd89a..272a67c31d18 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -4,7 +4,6 @@ import torch.nn.functional as F from typing import Optional, Tuple from einops import repeat - from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND @@ -12,8 +11,9 @@ import comfy.patcher_extension from comfy.ldm.flux.math import apply_rope1 + class GELU(nn.Module): - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.proj = operations.Linear(dim_in, dim_out, bias=bias, dtype=dtype, device=device) self.approximate = approximate @@ -33,7 +33,9 @@ def __init__( dropout: float = 0.0, inner_dim=None, bias: bool = True, - dtype=None, device=None, operations=None + dtype=None, device=None, operations=None, + svdquant_format=False, + **kwargs, ): super().__init__() if inner_dim is None: @@ -41,7 +43,7 @@ def __init__( dim_out = dim_out if dim_out is not None else dim self.net = nn.ModuleList([]) - self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations)) + self.net.append(GELU(dim, inner_dim, approximate="tanh", bias=bias, dtype=dtype, device=device, operations=operations, **kwargs)) self.net.append(nn.Dropout(dropout)) self.net.append(operations.Linear(inner_dim, dim_out, bias=bias, dtype=dtype, device=device)) @@ -92,7 +94,9 @@ def __init__( out_context_dim: int = None, dtype=None, device=None, - operations=None + operations=None, + svdquant_format=False, + **kwargs, ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -109,21 +113,30 @@ def __init__( self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.svdquant_format = svdquant_format + # Image stream projections - self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) - self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) - self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + if self.svdquant_format: # svdq merged qkv for better perf + self.to_qkv = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.to_k = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.to_v = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) # Text stream projections - self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) - self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) - self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + if self.svdquant_format: + self.add_qkv_proj = operations.Linear(query_dim, self.inner_dim + self.inner_kv_dim * 2, bias=bias, dtype=dtype, device=device) + else: + self.add_q_proj = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device) + self.add_k_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) + self.add_v_proj = operations.Linear(query_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device) # Output projections self.to_out = nn.ModuleList([ operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device), nn.Dropout(dropout) ]) + self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device) def forward( @@ -140,29 +153,64 @@ def forward( seq_txt = encoder_hidden_states.shape[1] # Project and reshape to BHND format (batch, heads, seq, dim) - img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() - img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() - img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) + if self.svdquant_format: + img_qkv = self.to_qkv(hidden_states) + img_query, img_key, img_value = img_qkv.chunk(3, dim=-1) + # Reshape for multi-head attention to [B, L, H, D] + img_query = img_query.unflatten(-1, (self.heads, -1)) # [B, L, H, D] + img_key = img_key.unflatten(-1, (self.heads, -1)) + img_value = img_value.unflatten(-1, (self.heads, -1)) + else: + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) + if self.svdquant_format: + txt_qkv = self.add_qkv_proj(encoder_hidden_states) + txt_query, txt_key, txt_value = txt_qkv.chunk(3, dim=-1) + # Reshape for multi-head attention to [B, L, H, D] + txt_query = txt_query.unflatten(-1, (self.heads, -1)) + txt_key = txt_key.unflatten(-1, (self.heads, -1)) + txt_value = txt_value.unflatten(-1, (self.heads, -1)) + else: + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) - txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() - txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() - txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - joint_query = torch.cat([txt_query, img_query], dim=2) - joint_key = torch.cat([txt_key, img_key], dim=2) - joint_value = torch.cat([txt_value, img_value], dim=2) + if self.svdquant_format: + # Concatenate image and text streams for joint attention + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Apply rotary embeddings to concatenated tensors + joint_query = apply_rotary_emb(joint_query, image_rotary_emb) + joint_key = apply_rotary_emb(joint_key, image_rotary_emb) - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) + # Flatten to [B, L, H*D] for attention + joint_query = joint_query.flatten(start_dim=2) + joint_key = joint_key.flatten(start_dim=2) + joint_value = joint_value.flatten(start_dim=2) - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, - attention_mask, transformer_options=transformer_options, - skip_reshape=True) + joint_hidden_states = optimized_attention_masked( + joint_query, joint_key, joint_value, self.heads, attention_mask + ) + else: + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attention_mask, transformer_options=transformer_options, + skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -183,28 +231,38 @@ def __init__( eps: float = 1e-6, dtype=None, device=None, - operations=None + operations=None, + scale_shift: float = None, + svdquant_format=False, + **kwargs, ): super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim + self.svdquant_format = svdquant_format + # For svdquant, scale_shift should be 0 as the shift is fused into weights + if scale_shift is None: + scale_shift = 0.0 if self.svdquant_format else 1.0 + self.scale_shift = scale_shift self.img_mod = nn.Sequential( nn.SiLU(), operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), ) + self.img_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) self.img_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) - self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs) self.txt_mod = nn.Sequential( nn.SiLU(), operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), ) + self.txt_norm1 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) self.txt_norm2 = operations.LayerNorm(dim, elementwise_affine=False, eps=eps, dtype=dtype, device=device) - self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, dtype=dtype, device=device, operations=operations, **kwargs) self.attn = Attention( query_dim=dim, @@ -216,11 +274,18 @@ def __init__( dtype=dtype, device=device, operations=operations, + svdquant_format=svdquant_format, + **kwargs, ) def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) - return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) + if self.svdquant_format: + if self.scale_shift != 0: + scale.add_(self.scale_shift) + return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1) + else: + return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) def forward( self, @@ -233,21 +298,42 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: img_mod_params = self.img_mod(temb) txt_mod_params = self.txt_mod(temb) + + # Nunchaku's mod_params layout is [B, dim*6] with different ordering + # Need to reshape from [B, dim*6] to correct layout + + if self.svdquant_format: + img_mod_params = ( + img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1) + ) + txt_mod_params = ( + txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1) + ) + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) del img_mod1 txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) del txt_mod1 - img_attn_output, txt_attn_output = self.attn( - hidden_states=img_modulated, - encoder_hidden_states=txt_modulated, - encoder_hidden_states_mask=encoder_hidden_states_mask, - image_rotary_emb=image_rotary_emb, - transformer_options=transformer_options, - ) + if self.svdquant_format: + img_attn_output, txt_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + ) + else: + img_attn_output, txt_attn_output = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) del img_modulated del txt_modulated @@ -258,6 +344,8 @@ def forward( del img_gate1 del txt_gate1 + + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) @@ -307,7 +395,15 @@ def __init__( dtype=None, device=None, operations=None, + scale_shift: float = None, + svdquant_format=False, + **kwargs, ): + # For svdquant, scale_shift should be 0 as the shift is fused into weights + self.svdquant_format = svdquant_format + + if scale_shift is None: + scale_shift = 0.0 if self.svdquant_format else 1.0 super().__init__() self.dtype = dtype self.patch_size = patch_size @@ -336,7 +432,10 @@ def __init__( attention_head_dim=attention_head_dim, dtype=dtype, device=device, - operations=operations + operations=operations, + scale_shift=scale_shift, + svdquant_format=svdquant_format, + **kwargs ) for _ in range(num_layers) ]) @@ -384,10 +483,12 @@ def _forward( control=None, **kwargs ): + #from safetensors import safe_open + #with safe_open("/root/nck_x.safetensors", framework="pt", device="cuda") as f: + # x = f.get_tensor("nck_x") timestep = timesteps encoder_hidden_states = context encoder_hidden_states_mask = attention_mask - hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] @@ -419,7 +520,10 @@ def _forward( txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + if self.svdquant_format: + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + else: + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) @@ -441,6 +545,9 @@ def _forward( transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" + + + for i, block in enumerate(self.transformer_blocks): transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0517e611b1..0986bb2a7953 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -623,6 +623,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "qwen_image" dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') + + # Add SVDQuant linear support + if '{}transformer_blocks.0.attn.add_qkv_proj.weight'.format(key_prefix) in state_dict_keys: + # try import nunchaku: + try: + from nunchaku.models.linear import SVDQW4A4Linear + except ImportError: + raise ImportError( + "SVDQuant requires the nunchaku library. " + "Please follow the instructions in https://nunchaku.tech/docs/nunchaku/installation/installation.html to install nunchaku" + ) + + dit_config["svdquant_format"] = True + + if metadata is not None and 'config' in metadata: + if 'quantization_config' in metadata: + import json + metadata_quantization_config = json.loads(metadata['quantization_config']) + if 'weight' in metadata_quantization_config: + if metadata_quantization_config["weight"]["dtype"] == "fp4_e2m1_all": + if metadata_quantization_config["weight"]["group_size"] == 16: + dit_config['precision'] = "nvfp4" + elif metadata_quantization_config["weight"]["dtype"] == "int4": + dit_config['precision'] = "int4" return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754f85..a27a05a40dbd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ import comfy.float import comfy.rmsnorm import contextlib +from comfy.quant_ops import QuantizedTensor def run_every_op(): if torch.compiler.is_compiling(): @@ -582,11 +583,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): device = self.factory_kwargs["device"] + if device is None and self.bias is not None: + device = self.bias.device + layer_name = prefix.rstrip('.') weight_key = f"{prefix}weight" weight = state_dict.pop(weight_key, None) if weight is None: raise ValueError(f"Missing weight for layer {layer_name}") + + if device is None: + device = weight.device manually_loaded_keys = [weight_key] @@ -600,27 +607,58 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, qconfig = QUANT_ALGOS[quant_format] self.layout_type = qconfig["comfy_tensor_layout"] - weight_scale_key = f"{prefix}weight_scale" + # Build layout_params - start with basic parameters layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), + 'is_weight': True, # Mark this as a weight tensor } - if layout_params['scale'] is not None: + + # Add group_size and precision if present in qconfig + if 'group_size' in qconfig: + layout_params['group_size'] = qconfig['group_size'] + if 'precision' in qconfig: + layout_params['precision'] = qconfig['precision'] + + # Handle weight_scale + weight_scale_key = f"{prefix}weight_scale" + weight_scale = state_dict.pop(weight_scale_key, None) + if weight_scale is not None: + layout_params['scale'] = weight_scale manually_loaded_keys.append(weight_scale_key) + # custom_layer_params_keys are loaded into layout_params from state_dict + if 'custom_layer_params_keys' in qconfig: + for param_name in qconfig['custom_layer_params_keys']: + param_key = f"{prefix}{param_name}" + param_value = state_dict.pop(param_key, None) + if param_value is not None: + layout_params[param_name] = param_value.to(device=device).contiguous() + manually_loaded_keys.append(param_key) + else: + logging.warning(f"Missing custom parameter {param_name} for layer {layer_name}") + + # parameters are loaded into module attributes from state_dict + for param_name in qconfig["parameters"]: + if param_name in layout_params: + continue # Already loaded via custom_layer_params_keys or weight_scale + + param_key = f"{prefix}{param_name}" + param_value = state_dict.pop(param_key, None) + if param_value is not None: + # For standard parameters, store as module attributes + setattr(self, param_name, torch.nn.Parameter(param_value.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + # Create the quantized weight tensor + quantized_weight = QuantizedTensor(weight.to(device=device), + self.layout_type, layout_params) + + self.weight_prefix = prefix self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + quantized_weight, requires_grad=False ) - - for param_name in qconfig["parameters"]: - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) + self.weight.requires_grad = False super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -646,6 +684,7 @@ def forward(self, input, *args, **kwargs): getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) def convert_weight(self, weight, inplace=False, **kwargs): @@ -696,4 +735,4 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init - return manual_cast + return manual_cast \ No newline at end of file diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index bb1fb860ca52..af04ae696bc0 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -46,13 +46,28 @@ def decorator(handler_func): def _get_layout_from_args(args): + def _extract_layout(obj): + if isinstance(obj, QuantizedTensor): + return obj._layout_type + # For torch.nn.Parameter wrapping QuantizedTensor, check the data attribute + if isinstance(obj, torch.nn.Parameter): + if isinstance(obj.data, QuantizedTensor): + return obj.data._layout_type + if hasattr(obj.data, "_layout_type"): + return getattr(obj.data, "_layout_type", None) + if hasattr(obj, "_layout_type"): + return getattr(obj, "_layout_type", None) + return None + for arg in args: - if isinstance(arg, QuantizedTensor): - return arg._layout_type - elif isinstance(arg, (list, tuple)): + layout = _extract_layout(arg) + if layout is not None: + return layout + if isinstance(arg, (list, tuple)): for item in arg: - if isinstance(item, QuantizedTensor): - return item._layout_type + layout = _extract_layout(item) + if layout is not None: + return layout return None @@ -438,6 +453,46 @@ def get_plain_tensors(cls, qtensor): "parameters": {"weight_scale", "input_scale"}, "comfy_tensor_layout": "TensorCoreFP8Layout", }, + "svdquant_int4": { + "storage_t": torch.int8, # Packed 4-bit stored in int8 + "parameters": { + "wscales", + "smooth_factor", + "smooth_factor_orig", + "proj_down", + "proj_up", + }, + "custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up"], + "comfy_tensor_layout": "SVDQuantLayout", + "group_size": 64, + "precision": "int4", + }, + "svdquant_nvfp4": { + "storage_t": torch.int8, # Packed 4-bit stored in int8 + "parameters": { + "wscales", + "smooth_factor", + "smooth_factor_orig", + "proj_down", + "proj_up", + "wtscale", + "wcscales", + }, + "custom_layer_params_keys": ["wscales", "smooth_factor", "smooth_factor_orig", "proj_down", "proj_up", "wtscale", "wcscales"], + "comfy_tensor_layout": "SVDQuantLayout", + "group_size": 16, + "precision": "nvfp4", + }, + "awq_int4": { + "storage_t": torch.int32, # Packed 4-bit stored in int32 + "parameters": { + "wscales", + "wzeros", + }, + "custom_layer_params_keys": ["wscales", "wzeros"], + "comfy_tensor_layout": "AWQQuantLayout", + "group_size": 64, + }, } LAYOUTS = { @@ -571,3 +626,439 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + + +# ============================================================================== +# SVDQuant Layout + Operation Handlers +# ============================================================================== + +class SVDQuantLayout(QuantizedLayout): + """ + SVDQuant W4A4 quantization layout. + + SVDQuant decomposes linear operations as: + X*W = X * proj_up * proj_down + quantize(X) * quantize(R) + + Where: + - proj_up, proj_down: Low-rank factorization of weights + - R: Residual weights (quantized to 4-bit) + - quantize(): 4-bit quantization with smoothing factors + + Storage format: + For weights (is_weight=True): + - qdata: Packed quantized residual weights (out_features, in_features // 2), int8 + - wscales: Weight quantization scales + - smooth_factor: Smoothing factors for inputs + - proj_down: Low-rank down projection + - proj_up: Low-rank up projection + - group_size: Quantization group size (64 for int4, 16 for nvfp4) + - precision: 'int4' or 'nvfp4' + - rank: SVD rank + - wtscale: Global weight scale (nvfp4 only) + - wcscales: Channel-wise weight scales (nvfp4 only) + - act_unsigned: Whether activations are unsigned (int4 only) + - orig_dtype: Original dtype before quantization + + For activations (is_weight=False): + - qdata: Original activation tensor (not quantized yet) + - orig_dtype: Original dtype + - is_weight: False marker + """ + + @classmethod + def quantize(cls, tensor, is_weight=True, **kwargs): + """ + For SVDQuant, we don't perform online quantization. + - Weights are pre-quantized offline and loaded from checkpoint + - Activations are stored as-is and quantized during forward pass + """ + orig_dtype = tensor.dtype + + if is_weight: + # This shouldn't be called for weights as they're loaded pre-quantized + raise NotImplementedError( + "SVDQuant weights should be loaded pre-quantized from checkpoint, " + "not quantized on-the-fly" + ) + else: + # For activations, just store the tensor as-is + # It will be quantized during the linear operation + layout_params = { + 'orig_dtype': orig_dtype, + 'is_weight': False + } + return tensor, layout_params + + @staticmethod + def dequantize(qdata, is_weight=True, orig_dtype=None, **kwargs): + """ + Dequantization for SVDQuant. + - Activations: return as-is (not actually quantized) + - Weights: full dequantization not supported (would need to reconstruct from SVD + residual) + """ + if not is_weight: + # Activations aren't actually quantized, just return them + return qdata.to(orig_dtype) if orig_dtype else qdata + else: + # Full weight dequantization is complex and not typically needed + # Would require: proj_down @ proj_up.T + dequantize(qweight) + raise NotImplementedError( + "Full dequantization of SVDQuant weights is not supported. " + "Use the quantized forward pass instead." + ) + + @classmethod + def get_plain_tensors(cls, qtensor): + """Extract the raw tensors needed for SVDQuant computation.""" + if qtensor._layout_params.get('is_weight', True): + # For weights, return all the necessary components + return { + 'qweight': qtensor._qdata, + 'wscales': qtensor._layout_params.get('wscales'), + 'smooth_factor': qtensor._layout_params.get('smooth_factor'), + 'proj_down': qtensor._layout_params.get('proj_down'), + 'proj_up': qtensor._layout_params.get('proj_up'), + 'group_size': qtensor._layout_params.get('group_size'), + 'precision': qtensor._layout_params.get('precision', 'int4'), + 'wtscale': qtensor._layout_params.get('wtscale'), + 'wcscales': qtensor._layout_params.get('wcscales'), + 'act_unsigned': qtensor._layout_params.get('act_unsigned', False), + } + else: + # For activations, just return the tensor + return qtensor._qdata + + +@register_layout_op(torch.ops.aten.addmm.default, "SVDQuantLayout") +@register_layout_op(torch.ops.aten.linear.default, "SVDQuantLayout") +def svdquant_linear(func, args, kwargs): + """ + SVDQuant linear operation handler. + + Implements: X*W = X * proj_up * proj_down + quantize(X) * quantize(R) + + Handles both aten.linear and aten.addmm (which linear decomposes into). + """ + # Handle both linear and addmm calling conventions + if func == torch.ops.aten.addmm.default: + # addmm(bias, input, weight.t()) -> out + bias = args[0] if len(args) > 0 else None + input_tensor = args[1] if len(args) > 1 else None + weight = args[2] if len(args) > 2 else None + # Weight comes transposed in addmm, but SVDQuant stores it non-transposed + # So we need to transpose it back + need_transpose = True + else: + # linear(input, weight, bias) -> out + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + need_transpose = False + + # Unwrap Parameter if necessary + if isinstance(weight, torch.nn.Parameter): + weight = weight.data + + # Check if weight is SVDQuant quantized + if not isinstance(weight, QuantizedTensor) or weight._layout_type != "SVDQuantLayout": + # Fallback to standard linear + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + if func == torch.ops.aten.addmm.default: + return torch.addmm(bias, input_tensor, weight) + else: + return torch.nn.functional.linear(input_tensor, weight, bias) + + # Extract weight parameters + weight_params = SVDQuantLayout.get_plain_tensors(weight) + qweight = weight_params['qweight'] + wscales = weight_params['wscales'] + smooth_factor = weight_params['smooth_factor'] + proj_down = weight_params['proj_down'] + proj_up = weight_params['proj_up'] + group_size = weight_params['group_size'] + precision = weight_params['precision'] + wtscale = weight_params['wtscale'] + wcscales = weight_params['wcscales'] + act_unsigned = weight_params['act_unsigned'] + + # Get activation tensor (dequantize if it's a QuantizedTensor) + if isinstance(input_tensor, QuantizedTensor): + if input_tensor._layout_type == "SVDQuantLayout": + x = SVDQuantLayout.get_plain_tensors(input_tensor) + else: + x = input_tensor.dequantize() + else: + x = input_tensor + + # Import nunchaku operations + try: + from nunchaku.ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda + from nunchaku.ops.gemm import svdq_gemm_w4a4_cuda + except ImportError: + raise ImportError( + "SVDQuant requires the nunchaku library. " + "Install it with: pip install nunchaku" + ) + + # Handle batch dimensions + original_shape = x.shape + if len(original_shape) == 2: + batch_size, channels = original_shape + seq_len = 1 + x = x.view(batch_size, seq_len, channels) + elif len(original_shape) == 3: + batch_size, seq_len, channels = original_shape + else: + raise ValueError(f"SVDQuant linear expects 2D or 3D input, got {len(original_shape)}D") + + # Reshape to 2D for computation + x_2d = x.reshape(batch_size * seq_len, channels) + original_batch_size = x_2d.shape[0] # Track original size before padding + + # Step 1: Quantize activations and compute low-rank hidden states + # Output: quantized_x, ascales, lora_act_out + quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda( + x_2d, + lora_down=proj_down, + smooth=smooth_factor, + fp4=(precision == "nvfp4"), + pad_size=256 + ) + + # Step 2: Compute quantized GEMM with low-rank residual + # Output shape: (N_padded, out_features) where N_padded may be larger due to padding + out_features = qweight.shape[0] + output = torch.empty( + quantized_x.shape[0], + out_features, + dtype=proj_up.dtype, + device=x.device + ) + + svdq_gemm_w4a4_cuda( + act=quantized_x, + wgt=qweight, + out=output, + ascales=ascales, + wscales=wscales, + lora_act_in=lora_act_out, + lora_up=proj_up, + bias=bias, + fp4=(precision == "nvfp4"), + alpha=wtscale, + wcscales=wcscales, + act_unsigned=act_unsigned, + ) + + # Slice to remove padding and reshape back to original batch dimensions + output = output[:original_batch_size, :] # Remove padding + if len(original_shape) == 2: + output = output.view(batch_size, out_features) + else: + output = output.view(batch_size, seq_len, out_features) + + return output + + +# ============================================================================== +# AWQ Layout + Operation Handlers +# ============================================================================== + +class AWQQuantLayout(QuantizedLayout): + """ + AWQ W4A16 quantization layout. + + AWQ (Activation-aware Weight Quantization) quantizes weights to 4-bit + while keeping activations in 16-bit precision (float16/bfloat16). + + Storage format: + For weights (is_weight=True): + - qdata: Packed quantized weights (out_features // 4, in_features // 2), int32 + - wscales: Weight quantization scales (in_features // group_size, out_features) + - wzeros: Weight zero points (in_features // group_size, out_features) + - group_size: Quantization group size (default 64) + - orig_dtype: Original dtype before quantization + + For activations (is_weight=False): + - qdata: Original activation tensor (not quantized) + - orig_dtype: Original dtype + - is_weight: False marker + """ + + @classmethod + def quantize(cls, tensor, is_weight=True, **kwargs): + """ + For AWQ, we don't perform online quantization. + - Weights are pre-quantized offline and loaded from checkpoint + - Activations remain in 16-bit precision + """ + orig_dtype = tensor.dtype + + if is_weight: + # This shouldn't be called for weights as they're loaded pre-quantized + raise NotImplementedError( + "AWQ weights should be loaded pre-quantized from checkpoint, " + "not quantized on-the-fly" + ) + else: + # For activations, just store the tensor as-is + layout_params = { + 'orig_dtype': orig_dtype, + 'is_weight': False + } + return tensor, layout_params + + @staticmethod + def dequantize(qdata, is_weight=True, orig_dtype=None, wscales=None, wzeros=None, group_size=64, **kwargs): + """ + Dequantization for AWQ. + - Activations: return as-is (not quantized) + - Weights: unpack and dequantize from 4-bit + """ + if not is_weight: + # Activations aren't quantized, just return them + return qdata.to(orig_dtype) if orig_dtype else qdata + else: + # Dequantize 4-bit weights + # qdata shape: (out_features // 4, in_features // 2), dtype int32 + # Output shape should be: (out_features, in_features) + + # This is a complex operation that requires unpacking 4-bit values + # For now, we'll raise an error and rely on the quantized forward pass + raise NotImplementedError( + "Full dequantization of AWQ weights is not yet supported. " + "Use the quantized forward pass instead." + ) + + @classmethod + def get_plain_tensors(cls, qtensor): + """Extract the raw tensors needed for AWQ computation.""" + if qtensor._layout_params.get('is_weight', True): + # For weights, return all the necessary components + return { + 'qweight': qtensor._qdata, + 'wscales': qtensor._layout_params.get('wscales'), + 'wzeros': qtensor._layout_params.get('wzeros'), + 'group_size': qtensor._layout_params.get('group_size', 64), + } + else: + # For activations, just return the tensor + return qtensor._qdata + + +@register_layout_op(torch.ops.aten.addmm.default, "AWQQuantLayout") +@register_layout_op(torch.ops.aten.linear.default, "AWQQuantLayout") +def awq_linear(func, args, kwargs): + """ + AWQ linear operation handler. + + Implements W4A16 quantized linear using AWQ format. + + Handles both aten.linear and aten.addmm (which linear decomposes into). + """ + # Handle both linear and addmm calling conventions + if func == torch.ops.aten.addmm.default: + # addmm(bias, input, weight.t()) -> out + bias = args[0] if len(args) > 0 else None + input_tensor = args[1] if len(args) > 1 else None + weight = args[2] if len(args) > 2 else None + else: + # linear(input, weight, bias) -> out + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + # Unwrap Parameter if necessary + if isinstance(weight, torch.nn.Parameter): + weight = weight.data + + # Check if weight is AWQ quantized + if not isinstance(weight, QuantizedTensor) or weight._layout_type != "AWQQuantLayout": + # Fallback to standard linear + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + if func == torch.ops.aten.addmm.default: + return torch.addmm(bias, input_tensor, weight) + else: + return torch.nn.functional.linear(input_tensor, weight, bias) + + # Extract weight parameters + weight_params = AWQQuantLayout.get_plain_tensors(weight) + qweight = weight_params['qweight'] + wscales = weight_params['wscales'] + wzeros = weight_params['wzeros'] + group_size = weight_params['group_size'] + + # Get activation tensor (dequantize if it's a QuantizedTensor) + if isinstance(input_tensor, QuantizedTensor): + if input_tensor._layout_type == "AWQQuantLayout": + x = AWQQuantLayout.get_plain_tensors(input_tensor) + else: + x = input_tensor.dequantize() + else: + x = input_tensor + + # Import nunchaku AWQ operation + try: + from nunchaku.ops.gemv import awq_gemv_w4a16_cuda + except ImportError: + raise ImportError( + "AWQ requires the nunchaku library. " + "Install it with: pip install nunchaku" + ) + + # Calculate output dimensions from packed weight shape + # qweight shape: (out_features // 4, in_features // 2) + out_features = qweight.shape[0] * 4 + in_features = qweight.shape[1] * 2 + + + # Handle batch dimensions - preserve original shape + # Important: nunchaku expects 2D input only, so we reshape 3D to 2D + original_shape = x.shape + if len(original_shape) == 2: + # (batch_size, in_features) + batch_size = original_shape[0] + x_2d = x + #elif len(original_shape) == 3: + # # (batch_size, seq_len, in_features) -> (batch_size * seq_len, in_features) + # batch_size, seq_len, _ = original_shape + # x_2d = x.reshape(batch_size * seq_len, in_features) + else: + raise ValueError(f"AWQ linear expects 2D or 3D input, got {len(original_shape)}D") + + # Ensure input is contiguous (required by CUDA kernel) + # Only create a contiguous copy if absolutely necessary + #if not x_2d.is_contiguous(): + # x_2d = x_2d.contiguous() + + output = awq_gemv_w4a16_cuda( + in_feats=x_2d, + kernel=qweight, + scaling_factors=wscales, + zeros=wzeros, + m=x_2d.shape[0], + n=out_features, + k=in_features, + group_size=group_size, + ) + + # Add bias if present + if bias is not None: + view_shape = [1] * (output.ndim - 1) + [-1] + output = output + bias.view(view_shape) + + # Reshape back to original batch dimensions + #if len(original_shape) == 3: + # output = output.view(batch_size, seq_len, out_features) + + return output + + +LAYOUTS["SVDQuantLayout"] = SVDQuantLayout +LAYOUTS["AWQQuantLayout"] = AWQQuantLayout \ No newline at end of file diff --git a/comfy/svdquant_converter.py b/comfy/svdquant_converter.py new file mode 100644 index 000000000000..08aaf50d1128 --- /dev/null +++ b/comfy/svdquant_converter.py @@ -0,0 +1,377 @@ +import json +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +# Note: Fused layer splitting is no longer used + + +@dataclass +class ConvertedState: + tensors: Dict[str, torch.Tensor] + quant_layers: Dict[str, str] + + +def _is_svd_prefix(keys: set[str], prefix: str) -> bool: + return ( + f"{prefix}.qweight" in keys + and f"{prefix}.smooth_factor" in keys + and f"{prefix}.proj_down" in keys + and f"{prefix}.proj_up" in keys + ) + + +def _is_awq_prefix(keys: set[str], prefix: str) -> bool: + return ( + f"{prefix}.qweight" in keys + and f"{prefix}.wscales" in keys + and f"{prefix}.wzeros" in keys + and f"{prefix}.smooth_factor" not in keys # Distinguish from SVDQuant + ) + + +def _detect_svd_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]: + prefixes = set() + keys = set(state_dict.keys()) + for key in keys: + if not key.endswith(".qweight"): + continue + prefix = key[: -len(".qweight")] + if _is_svd_prefix(keys, prefix): + prefixes.add(prefix) + return sorted(prefixes) + + +def _detect_awq_prefixes(state_dict: Dict[str, torch.Tensor]) -> List[str]: + prefixes = set() + keys = set(state_dict.keys()) + for key in keys: + if not key.endswith(".qweight"): + continue + prefix = key[: -len(".qweight")] + if _is_awq_prefix(keys, prefix): + prefixes.add(prefix) + return sorted(prefixes) + + +def _detect_format(wscales: torch.Tensor) -> str: + if wscales.dtype == torch.float8_e4m3fn: + return "svdquant_nvfp4" + return "svdquant_int4" + + +class _SVDQuantConverter: + def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.src = dict(state_dict) + self.dst: Dict[str, torch.Tensor] = {} + self.quant_layers: Dict[str, str] = {} + + def convert(self) -> ConvertedState: + prefixes = _detect_svd_prefixes(self.src) + for prefix in prefixes: + self._convert_single(prefix) + + for key, tensor in self.src.items(): + if key not in self.dst: + self.dst[key] = tensor + + return ConvertedState(self.dst, self.quant_layers) + + def _pop_tensor(self, key: str) -> torch.Tensor: + try: + return self.src.pop(key) + except KeyError as exc: + raise KeyError(f"Missing key '{key}' in SVDQuant checkpoint") from exc + + def _pop_optional(self, key: str) -> torch.Tensor | None: + return self.src.pop(key, None) + + def _convert_single(self, prefix: str) -> None: + # Ensure all tensors are contiguous to avoid CUDA alignment issues + self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous() + wscales = self._pop_tensor(f"{prefix}.wscales").contiguous() + self.dst[f"{prefix}.wscales"] = wscales + format_name = _detect_format(wscales) + + self.dst[f"{prefix}.smooth_factor"] = self._pop_tensor(f"{prefix}.smooth_factor").contiguous() + self.dst[f"{prefix}.smooth_factor_orig"] = self._pop_tensor( + f"{prefix}.smooth_factor_orig" + ).contiguous() + self.dst[f"{prefix}.proj_down"] = self._pop_tensor(f"{prefix}.proj_down").contiguous() + self.dst[f"{prefix}.proj_up"] = self._pop_tensor(f"{prefix}.proj_up").contiguous() + + bias = self._pop_optional(f"{prefix}.bias") + if bias is not None: + self.dst[f"{prefix}.bias"] = bias.contiguous() + + wtscale = self._pop_optional(f"{prefix}.wtscale") + if wtscale is not None: + self.dst[f"{prefix}.wtscale"] = wtscale.contiguous() if isinstance(wtscale, torch.Tensor) else wtscale + + wcscales = self._pop_optional(f"{prefix}.wcscales") + if wcscales is not None: + self.dst[f"{prefix}.wcscales"] = wcscales.contiguous() + + self.quant_layers[prefix] = format_name + + +class _AWQConverter: + def __init__(self, state_dict: Dict[str, torch.Tensor]) -> None: + self.src = dict(state_dict) + self.dst: Dict[str, torch.Tensor] = {} + self.quant_layers: Dict[str, str] = {} + + def convert(self) -> ConvertedState: + prefixes = _detect_awq_prefixes(self.src) + for prefix in prefixes: + self._convert_single(prefix) + + for key, tensor in self.src.items(): + if key not in self.dst: + self.dst[key] = tensor + + return ConvertedState(self.dst, self.quant_layers) + + def _pop_tensor(self, key: str) -> torch.Tensor: + try: + return self.src.pop(key) + except KeyError as exc: + raise KeyError(f"Missing key '{key}' in AWQ checkpoint") from exc + + def _pop_optional(self, key: str) -> torch.Tensor | None: + return self.src.pop(key, None) + + def _convert_single(self, prefix: str) -> None: + # Ensure all tensors are contiguous to avoid CUDA alignment issues + self.dst[f"{prefix}.weight"] = self._pop_tensor(f"{prefix}.qweight").contiguous() + self.dst[f"{prefix}.wscales"] = self._pop_tensor(f"{prefix}.wscales").contiguous() + self.dst[f"{prefix}.wzeros"] = self._pop_tensor(f"{prefix}.wzeros").contiguous() + + bias = self._pop_optional(f"{prefix}.bias") + if bias is not None: + self.dst[f"{prefix}.bias"] = bias.contiguous() + + self.quant_layers[prefix] = "awq_int4" + + +def convert_svdquant_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState: + return _SVDQuantConverter(state_dict).convert() + + +def convert_awq_state_dict(state_dict: Dict[str, torch.Tensor]) -> ConvertedState: + return _AWQConverter(state_dict).convert() + + +def detect_quantization_formats(state_dict: Dict[str, torch.Tensor]) -> Dict[str, List[str]]: + """ + Detect quantization formats present in a state dict. + + Parameters + ---------- + state_dict : Dict[str, torch.Tensor] + State dictionary to analyze + + Returns + ------- + Dict[str, List[str]] + Dictionary mapping format names to lists of layer prefixes + Example: { + "svdquant_int4": ["layer1.attn.qkv", "layer2.mlp.up"], + "svdquant_nvfp4": ["layer3.attn.qkv"], + "awq_int4": ["layer1.mlp.down", "layer4.attn.qkv"] + } + """ + result = {} + + # Detect SVDQuant layers + svd_prefixes = _detect_svd_prefixes(state_dict) + if svd_prefixes: + # Determine if int4 or nvfp4 based on wscales dtype + for prefix in svd_prefixes: + wscales_key = f"{prefix}.wscales" + if wscales_key in state_dict: + format_name = _detect_format(state_dict[wscales_key]) + if format_name not in result: + result[format_name] = [] + result[format_name].append(prefix) + + # Detect AWQ layers + awq_prefixes = _detect_awq_prefixes(state_dict) + if awq_prefixes: + result["awq_int4"] = awq_prefixes + + return result + + +def convert_awq_file( + input_path: str, + output_path: str, + format_version: str = "1.0", +) -> Tuple[int, Dict[str, str]]: + with safe_open(input_path, framework="pt", device="cpu") as f: + tensors = {key: f.get_tensor(key) for key in f.keys()} + metadata = dict(f.metadata()) + + converted = convert_awq_state_dict(tensors) + + # Convert layer format dict to expected metadata format + # From: {"layer": "awq_int4"} + # To: {"layer": {"format": "awq_int4"}} + layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()} + + metadata["_quantization_metadata"] = json.dumps( + {"format_version": format_version, "layers": layers_metadata}, sort_keys=True + ) + + save_file(converted.tensors, output_path, metadata=metadata) + return len(converted.quant_layers), converted.quant_layers + + +def convert_svdquant_file( + input_path: str, + output_path: str, + format_version: str = "1.0", +) -> Tuple[int, Dict[str, str]]: + with safe_open(input_path, framework="pt", device="cpu") as f: + tensors = {key: f.get_tensor(key) for key in f.keys()} + metadata = dict(f.metadata()) + + converted = convert_svdquant_state_dict(tensors) + + # Convert layer format dict to expected metadata format + # From: {"layer": "svdquant_int4"} + # To: {"layer": {"format": "svdquant_int4"}} + layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()} + + metadata["_quantization_metadata"] = json.dumps( + {"format_version": format_version, "layers": layers_metadata}, sort_keys=True + ) + metadata["model_class"] = "QwenImageTransformer2DModel" + + save_file(converted.tensors, output_path, metadata=metadata) + return len(converted.quant_layers), converted.quant_layers + + +def convert_quantized_file( + input_path: str, + output_path: str, + format_version: str = "1.0", + quant_format: str = "auto", +) -> Tuple[int, Dict[str, str]]: + """ + Auto-detect and convert quantized checkpoint to ComfyUI format. + + Supports mixed-format models where some layers are SVDQuant and others are AWQ. + Each layer is independently detected and converted to the appropriate format. + + Parameters + ---------- + input_path : str + Path to input checkpoint file + output_path : str + Path to output checkpoint file + format_version : str, optional + Quantization metadata format version (default: "1.0") + quant_format : str, optional + Quantization format: "auto", "svdquant", or "awq" (default: "auto") + + Returns + ------- + Tuple[int, Dict[str, str]] + Number of quantized layers and mapping of layer names to formats + """ + with safe_open(input_path, framework="pt", device="cpu") as f: + tensors = {key: f.get_tensor(key) for key in f.keys()} + metadata = dict(f.metadata()) + + # Auto-detect format if needed + if quant_format == "auto": + svd_prefixes = _detect_svd_prefixes(tensors) + awq_prefixes = _detect_awq_prefixes(tensors) + + if svd_prefixes and awq_prefixes: + # Mixed format - partition tensors by format and convert separately + + # Build sets of all quantized prefixes + all_svd_prefixes = set(svd_prefixes) + all_awq_prefixes = set(awq_prefixes) + + # Helper to check if a key belongs to a specific quantized layer + def belongs_to_prefix(key, prefix): + """Check if key belongs to a specific layer prefix.""" + return key == prefix or key.startswith(f"{prefix}.") + + def is_svd_key(key): + """Check if key belongs to any SVDQuant layer.""" + return any(belongs_to_prefix(key, prefix) for prefix in all_svd_prefixes) + + def is_awq_key(key): + """Check if key belongs to any AWQ layer.""" + return any(belongs_to_prefix(key, prefix) for prefix in all_awq_prefixes) + + # Partition tensors by format + svd_tensors = {} + awq_tensors = {} + other_tensors = {} + + for key, tensor in tensors.items(): + if is_svd_key(key): + svd_tensors[key] = tensor + elif is_awq_key(key): + awq_tensors[key] = tensor + else: + other_tensors[key] = tensor + + # Convert each format separately with only its relevant tensors + svd_converted = _SVDQuantConverter(svd_tensors).convert() + awq_converted = _AWQConverter(awq_tensors).convert() + + # Merge results - each converter only has its own layer tensors + converted_tensors = {} + + # Add SVDQuant converted tensors + converted_tensors.update(svd_converted.tensors) + + # Add AWQ converted tensors + converted_tensors.update(awq_converted.tensors) + + # Add non-quantized tensors + converted_tensors.update(other_tensors) + + # Merge quantization layer metadata + quant_layers = {} + quant_layers.update(svd_converted.quant_layers) + quant_layers.update(awq_converted.quant_layers) + + converted = ConvertedState(converted_tensors, quant_layers) + elif svd_prefixes: + converted = convert_svdquant_state_dict(tensors) + elif awq_prefixes: + converted = convert_awq_state_dict(tensors) + else: + raise ValueError("No quantized layers detected in checkpoint") + elif quant_format == "svdquant": + converted = convert_svdquant_state_dict(tensors) + elif quant_format == "awq": + converted = convert_awq_state_dict(tensors) + else: + raise ValueError(f"Unknown quantization format: {quant_format}") + + # Convert layer format dict to expected metadata format + # From: {"layer": "awq_int4"} + # To: {"layer": {"format": "awq_int4"}} + layers_metadata = {k: {"format": v} for k, v in converted.quant_layers.items()} + + metadata["_quantization_metadata"] = json.dumps( + {"format_version": format_version, "layers": layers_metadata}, sort_keys=True + ) + metadata["model_class"] = "QwenImageTransformer2DModel" + + save_file(converted.tensors, output_path, metadata=metadata) + return len(converted.quant_layers), converted.quant_layers + + diff --git a/convert_svdquant_checkpoint.py b/convert_svdquant_checkpoint.py new file mode 100644 index 000000000000..f761e707bac0 --- /dev/null +++ b/convert_svdquant_checkpoint.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +""" +Convert quantized checkpoints (SVDQuant, AWQ, or mixed) into the ComfyUI quantization format. +""" + +import argparse +from pathlib import Path +from safetensors import safe_open + +from comfy.svdquant_converter import ( + convert_quantized_file, + convert_svdquant_file, + convert_awq_file, + detect_quantization_formats, +) + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Convert quantized .safetensors files (SVDQuant, AWQ, or mixed) " + "into the ComfyUI format with per-layer metadata for MixedPrecisionOps." + ) + parser.add_argument("input", type=Path, help="Path to the source quantized .safetensors file.") + parser.add_argument( + "-o", + "--output", + type=Path, + help="Destination path for the converted checkpoint. " + "Defaults to _comfy.safetensors in the same directory.", + ) + parser.add_argument( + "--format-version", + default="1.0", + help="Format version to store inside _quantization_metadata (default: 1.0).", + ) + parser.add_argument( + "--format", + choices=["auto", "svdquant", "awq"], + default="auto", + help="Quantization format (default: auto-detect).", + ) + parser.add_argument( + "--detect-only", + action="store_true", + help="Only detect and report quantization formats without converting.", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + + input_path = args.input.expanduser().resolve() + + # Detect formats if requested + if args.detect_only: + print(f"[Quantization Detector] Analyzing: {input_path}") + with safe_open(str(input_path), framework="pt", device="cpu") as f: + tensors = {key: f.get_tensor(key) for key in f.keys()} + + formats = detect_quantization_formats(tensors) + + if not formats: + print("[Quantization Detector] No quantized layers detected.") + return + + print(f"[Quantization Detector] Detected formats:") + total_layers = 0 + for format_name, layer_prefixes in sorted(formats.items()): + print(f"\n {format_name}: {len(layer_prefixes)} layers") + for prefix in sorted(layer_prefixes)[:5]: # Show first 5 + print(f" - {prefix}") + if len(layer_prefixes) > 5: + print(f" ... and {len(layer_prefixes) - 5} more") + total_layers += len(layer_prefixes) + + print(f"\n[Quantization Detector] Total: {total_layers} quantized layers") + print(f"[Quantization Detector] Use without --detect-only to convert.") + return + + # Convert checkpoint + if args.output is None: + output_path = input_path.with_name(f"{input_path.stem}_comfy.safetensors") + else: + output_path = args.output.expanduser().resolve() + + layer_count, quant_layers = convert_quantized_file( + str(input_path), + str(output_path), + format_version=args.format_version, + quant_format=args.format, + ) + + # Group layers by format for display + format_groups = {} + for layer_name, fmt in quant_layers.items(): + if fmt not in format_groups: + format_groups[fmt] = [] + format_groups[fmt].append(layer_name) + + print(f"[Quantization Converter] Converted {layer_count} layers.") + print(f"[Quantization Converter] Output saved to: {output_path}") + print(f"\n[Quantization Converter] Quantized layers by format:") + + for fmt, layers in sorted(format_groups.items()): + print(f"\n {fmt}: {len(layers)} layers") + for layer_name in sorted(layers)[:5]: # Show first 5 + print(f" - {layer_name}") + if len(layers) > 5: + print(f" ... and {len(layers) - 5} more") + + +if __name__ == "__main__": + main() + diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8026..2a1f60208e97 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -1,7 +1,10 @@ +import os +import sys import unittest +from pathlib import Path + import torch -import sys -import os +from safetensors.torch import load_file # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -13,7 +16,9 @@ def has_gpu(): if not has_gpu(): args.cpu = True -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout, AWQQuantLayout, SVDQuantLayout +from comfy.ops import mixed_precision_ops +from comfy.svdquant_converter import convert_svdquant_state_dict, convert_awq_state_dict class TestQuantizedTensor(unittest.TestCase): @@ -156,6 +161,199 @@ def test_dequantize(self): self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) +class TestAWQQuantLayout(unittest.TestCase): + """Test the AWQQuantLayout implementation""" + + def test_awq_layout_creation(self): + """Test creating an AWQ quantized tensor""" + # AWQ uses pre-quantized weights loaded from checkpoints + # Create dummy AWQ quantized weights + out_features, in_features = 256, 128 + group_size = 64 + + qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + + layout_params = { + 'wscales': wscales, + 'wzeros': wzeros, + 'group_size': group_size, + 'orig_dtype': torch.bfloat16, + 'is_weight': True, + } + + qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, qweight.shape) + self.assertEqual(qt.dtype, torch.int32) + self.assertEqual(qt._layout_type, "AWQQuantLayout") + self.assertEqual(qt._layout_params['group_size'], group_size) + + def test_awq_quantize_not_supported(self): + """Test that online quantization raises NotImplementedError for AWQ""" + # AWQ doesn't support online quantization - weights must be pre-quantized + float_tensor = torch.randn(32, 64, dtype=torch.float32) + + with self.assertRaises(NotImplementedError): + AWQQuantLayout.quantize(float_tensor, is_weight=True) + + def test_awq_get_plain_tensors(self): + """Test extracting plain tensors from AWQ quantized tensor""" + out_features, in_features = 256, 128 + group_size = 64 + + qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + + layout_params = { + 'wscales': wscales, + 'wzeros': wzeros, + 'group_size': group_size, + 'orig_dtype': torch.bfloat16, + 'is_weight': True, + } + + qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params) + plain_tensors = AWQQuantLayout.get_plain_tensors(qt) + + # Verify we can extract all necessary components + self.assertIsInstance(plain_tensors, dict) + self.assertIn('qweight', plain_tensors) + self.assertIn('wscales', plain_tensors) + self.assertIn('wzeros', plain_tensors) + self.assertIn('group_size', plain_tensors) + self.assertTrue(torch.equal(plain_tensors['qweight'], qweight)) + self.assertTrue(torch.equal(plain_tensors['wscales'], wscales)) + self.assertTrue(torch.equal(plain_tensors['wzeros'], wzeros)) + + +class TestSVDQuantLayout(unittest.TestCase): + """Test the SVDQuantLayout implementation""" + + def test_svdquant_layout_creation(self): + """Test creating an SVDQuant quantized tensor""" + # SVDQuant uses pre-quantized weights loaded from checkpoints + out_features, in_features = 256, 128 + rank = 32 + group_size = 64 + precision = "int4" + + # Create dummy SVDQuant quantized weights (int8 range is -128 to 127) + qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + smooth_factor = torch.randn(in_features, dtype=torch.bfloat16) + smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16) + proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16) + proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16) + + layout_params = { + 'wscales': wscales, + 'smooth_factor': smooth_factor, + 'smooth_factor_orig': smooth_factor_orig, + 'proj_down': proj_down, + 'proj_up': proj_up, + 'group_size': group_size, + 'precision': precision, + 'orig_dtype': torch.bfloat16, + 'is_weight': True, + 'act_unsigned': False, + 'wtscale': None, + 'wcscales': None, + } + + qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, qweight.shape) + self.assertEqual(qt.dtype, torch.int8) + self.assertEqual(qt._layout_type, "SVDQuantLayout") + self.assertEqual(qt._layout_params['group_size'], group_size) + self.assertEqual(qt._layout_params['precision'], precision) + + def test_svdquant_quantize_not_supported(self): + """Test that online quantization raises NotImplementedError for SVDQuant""" + # SVDQuant doesn't support online quantization - weights must be pre-quantized + float_tensor = torch.randn(32, 64, dtype=torch.float32) + + with self.assertRaises(NotImplementedError): + SVDQuantLayout.quantize(float_tensor, is_weight=True) + + def test_svdquant_dequantize_not_supported(self): + """Test that weight dequantization raises NotImplementedError for SVDQuant""" + # Full weight dequantization is not supported (complex operation) + out_features, in_features = 256, 128 + rank = 32 + group_size = 64 + + qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + smooth_factor = torch.randn(in_features, dtype=torch.bfloat16) + proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16) + proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16) + + with self.assertRaises(NotImplementedError): + SVDQuantLayout.dequantize( + qweight, + is_weight=True, + wscales=wscales, + smooth_factor=smooth_factor, + proj_down=proj_down, + proj_up=proj_up, + group_size=group_size, + precision="int4", + orig_dtype=torch.bfloat16 + ) + + def test_svdquant_get_plain_tensors(self): + """Test extracting plain tensors from SVDQuant quantized tensor""" + out_features, in_features = 256, 128 + rank = 32 + group_size = 64 + + qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16) + smooth_factor = torch.randn(in_features, dtype=torch.bfloat16) + smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16) + proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16) + proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16) + + layout_params = { + 'wscales': wscales, + 'smooth_factor': smooth_factor, + 'smooth_factor_orig': smooth_factor_orig, + 'proj_down': proj_down, + 'proj_up': proj_up, + 'group_size': group_size, + 'precision': "int4", + 'orig_dtype': torch.bfloat16, + 'is_weight': True, + 'act_unsigned': False, + 'wtscale': None, + 'wcscales': None, + } + + qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params) + plain_tensors = SVDQuantLayout.get_plain_tensors(qt) + + # Verify we can extract all necessary components + self.assertIsInstance(plain_tensors, dict) + self.assertIn('qweight', plain_tensors) + self.assertIn('wscales', plain_tensors) + self.assertIn('smooth_factor', plain_tensors) + self.assertIn('proj_down', plain_tensors) + self.assertIn('proj_up', plain_tensors) + self.assertIn('group_size', plain_tensors) + self.assertIn('precision', plain_tensors) + self.assertTrue(torch.equal(plain_tensors['qweight'], qweight)) + self.assertTrue(torch.equal(plain_tensors['wscales'], wscales)) + self.assertTrue(torch.equal(plain_tensors['smooth_factor'], smooth_factor)) + self.assertTrue(torch.equal(plain_tensors['proj_down'], proj_down)) + self.assertTrue(torch.equal(plain_tensors['proj_up'], proj_up)) + + class TestFallbackMechanism(unittest.TestCase): """Test fallback for unsupported operations""" @@ -186,5 +384,158 @@ def test_unsupported_op_dequantizes(self): self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") +class TestAWQConversion(unittest.TestCase): + """Test AWQ checkpoint conversion""" + + def test_awq_single_layer_conversion(self): + """Test converting a single AWQ layer""" + in_features, out_features = 128, 256 + group_size = 64 + + # Create AWQ checkpoint format + state_dict = { + "layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32), + "layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16), + "layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16), + "layer.bias": torch.randn(out_features, dtype=torch.bfloat16), + } + + converted = convert_awq_state_dict(state_dict) + + # Check that qweight was renamed to weight + self.assertIn("layer.weight", converted.tensors) + self.assertNotIn("layer.qweight", converted.tensors) + + # Check other parameters preserved + self.assertIn("layer.wscales", converted.tensors) + self.assertIn("layer.wzeros", converted.tensors) + self.assertIn("layer.bias", converted.tensors) + + # Check quantization metadata + self.assertIn("layer", converted.quant_layers) + self.assertEqual(converted.quant_layers["layer"], "awq_int4") + + def test_awq_tensor_shapes(self): + """Test that converted AWQ tensors have correct shapes""" + in_features, out_features = 3072, 18432 + group_size = 64 + + state_dict = { + "layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32), + "layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16), + "layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16), + } + + converted = convert_awq_state_dict(state_dict) + + # Check qweight shape (packed 4-bit) + qweight = converted.tensors["layer.weight"] + self.assertEqual(qweight.shape, (out_features // 4, in_features // 2)) + self.assertEqual(qweight.dtype, torch.int32) + + # Check wscales shape + wscales = converted.tensors["layer.wscales"] + self.assertEqual(wscales.shape, (in_features // group_size, out_features)) + self.assertEqual(wscales.dtype, torch.bfloat16) + + # Check wzeros shape + wzeros = converted.tensors["layer.wzeros"] + self.assertEqual(wzeros.shape, (in_features // group_size, out_features)) + self.assertEqual(wzeros.dtype, torch.bfloat16) + + +class TestAWQLinearOperation(unittest.TestCase): + """Test AWQ linear operations with actual nunchaku kernels""" + + @unittest.skipUnless(has_gpu(), "GPU required for AWQ operations") + def test_awq_linear_basic(self): + """Test basic AWQ linear operation by calling kernel directly""" + try: + from nunchaku.ops.gemv import awq_gemv_w4a16_cuda + except ImportError: + self.skipTest("nunchaku package not available") + + device = torch.device("cuda") + in_features, out_features = 128, 256 + group_size = 64 + batch_size = 4 + + # Create AWQ quantized weight tensors + qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device) + wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device) + bias = torch.randn(out_features, dtype=torch.bfloat16, device=device) + + # Create layout params + layout_params = { + 'wscales': wscales, + 'wzeros': wzeros, + 'group_size': group_size, + 'orig_dtype': torch.bfloat16, + 'is_weight': True, + } + + weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params) + + # Check that weight is a QuantizedTensor + self.assertIsInstance(weight, QuantizedTensor) + self.assertEqual(weight._layout_type, "AWQQuantLayout") + + # Create input + x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device) + + # Call AWQ linear handler directly + from comfy.quant_ops import awq_linear + output = awq_linear(torch.ops.aten.linear.default, (x, weight, bias), {}) + + # Check output shape and dtype + self.assertEqual(output.shape, (batch_size, out_features)) + self.assertEqual(output.dtype, torch.bfloat16) + + @unittest.skipUnless(has_gpu(), "GPU required for AWQ operations") + def test_awq_linear_2d_input(self): + """Test AWQ linear with 2D input (batch, features) by calling kernel directly""" + try: + from nunchaku.ops.gemv import awq_gemv_w4a16_cuda + except ImportError: + self.skipTest("nunchaku package not available") + + device = torch.device("cuda") + in_features, out_features = 128, 256 + group_size = 64 + batch_size = 4 + + # Create AWQ quantized weight tensors + qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device) + wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device) + wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device) + + # Create layout params + layout_params = { + 'wscales': wscales, + 'wzeros': wzeros, + 'group_size': group_size, + 'orig_dtype': torch.bfloat16, + 'is_weight': True, + } + + weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params) + + # Check that weight is a QuantizedTensor + self.assertIsInstance(weight, QuantizedTensor) + self.assertEqual(weight._layout_type, "AWQQuantLayout") + + # Create 2D input + x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device) + + # Call AWQ linear handler directly + from comfy.quant_ops import awq_linear + output = awq_linear(torch.ops.aten.linear.default, (x, weight, None), {}) + + # Check output shape + self.assertEqual(output.shape, (batch_size, out_features)) + self.assertEqual(output.dtype, torch.bfloat16) + + if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file