From c2416632ca29dd98292ce0cf12a38705910ecede Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 26 Sep 2025 10:41:58 -0500 Subject: [PATCH] [AMD] Support float8_e5m2 in 03-matrix-multiplication.py for gfx950 This enables the float8_e5m2 part of the test for AMD's gfx950 devices, which has gained support for this data type (see https://rocm.docs.amd.com/en/latest/reference/precision-support.html). This test was already enabled for CUDA devices. --- python/tutorials/03-matrix-multiplication.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 526934c1d7c6..e46c8fb4e0b4 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -154,6 +154,8 @@ import triton import triton.language as tl +from triton.language.target_info import is_hip_cdna4 + DEVICE = triton.runtime.driver.active.get_active_torch_device() @@ -210,7 +212,7 @@ def get_hip_autotune_config(): {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4}, {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6}, ] - return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes] + return [triton.Config(s | {'matrix_instr_nonkdim': 32}, num_warps=8, num_stages=2) for s in sizes] def get_autotune_config(): @@ -372,7 +374,7 @@ def matmul(a, b, activation=""): print("❌ Triton and Torch differ") TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") -if TORCH_HAS_FP8 and is_cuda(): +if TORCH_HAS_FP8 and (is_cuda() or is_hip_cdna4()): torch.manual_seed(0) a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) @@ -403,7 +405,7 @@ def matmul(a, b, activation=""): configs = [] for fp8_inputs in [False, True]: - if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): + if fp8_inputs and (not TORCH_HAS_FP8 or (not is_cuda() and not is_hip_cdna4())): continue configs.append( triton.testing.Benchmark(