Skip to content

Conversation

@agolajko
Copy link

@agolajko agolajko commented Nov 6, 2025

Summary

As discussed with @danielvegamyhre following #3290 created benchmark scripts for each quantization kernel from kernels.py

This compares the kernels' performance against a naive torch implementation

For triton_fp8_blockwise_act_quant_lhs I was able to use the naive torch implementation in the kernels.py file, for the others I wrote new ones

Bench results on a 5090

Python: 3.13.8 (main, Oct 8 2025, 08:53:25) [GCC 13.3.0]
PyTorch: 2.9.0+cu128
CUDA: 12.8
CuDNN: 91002
570.153.02, NVIDIA GeForce RTX 5090, 12.0
Ubuntu 24.04.3 LTS

LHS Activation:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, K)   |   block_size |   naive_us |   triton_us | speedup   |   naive_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      64.06 |       14.34 | 4.47x     |        262.9 |        1174.9 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      88.06 |       16.38 | 5.38x     |        382.5 |        2056   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     116.74 |       26.62 | 4.38x     |        577.1 |        2530.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     238.59 |       45.7  | 5.22x     |        564.7 |        2948.7 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     581.66 |       83.97 | 6.93x     |        463.3 |        3209.4 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

RHS Activation:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, K)   |   block_size |   naive_us |   triton_us | speedup   |   naive_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      42.34 |       19.34 | 2.19x     |        249.2 |         545.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      73.57 |       27.49 | 2.68x     |        286.8 |         767.7 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     135.33 |       33.84 | 4.00x     |        311.9 |        1247.2 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     266.24 |       58.5  | 4.55x     |        317   |        1443   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     529.41 |      114.85 | 4.61x     |        318.9 |        1469.9 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

Transposed LHS Activation:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, K)   |   block_size |   naive_us |   triton_us | speedup   |   naive_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      42.4  |       16.8  | 2.52x     |        248.9 |         628.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      73.63 |       21.54 | 3.42x     |        286.6 |         979.9 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     135.17 |       25.82 | 5.23x     |        312.2 |        1634.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     264.19 |       40.26 | 6.56x     |        319.5 |        2096.8 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     525.31 |       68.74 | 7.64x     |        321.4 |        2456.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

RHS Weights:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, N)   |   block_size |   naive_us |   triton_us | speedup   |   naive_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      51.2  |       23.55 | 2.17x     |        204.8 |         445.2 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      75.36 |       34.43 | 2.19x     |        278.3 |         609.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     138.53 |       40.96 | 3.38x     |        302.8 |        1024.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     277.63 |       66.82 | 4.16x     |        302.2 |        1255.5 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     540.9  |      125.98 | 4.29x     |        310.2 |        1331.8 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

Transposed RHS Weights:

+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| input_shape (M, N)   |   block_size |   naive_us |   triton_us | speedup   |   naive_gbps |   triton_gbps |
+======================+==============+============+=============+===========+==============+===============+
| 512x4096             |          128 |      12.29 |       23.04 | 0.53x     |        853.4 |         455.1 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 1024x4096            |          128 |      67.71 |       30.72 | 2.20x     |        309.7 |         682.7 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 2048x4096            |          128 |     124.19 |       39.94 | 3.11x     |        337.7 |        1050.3 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 4096x4096            |          128 |     243.01 |       66.62 | 3.65x     |        345.2 |        1259.2 |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+
| 8192x4096            |          128 |     482.51 |      123.55 | 3.91x     |        347.7 |        1358   |
+----------------------+--------------+------------+-------------+-----------+--------------+---------------+

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3306

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link

meta-cla bot commented Nov 6, 2025

Hi @agolajko!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@@ -0,0 +1,315 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see a space in the filename for this one, can you remove it?

y, s_reciprocal = torch_blockwise_scale_act_quant_lhs(
x, tile_size=block_size)

# Convert scales from row-major to column-major format to match Triton kernel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's adjust the torch_blockwise_scale_act_quant_lhs to write to col major instead, since that is the layout required by torch._scaled_grouped_mm.

You can do this here with a s.transpose(-2,-1).contiguous().transpose(-2,-1) to convert from row major to col major.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that does make using the function easier

