Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import List, Tuple

import torch
from tabulate import tabulate
from tqdm import tqdm

# Assuming these imports based on the kernel location
from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.blockwise_fp8_training.kernels import (
torch_blockwise_scale_act_quant_lhs,
triton_fp8_blockwise_act_quant_lhs,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: Tuple[int, int] # (M, K)
block_size: int


@dataclass(frozen=True)
class ExperimentResult:
# time
naive_us: float
triton_us: float
# mem bw
naive_gbps: float
triton_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
"""
Test configurations for typical transformer activation shapes.
Format: (batch_size * seq_len, hidden_dim)
"""
# Llama-style shapes: various batch*seq_len sizes with typical hidden dims
input_shapes = [
(512, 4096),
(1024, 4096),
(2048, 4096),
(4096, 4096),
(8192, 4096),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the leading total_M dims (seq_len * local_batch_size) bigger? e.g. range of 8192, 8192*2, 8192*4, 8192*8, 8192*16? this is more representative of what we'll see in real training runs.

same for act_quant_rhs benchmarks

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, any downside to having all the above quantization benchmarks with these bigger values?


]

configs = []
block_sizes = [128] # Standard block size for FP8

for shape in input_shapes:
for block_size in block_sizes:
configs.append(
ExperimentConfig(
input_shape=shape,
block_size=block_size,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
M, K = config.input_shape
block_size = config.block_size

def verify_outputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in these various verify_outputs can we also validate the memory layouts are the same? i.e., check shapes and strides match.

this way we are 100% sure we are doing a 1:1 comparison (writing to different memory layouts can drastically affect performance)

y_naive: torch.Tensor,
s_naive: torch.Tensor,
y_triton: torch.Tensor,
s_triton: torch.Tensor,
input_tensor: torch.Tensor,
block_size: int,
rtol: float = 1e-2,
atol: float = 1e-2,
):
"""Verify that Triton and naive implementations produce similar results."""

# Convert FP8 back to float for comparison
y_naive_float = y_naive.to(torch.float32)
y_triton_float = y_triton.to(torch.float32)

# Check quantized values are close
if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol):
max_diff = (y_naive_float - y_triton_float).abs().max().item()
print(f"WARNING: Quantized values differ! Max diff: {max_diff}")
print(
f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]"
)
print(
f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]"
)

# ROBUST FIX: Handle potential dtype mismatches from torch.compile
# Convert both scales to float32 before any operations
if s_naive.dtype != torch.float32:
print(
f"INFO: Converting naive scales from {s_naive.dtype} to float32")
s_naive = s_naive.to(torch.float32)

if s_triton.dtype != torch.float32:
print(
f"INFO: Converting Triton scales from {s_triton.dtype} to float32")
s_triton = s_triton.to(torch.float32)

# Check scales are close
# Note: scales are in column-major format, need to read them correctly
s_naive_rowmajor = s_naive.as_strided(
s_naive.shape, (s_naive.shape[1], 1))
s_triton_rowmajor = s_triton.as_strided(
s_triton.shape, (s_triton.shape[1], 1))

if not torch.allclose(
s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol
):
max_diff = (s_naive_rowmajor -
s_triton_rowmajor).abs().max().item()
print(f"WARNING: Scales differ! Max diff: {max_diff}")
print(
f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]"
)
print(
f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]"
)

input_tensor = torch.randn(
M,
K,
dtype=torch.bfloat16,
device=device,
)

# Benchmark naive implementation
naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: i wouldn't call the torch native implementation "naive" per say - when using torch.compile these can sometimes be quite fast / close to speed of light. however, when that is not the case, we hand implement kernels with triton or cuda (like we've done here).

i think just replacing naive_impl => torch_impl or similar would be better

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've replaced all mentions of naive in function and variables names with 'torch'

y_naive, s_naive = naive_impl_c(input_tensor, block_size)
naive_time_us = benchmark_cuda_function_in_microseconds(
naive_impl_c,
input_tensor,
block_size,
)

# Benchmark Triton implementation
triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_lhs)
y_triton, s_triton = triton_impl_c(input_tensor, block_size)
triton_time_us = benchmark_cuda_function_in_microseconds(
triton_impl_c,
input_tensor,
block_size,
)

# Verify correctness (optional, can comment out for pure benchmarking)
verify_outputs(y_naive, s_naive, y_triton,
s_triton, input_tensor, block_size)

# Memory bandwidth calculations
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
bytes_per_scale_el = 4 # float32

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = (
y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el
)

naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6)
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6)

return ExperimentResult(
naive_us=naive_time_us,
triton_us=triton_time_us,
naive_gbps=naive_gbps,
triton_gbps=triton_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape (M, K)",
"block_size",
"naive_us",
"triton_us",
"speedup",
"naive_gbps",
"triton_gbps",
]
rows = []
for experiment in experiments:
speedup = experiment.result.naive_us / experiment.result.triton_us
rows.append(
[
f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}",
experiment.config.block_size,
f"{experiment.result.naive_us:.2f}",
f"{experiment.result.triton_us:.2f}",
f"{speedup:.2f}x",
f"{experiment.result.naive_gbps:.1f}",
f"{experiment.result.triton_gbps:.1f}",
]
)
print(tabulate(rows, headers=headers, tablefmt="grid"))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []

print(f"Running {len(configs)} benchmark configurations...\n")

for config in tqdm(configs, desc="Benchmarking"):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

print("\n" + "=" * 80)
print("BENCHMARK RESULTS")
print("=" * 80 + "\n")
print_results(results)


if __name__ == "__main__":
main()
Loading
Loading