-
Notifications
You must be signed in to change notification settings - Fork 368
Re: #3290 FP8 Blockwise Training Tracker, quantization benchmarks #3306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
577a570
e5c8601
ac3b550
ee3a26e
0bc1597
a36bb48
0b8b05e
066b346
4ad066a
e464ad5
36f34ca
278cb70
873ba81
8281e7b
2175611
8525822
c5b058c
83af1d7
a22d36f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,236 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import List, Tuple | ||
|
|
||
| import torch | ||
| from tabulate import tabulate | ||
| from tqdm import tqdm | ||
|
|
||
| # Assuming these imports based on the kernel location | ||
| from benchmarks.utils import benchmark_cuda_function_in_microseconds | ||
| from torchao.prototype.blockwise_fp8_training.kernels import ( | ||
| torch_blockwise_scale_act_quant_lhs, | ||
| triton_fp8_blockwise_act_quant_lhs, | ||
| ) | ||
|
|
||
| device = torch.device("cuda") | ||
|
|
||
| # Needed since changing args to function causes recompiles | ||
| torch._dynamo.config.cache_size_limit = 1000 | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ExperimentConfig: | ||
| input_shape: Tuple[int, int] # (M, K) | ||
| block_size: int | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class ExperimentResult: | ||
| # time | ||
| naive_us: float | ||
| triton_us: float | ||
| # mem bw | ||
| naive_gbps: float | ||
| triton_gbps: float | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class Experiment: | ||
| config: ExperimentConfig | ||
| result: ExperimentResult | ||
|
|
||
|
|
||
| def get_configs() -> List[ExperimentConfig]: | ||
| """ | ||
| Test configurations for typical transformer activation shapes. | ||
| Format: (batch_size * seq_len, hidden_dim) | ||
| """ | ||
| # Llama-style shapes: various batch*seq_len sizes with typical hidden dims | ||
| input_shapes = [ | ||
| (512, 4096), | ||
| (1024, 4096), | ||
| (2048, 4096), | ||
| (4096, 4096), | ||
| (8192, 4096), | ||
|
|
||
| ] | ||
|
|
||
| configs = [] | ||
| block_sizes = [128] # Standard block size for FP8 | ||
|
|
||
| for shape in input_shapes: | ||
| for block_size in block_sizes: | ||
| configs.append( | ||
| ExperimentConfig( | ||
| input_shape=shape, | ||
| block_size=block_size, | ||
| ) | ||
| ) | ||
| return configs | ||
|
|
||
|
|
||
| def run_experiment(config: ExperimentConfig) -> ExperimentResult: | ||
| M, K = config.input_shape | ||
| block_size = config.block_size | ||
|
|
||
| def verify_outputs( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in these various this way we are 100% sure we are doing a 1:1 comparison (writing to different memory layouts can drastically affect performance) |
||
| y_naive: torch.Tensor, | ||
| s_naive: torch.Tensor, | ||
| y_triton: torch.Tensor, | ||
| s_triton: torch.Tensor, | ||
| input_tensor: torch.Tensor, | ||
| block_size: int, | ||
| rtol: float = 1e-2, | ||
| atol: float = 1e-2, | ||
| ): | ||
| """Verify that Triton and naive implementations produce similar results.""" | ||
|
|
||
| # Convert FP8 back to float for comparison | ||
| y_naive_float = y_naive.to(torch.float32) | ||
| y_triton_float = y_triton.to(torch.float32) | ||
|
|
||
| # Check quantized values are close | ||
| if not torch.allclose(y_naive_float, y_triton_float, rtol=rtol, atol=atol): | ||
| max_diff = (y_naive_float - y_triton_float).abs().max().item() | ||
| print(f"WARNING: Quantized values differ! Max diff: {max_diff}") | ||
| print( | ||
| f" Naive range: [{y_naive_float.min():.3f}, {y_naive_float.max():.3f}]" | ||
| ) | ||
| print( | ||
| f" Triton range: [{y_triton_float.min():.3f}, {y_triton_float.max():.3f}]" | ||
| ) | ||
|
|
||
| # ROBUST FIX: Handle potential dtype mismatches from torch.compile | ||
| # Convert both scales to float32 before any operations | ||
| if s_naive.dtype != torch.float32: | ||
danielvegamyhre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print( | ||
| f"INFO: Converting naive scales from {s_naive.dtype} to float32") | ||
| s_naive = s_naive.to(torch.float32) | ||
|
|
||
| if s_triton.dtype != torch.float32: | ||
| print( | ||
| f"INFO: Converting Triton scales from {s_triton.dtype} to float32") | ||
| s_triton = s_triton.to(torch.float32) | ||
|
|
||
| # Check scales are close | ||
| # Note: scales are in column-major format, need to read them correctly | ||
danielvegamyhre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| s_naive_rowmajor = s_naive.as_strided( | ||
| s_naive.shape, (s_naive.shape[1], 1)) | ||
| s_triton_rowmajor = s_triton.as_strided( | ||
| s_triton.shape, (s_triton.shape[1], 1)) | ||
|
|
||
| if not torch.allclose( | ||
danielvegamyhre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| s_naive_rowmajor, s_triton_rowmajor, rtol=rtol, atol=atol | ||
| ): | ||
| max_diff = (s_naive_rowmajor - | ||
| s_triton_rowmajor).abs().max().item() | ||
| print(f"WARNING: Scales differ! Max diff: {max_diff}") | ||
| print( | ||
| f" Naive scale range: [{s_naive_rowmajor.min():.6f}, {s_naive_rowmajor.max():.6f}]" | ||
| ) | ||
| print( | ||
| f" Triton scale range: [{s_triton_rowmajor.min():.6f}, {s_triton_rowmajor.max():.6f}]" | ||
| ) | ||
|
|
||
| input_tensor = torch.randn( | ||
| M, | ||
| K, | ||
| dtype=torch.bfloat16, | ||
| device=device, | ||
| ) | ||
|
|
||
| # Benchmark naive implementation | ||
| naive_impl_c = torch.compile(torch_blockwise_scale_act_quant_lhs) | ||
|
||
| y_naive, s_naive = naive_impl_c(input_tensor, block_size) | ||
| naive_time_us = benchmark_cuda_function_in_microseconds( | ||
| naive_impl_c, | ||
| input_tensor, | ||
| block_size, | ||
| ) | ||
|
|
||
| # Benchmark Triton implementation | ||
| triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_lhs) | ||
danielvegamyhre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| y_triton, s_triton = triton_impl_c(input_tensor, block_size) | ||
| triton_time_us = benchmark_cuda_function_in_microseconds( | ||
| triton_impl_c, | ||
| input_tensor, | ||
| block_size, | ||
| ) | ||
|
|
||
| # Verify correctness (optional, can comment out for pure benchmarking) | ||
| verify_outputs(y_naive, s_naive, y_triton, | ||
| s_triton, input_tensor, block_size) | ||
|
|
||
| # Memory bandwidth calculations | ||
| bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 | ||
| bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 | ||
| bytes_per_scale_el = 4 # float32 | ||
danielvegamyhre marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| read_bytes = input_tensor.numel() * bytes_per_input_el | ||
| write_bytes = ( | ||
| y_triton.numel() * bytes_per_output_el + s_triton.numel() * bytes_per_scale_el | ||
| ) | ||
|
|
||
| naive_gbps = ((read_bytes + write_bytes) / 1e9) / (naive_time_us / 1e6) | ||
| triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_time_us / 1e6) | ||
|
|
||
| return ExperimentResult( | ||
| naive_us=naive_time_us, | ||
| triton_us=triton_time_us, | ||
| naive_gbps=naive_gbps, | ||
| triton_gbps=triton_gbps, | ||
| ) | ||
|
|
||
|
|
||
| def print_results(experiments: List[Experiment]): | ||
| headers = [ | ||
| "input_shape (M, K)", | ||
| "block_size", | ||
| "naive_us", | ||
| "triton_us", | ||
| "speedup", | ||
| "naive_gbps", | ||
| "triton_gbps", | ||
| ] | ||
| rows = [] | ||
| for experiment in experiments: | ||
| speedup = experiment.result.naive_us / experiment.result.triton_us | ||
| rows.append( | ||
| [ | ||
| f"{experiment.config.input_shape[0]}x{experiment.config.input_shape[1]}", | ||
| experiment.config.block_size, | ||
| f"{experiment.result.naive_us:.2f}", | ||
| f"{experiment.result.triton_us:.2f}", | ||
| f"{speedup:.2f}x", | ||
| f"{experiment.result.naive_gbps:.1f}", | ||
| f"{experiment.result.triton_gbps:.1f}", | ||
| ] | ||
| ) | ||
| print(tabulate(rows, headers=headers, tablefmt="grid")) | ||
|
|
||
|
|
||
| def main(): | ||
| torch.random.manual_seed(123) | ||
| configs = get_configs() | ||
| results = [] | ||
|
|
||
| print(f"Running {len(configs)} benchmark configurations...\n") | ||
|
|
||
| for config in tqdm(configs, desc="Benchmarking"): | ||
| result = run_experiment(config) | ||
| results.append(Experiment(config=config, result=result)) | ||
|
|
||
| print("\n" + "=" * 80) | ||
| print("BENCHMARK RESULTS") | ||
| print("=" * 80 + "\n") | ||
| print_results(results) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make the leading total_M dims (
seq_len * local_batch_size) bigger? e.g. range of 8192, 8192*2, 8192*4, 8192*8, 8192*16? this is more representative of what we'll see in real training runs.same for act_quant_rhs benchmarks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, any downside to having all the above quantization benchmarks with these bigger values?