-
Notifications
You must be signed in to change notification settings - Fork 363
Re: #3290 FP8 Blockwise Training Tracker, quantization benchmarks #3306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…aive torch implementation
…entation torch_blockwise_scale_act_quant_lhs from existing blockwise_fp8_training/kernels
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3306
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @agolajko! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
| @@ -0,0 +1,315 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see a space in the filename for this one, can you remove it?
| y, s_reciprocal = torch_blockwise_scale_act_quant_lhs( | ||
| x, tile_size=block_size) | ||
|
|
||
| # Convert scales from row-major to column-major format to match Triton kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's adjust the torch_blockwise_scale_act_quant_lhs to write to col major instead, since that is the layout required by torch._scaled_grouped_mm.
You can do this here with a s.transpose(-2,-1).contiguous().transpose(-2,-1) to convert from row major to col major.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that does make using the function easier
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
Thanks for working on this @agolajko, can you run the benchmarks and include them in the PR description? or do you not have access to a b200 gpu? |
|
Yeah, I've been testing on a 5090 via Runpod, let me see if I can access their B200 as well |
ok any blackwell gpu is fine |
|
|
||
| # ROBUST FIX: Handle potential dtype mismatches from torch.compile | ||
| # Convert both scales to float32 before any operations | ||
| if s_naive.dtype != torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scales should already fp32 for fp8 quantization, no need for this
| s_triton = s_triton.to(torch.float32) | ||
|
|
||
| # Check scales are close | ||
| # Note: scales are in column-major format, need to read them correctly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldn't have to change the memory layout to compare. plus after your change to the torch reference impl, both should be in col major anyway
| s_triton_rowmajor = s_triton.as_strided( | ||
| s_triton.shape, (s_triton.shape[1], 1)) | ||
|
|
||
| if not torch.allclose( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a simple torch.testing.assert_close would be better. We want the test to fail, not just print stuff.
| ) | ||
|
|
||
| # Benchmark Triton implementation | ||
| triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to compile the the custom op wrapping the triton kernel (this is basically a no-op, as compile won't trace inside the custom op). Just call triton_fp8_blockwise_act_quant_lhs directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, thanks
On that note, is there warmup needed like in this Gemm bench? https://github.com/pytorch/ao/blob/main/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py#L86
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline but documenting here: nice to have just in case but bench_cuda_function_microseconds should do a bunch of iterations and take the median - so robust against earlier outliers during warmup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, not sure if you've tried to make the change yet or not, but just in case:
torch_blockwise_scale_act_quant_lhsshould be compiled (i see you commented out the compile). for torch native code, we use compile to autogen triton kernels to improve performance, so that is the baseline we are measuring against.triton_fp8_blockwise_act_quant_lhsshould not be compiled. this is a handwritten triton kernel, compile will have no effect.
| # Memory bandwidth calculations | ||
| bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 | ||
| bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 | ||
| bytes_per_scale_el = 4 # float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency:
torch.finfo(torch.float32).bits / 8
|
@danielvegamyhre added the results as well as some of the potentially relevant sys info |
|
if this extends to gemms, I'd recommend also benchmarking |
|
Thanks for the suggestion @vkuzo, for now just keeping this PR to benching the quantizations |
|
Btw @vkuzo what are the other Gemms benchmarking scripts needed for the FP8 training given there are already two (bench_1x128_128x128 and bench_1x128_128x1) here: https://github.com/pytorch/ao/tree/main/benchmarks/prototype/blockwise_fp8_training |
| verify_outputs(y_naive, s_naive, y_triton, s_triton) | ||
|
|
||
| # Memory bandwidth calculations | ||
| bytes_per_input_el = torch.finfo(torch.float32).bits / 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're using bf16 input tensor, but calculating memory bandwidth using fp32 bytes per element.
to avoid this type of bug, just use input_tensor.dtype rather than hardcoding. same goes for fp8 outputs and scales.
i'd make this change across all the bench scripts
| ) | ||
|
|
||
| # y must be column-major per RHS kernel contract | ||
| y = x.new_empty(M, K, dtype=torch.float8_e4m3fn).as_strided((M, K), (1, M)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just do:
y = y_rowmajor.t().contiguous().t()to convert to col major. this is the common pattern across the codebase
| y_triton_float = y_triton.to(torch.float32) | ||
|
|
||
| try: | ||
| torch.testing.assert_close( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just assert without try/catch and printing stats. these print statements seem nice but are not actually useful in practice i've found.
(same goes for all the bench scripts)
Summary
As discussed with @danielvegamyhre following #3290 created benchmark scripts for each quantization kernel from kernels.py
This compares the kernels' performance against a naive torch implementation
For
triton_fp8_blockwise_act_quant_lhsI was able to use the naive torch implementation in the kernels.py file, for the others I wrote new onesBench results on a 5090
Python: 3.13.8 (main, Oct 8 2025, 08:53:25) [GCC 13.3.0]
PyTorch: 2.9.0+cu128
CUDA: 12.8
CuDNN: 91002
570.153.02, NVIDIA GeForce RTX 5090, 12.0
Ubuntu 24.04.3 LTS
LHS Activation:
RHS Activation:
Transposed LHS Activation:
RHS Weights:
Transposed RHS Weights: