Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 153 additions & 4 deletions QUANTIZATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
| svdquant_int4 | int8 (packed 4-bit) | - | - | - | - |
| svdquant_nvfp4 | int8 (packed 4-bit) | - | - | - | - |
| awq_int4 | int32 (packed 4-bit) | - | - | - | - |

For SVDQuant formats, additional parameters are stored:
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
- **smooth_factor**: Smoothing factors for inputs (shape: in_features)
- **smooth_factor_orig**: Original smoothing factors (shape: in_features)
- **proj_down**: Low-rank down projection (shape: in_features, rank)
- **proj_up**: Low-rank up projection (shape: out_features, rank)
- **wtscale**: Global weight scale (nvfp4 only, scalar float)
- **wcscales**: Channel-wise weight scales (nvfp4 only, shape: out_features)

For AWQ format, the following parameters are stored:
- **wscales**: Weight quantization scales (shape: in_features // group_size, out_features)
- **wzeros**: Weight zero points (shape: in_features // group_size, out_features)

You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).

Expand All @@ -139,9 +155,9 @@ Example:
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
"model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
}
}
}
Expand All @@ -165,4 +181,137 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights

The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.


## SVDQuant

SVDQuant is an advanced 4-bit quantization scheme that decomposes linear operations using low-rank factorization combined with residual quantization:

```
X*W = X * proj_down * proj_up + quantize(X) * quantize(R)
```

Where:
- `proj_down`, `proj_up`: Low-rank factorization matrices of the original weights
- `R`: Residual weights (quantized to 4-bit)
- `quantize()`: 4-bit quantization with smoothing factors

### Key Features

1. **Asymmetric Quantization**: Unlike FP8 where both weights and activations are quantized offline or use the same quantization scheme, SVDQuant:
- Quantizes weights offline with multiple parameters stored in the checkpoint
- Quantizes activations on-the-fly during forward pass using smoothing factors

2. **Two Precision Modes**:
- `svdquant_int4`: 4-bit integer quantization with group_size=64
- `svdquant_nvfp4`: 4-bit floating-point (NVIDIA FP4) with group_size=16, includes additional channel-wise scales

3. **Low-Rank Optimization**: Separates the easy-to-approximate low-rank component from the hard-to-quantize residual, improving accuracy.

### Implementation

SVDQuant requires the `nunchaku` library for optimized CUDA kernels:
```bash
pip install nunchaku
```

The implementation uses two main operations:
- `svdq_quantize_w4a4_act_fuse_lora_cuda`: Quantizes activations and computes low-rank hidden states
- `svdq_gemm_w4a4_cuda`: Performs the quantized GEMM with low-rank residual addition

### Checkpoint Format

SVDQuant checkpoints contain the standard weight tensor (packed 4-bit residuals in int8) plus additional parameters per quantized layer:

```python
{
"layer_name.weight": tensor, # Packed 4-bit residual weights (out_features, in_features // 2)
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
"layer_name.smooth_factor": tensor, # Smoothing factors (in_features,)
"layer_name.smooth_factor_orig": tensor, # Original smoothing factors (in_features,)
"layer_name.proj_down": tensor, # Low-rank down projection (in_features, rank)
"layer_name.proj_up": tensor, # Low-rank up projection (out_features, rank)

# For nvfp4 only:
"layer_name.wtscale": float, # Global weight scale
"layer_name.wcscales": tensor, # Channel-wise scales (out_features,)
}
```

The quantization metadata specifies which layers use SVDQuant:

```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "svdquant_int4"},
"model.layers.0.mlp.down_proj": {"format": "svdquant_int4"}
}
}
}
```

## AWQ

AWQ (Activation-aware Weight Quantization) is a 4-bit weight quantization scheme that keeps activations in 16-bit precision (W4A16):

```
Y = X @ W_quantized
```

Where:
- `X`: 16-bit activations (float16/bfloat16)
- `W_quantized`: 4-bit quantized weights with per-group scales and zero points

### Key Features

1. **W4A16 Quantization**:
- Quantizes weights to 4-bit while keeping activations in 16-bit
- Uses per-group quantization with configurable group size (typically 64)
- Stores zero points for asymmetric quantization

2. **Activation-Aware**:
- Quantization is calibrated based on activation statistics
- Protects salient weights that are important for accuracy

3. **Hardware Efficient**:
- Optimized for GPU inference
- Significantly reduces memory footprint
- Increases throughput with specialized kernels

### Implementation

AWQ requires the `nunchaku` library for optimized CUDA kernels:
```bash
pip install nunchaku
```

The implementation uses the `awq_gemv_w4a16_cuda` kernel for efficient W4A16 matrix multiplication.

### Checkpoint Format

AWQ checkpoints contain the standard weight tensor (packed 4-bit weights in int32) plus additional parameters per quantized layer:

```python
{
"layer_name.weight": tensor, # Packed 4-bit weights (out_features // 4, in_features // 2)
"layer_name.wscales": tensor, # Weight scales (in_features // group_size, out_features)
"layer_name.wzeros": tensor, # Zero points (in_features // group_size, out_features)
}
```

The quantization metadata specifies which layers use AWQ:

```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": {"format": "awq_int4"},
"model.layers.0.mlp.down_proj": {"format": "awq_int4"}
}
}
}
```
Loading