Skip to content

Issue: Performance degradation with int8 quantization in multi-batch scenarios #218

@kakarotzzz

Description

@kakarotzzz

When using int8 quantization, there is a significant performance drop in multi-batch inference compared to single-batch inference. The single-batch performance is good, but the performance doesn't scale well with increased batch size.

class WeightOnlyInt8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
  
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales

Current Behavior

  1. The explicit .to(dtype=input.dtype) creates a separate type conversion kernel
  2. In single batch case, inductor can successfully fuse this conversion with gemm
  3. In multi-batch case, the fusion fails and we get:
    • One kernel for int8->fp16 conversion
    • Another kernel for gemm computation
    • This leads to extra memory traffic and lower performance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions