-
Notifications
You must be signed in to change notification settings - Fork 566
Open
Description
When using int8 quantization, there is a significant performance drop in multi-batch inference compared to single-batch inference. The single-batch performance is good, but the performance doesn't scale well with increased batch size.
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
Current Behavior
- The explicit
.to(dtype=input.dtype)
creates a separate type conversion kernel - In single batch case, inductor can successfully fuse this conversion with gemm
- In multi-batch case, the fusion fails and we get:
- One kernel for int8->fp16 conversion
- Another kernel for gemm computation
- This leads to extra memory traffic and lower performance
Metadata
Metadata
Assignees
Labels
No labels