@meta-cla
Copy link

meta-cla bot commented Nov 6, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 6, 2025
@danielvegamyhre
Copy link
Contributor

Thanks for working on this @agolajko, can you run the benchmarks and include them in the PR description? or do you not have access to a b200 gpu?

@agolajko
Copy link
Author

agolajko commented Nov 6, 2025

Yeah, I've been testing on a 5090 via Runpod, let me see if I can access their B200 as well

@danielvegamyhre
Copy link
Contributor

Yeah, I've been testing on a 5090 via Runpod, let me see if I can access their B200 as well

ok any blackwell gpu is fine


# ROBUST FIX: Handle potential dtype mismatches from torch.compile
# Convert both scales to float32 before any operations
if s_naive.dtype != torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scales should already fp32 for fp8 quantization, no need for this

s_triton = s_triton.to(torch.float32)

# Check scales are close
# Note: scales are in column-major format, need to read them correctly
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't have to change the memory layout to compare. plus after your change to the torch reference impl, both should be in col major anyway

s_triton_rowmajor = s_triton.as_strided(
s_triton.shape, (s_triton.shape[1], 1))

if not torch.allclose(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a simple torch.testing.assert_close would be better. We want the test to fail, not just print stuff.

)

# Benchmark Triton implementation
triton_impl_c = torch.compile(triton_fp8_blockwise_act_quant_lhs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to compile the the custom op wrapping the triton kernel (this is basically a no-op, as compile won't trace inside the custom op). Just call triton_fp8_blockwise_act_quant_lhs directly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks

On that note, is there warmup needed like in this Gemm bench? https://github.com/pytorch/ao/blob/main/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py#L86

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline but documenting here: nice to have just in case but bench_cuda_function_microseconds should do a bunch of iterations and take the median - so robust against earlier outliers during warmup

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, not sure if you've tried to make the change yet or not, but just in case:

  • torch_blockwise_scale_act_quant_lhs should be compiled (i see you commented out the compile). for torch native code, we use compile to autogen triton kernels to improve performance, so that is the baseline we are measuring against.
  • triton_fp8_blockwise_act_quant_lhs should not be compiled. this is a handwritten triton kernel, compile will have no effect.

# Memory bandwidth calculations
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
bytes_per_scale_el = 4 # float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency:

torch.finfo(torch.float32).bits / 8

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

@danielvegamyhre added the results as well as some of the potentially relevant sys info

@vkuzo
Copy link
Contributor

vkuzo commented Nov 7, 2025

if this extends to gemms, I'd recommend also benchmarking F.scaled_mm. Note that you need a nightly with pytorch/pytorch#166752 for correct numerics.

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

Thanks for the suggestion @vkuzo, for now just keeping this PR to benching the quantizations

@agolajko
Copy link
Author

agolajko commented Nov 7, 2025

Btw @vkuzo what are the other Gemms benchmarking scripts needed for the FP8 training given there are already two (bench_1x128_128x128 and bench_1x128_128x1) here: https://github.com/pytorch/ao/tree/main/benchmarks/prototype/blockwise_fp8_training

verify_outputs(y_naive, s_naive, y_triton, s_triton)

# Memory bandwidth calculations
bytes_per_input_el = torch.finfo(torch.float32).bits / 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're using bf16 input tensor, but calculating memory bandwidth using fp32 bytes per element.

to avoid this type of bug, just use input_tensor.dtype rather than hardcoding. same goes for fp8 outputs and scales.

i'd make this change across all the bench scripts

)

# y must be column-major per RHS kernel contract
y = x.new_empty(M, K, dtype=torch.float8_e4m3fn).as_strided((M, K), (1, M))
Copy link
Contributor

@danielvegamyhre danielvegamyhre Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just do:

y = y_rowmajor.t().contiguous().t()

to convert to col major. this is the common pattern across the codebase

y_triton_float = y_triton.to(torch.float32)

try:
torch.testing.assert_close(
Copy link
Contributor

@danielvegamyhre danielvegamyhre Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just assert without try/catch and printing stats. these print statements seem nice but are not actually useful in practice i've found.

(same goes for all the bench scripts)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants