Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/tutorials/03-matrix-multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@
import triton
import triton.language as tl

from triton.language.target_info import is_hip_cdna4

DEVICE = triton.runtime.driver.active.get_active_torch_device()


Expand Down Expand Up @@ -210,7 +212,7 @@ def get_hip_autotune_config():
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
]
return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes]
return [triton.Config(s | {'matrix_instr_nonkdim': 32}, num_warps=8, num_stages=2) for s in sizes]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why changing this? Seems unrelated to the purpose of this pull request?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only configuration I found that works with both fp16 and fp8. Do you prefer if I split up the function into separate configs for fp16 and fp8?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp8 to use the mfma_16x16x128 instruction, BLOCK_K needs to be >=128. The current configs only have BLOCK_K=64, that's why it does not work.
I think performance is not a concern in the tutorials. We are lacking other optimizations, such as enabling buffer_load to avoid branches for masked load. I'd suggest to not touch the tuning config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you prefer that I change BLOCK_SIZE_K to 128, instead of the change to matrix_instr_nonkdim?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer you leave all tuning config un-changed. We are in the middle of fixing issues related to buffer ops. The current perf from tutorial will be temporal anyways.

Copy link
Author

@matthiasdiener matthiasdiener Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification on keeping the tutorial configs stable.
Just to clarify why I touched this, it is not due to performance reasons: with fp8, the example crashes at compile time on gfx950 during ConvertTritonAMDGPUToLLVM when using the unmodified tuning config:

python: /root/.triton/llvm/llvm-064f02da-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](size_type) [with T = mlir::Value; <template-parameter-1-2> = void; reference = mlir::Value&; size_type = long unsigned int]: Assertion idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 2], instrShape = [16, 16, 128], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
  tt.func public @matmul_kernel(%arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<64> : tensor<32x64xi32, #blocked>
    %cst_0 = arith.constant dense<64> : tensor<64x32xi32, #blocked1>
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c6_i32 = arith.constant 6 : i32
    %c0_i32 = arith.constant 0 : i32
    %true = arith.constant true
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf8E5M2, #blocked1>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x64xf8E5M2, #blocked>
    %c63_i32 = arith.constant 63 : i32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = arith.addi %arg3, %c31_i32 : i32
    %2 = arith.divsi %1, %c32_i32 : i32
    %3 = arith.addi %arg4, %c31_i32 : i32
    %4 = arith.divsi %3, %c32_i32 : i32
    %5 = arith.muli %4, %c6_i32 : i32
    %6 = arith.divsi %0, %5 : i32
    %7 = arith.muli %6, %c6_i32 : i32
    %8 = arith.subi %2, %7 : i32
    %9 = arith.minsi %8, %c6_i32 : i32
    %10 = arith.remsi %0, %5 : i32
    %11 = arith.remsi %10, %9 : i32
    %12 = arith.addi %7, %11 : i32
    %13 = arith.divsi %10, %9 : i32
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    %14 = arith.muli %12, %c32_i32 : i32
    %15 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %16 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %17 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %19 = tt.splat %14 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %20 = tt.splat %14 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %21 = arith.addi %19, %15 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %22 = arith.addi %20, %16 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %23 = tt.splat %arg3 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %24 = arith.remsi %21, %23 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %25 = arith.muli %13, %c32_i32 : i32
    %26 = tt.splat %25 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %27 = tt.splat %25 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %28 = arith.addi %26, %17 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %29 = arith.addi %27, %18 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %30 = tt.splat %arg4 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %31 = arith.remsi %28, %30 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
    %33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
    %34 = tt.expand_dims %24 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
    %35 = tt.splat %arg6 : i32 -> tensor<32x1xi32, #blocked>
    %36 = arith.muli %34, %35 : tensor<32x1xi32, #blocked>
    %37 = tt.broadcast %36 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked>
    %38 = tt.broadcast %33 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
    %39 = arith.addi %37, %38 : tensor<32x64xi32, #blocked>
    %40 = arith.addi %arg5, %c63_i32 : i32
    %41 = arith.divsi %40, %c64_i32 : i32
    %42 = arith.cmpi sgt, %41, %c0_i32 : i32
    %43 = tt.splat %42 : i1 -> tensor<32x64xi1, #blocked>
    %44 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked>
    %45 = arith.cmpi slt, %33, %44 : tensor<1x64xi32, #blocked>
    %46 = tt.broadcast %45 : tensor<1x64xi1, #blocked> -> tensor<32x64xi1, #blocked>
    %47 = arith.andi %43, %46 : tensor<32x64xi1, #blocked>
    %48 = amdgpu.buffer_load %arg0[%39], %47 stride = %arg6 : tensor<32x64xf8E5M2, #blocked>
    %49 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %50 = tt.expand_dims %49 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1>
    %51 = tt.broadcast %50 : tensor<64x1xi32, #blocked1> -> tensor<64x32xi32, #blocked1>
    %52 = tt.expand_dims %31 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
    %53 = tt.splat %arg7 : i32 -> tensor<1x32xi32, #blocked1>
    %54 = arith.muli %52, %53 : tensor<1x32xi32, #blocked1>
    %55 = tt.broadcast %54 : tensor<1x32xi32, #blocked1> -> tensor<64x32xi32, #blocked1>
    %56 = arith.addi %51, %55 : tensor<64x32xi32, #blocked1>
    %57 = tt.splat %42 : i1 -> tensor<64x32xi1, #blocked1>
    %58 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked1>
    %59 = arith.cmpi slt, %50, %58 : tensor<64x1xi32, #blocked1>
    %60 = tt.broadcast %59 : tensor<64x1xi1, #blocked1> -> tensor<64x32xi1, #blocked1>
    %61 = arith.andi %57, %60 : tensor<64x32xi1, #blocked1>
    %62 = amdgpu.buffer_load %arg1[%56], %61 stride = %arg7 : tensor<64x32xf8E5M2, #blocked1>
    %63 = ttg.local_alloc : () -> !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable>
    %64 = ttg.local_alloc : () -> !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable>
    %65 = ttg.memdesc_index %63[%c0_i32] : !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
    ttg.local_store %48, %65 : tensor<32x64xf8E5M2, #blocked> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
    %66 = ttg.memdesc_index %64[%c0_i32] : !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
    ttg.local_store %62, %66 : tensor<64x32xf8E5M2, #blocked1> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
    %67 = arith.subi %41, %c1_i32 : i32
    %68:6 = scf.for %arg9 = %c0_i32 to %67 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %39, %arg12 = %56, %arg13 = %c0_i32, %arg14 = %65, %arg15 = %66) -> (tensor<32x32xf32, #mma>, tensor<32x64xi32, #blocked>, tensor<64x32xi32, #blocked1>, i32, !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>, !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>)  : i32 {
      %95 = arith.addi %arg11, %cst : tensor<32x64xi32, #blocked>
      %96 = arith.addi %arg12, %cst_0 : tensor<64x32xi32, #blocked1>
      %97 = arith.addi %arg9, %c1_i32 : i32
      %98 = arith.muli %97, %c64_i32 : i32
      %99 = arith.subi %arg5, %98 : i32
      %100 = tt.splat %99 : i32 -> tensor<1x64xi32, #blocked>
      %101 = arith.cmpi slt, %33, %100 : tensor<1x64xi32, #blocked>
      %102 = tt.broadcast %101 : tensor<1x64xi1, #blocked> -> tensor<32x64xi1, #blocked>
      %103 = tt.splat %arg0 : !tt.ptr<f8E5M2> -> tensor<32x64x!tt.ptr<f8E5M2>, #blocked>
      %104 = tt.addptr %103, %95 : tensor<32x64x!tt.ptr<f8E5M2>, #blocked>, tensor<32x64xi32, #blocked>
      %105 = tt.load %104, %102, %cst_2 : tensor<32x64x!tt.ptr<f8E5M2>, #blocked>
      %106 = ttg.local_load %arg14 : !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64> -> tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
      %107 = tt.splat %99 : i32 -> tensor<64x1xi32, #blocked1>
      %108 = arith.cmpi slt, %50, %107 : tensor<64x1xi32, #blocked1>
      %109 = tt.broadcast %108 : tensor<64x1xi1, #blocked1> -> tensor<64x32xi1, #blocked1>
      %110 = tt.splat %arg1 : !tt.ptr<f8E5M2> -> tensor<64x32x!tt.ptr<f8E5M2>, #blocked1>
      %111 = tt.addptr %110, %96 : tensor<64x32x!tt.ptr<f8E5M2>, #blocked1>, tensor<64x32xi32, #blocked1>
      %112 = tt.load %111, %109, %cst_1 : tensor<64x32x!tt.ptr<f8E5M2>, #blocked1>
      %113 = ttg.local_load %arg15 : !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32> -> tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
      %114 = tt.dot_scaled %106, %113, %arg10 lhs = e5m2 rhs = e5m2 {fastMath = false} : tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<32x32xf32, #mma>
      %115 = arith.addi %arg13, %c1_i32 : i32
      %116 = arith.cmpi slt, %115, %c1_i32 : i32
      %117 = arith.select %116, %115, %c0_i32 : i32
      %118 = ttg.memdesc_index %63[%117] : !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
      ttg.local_store %105, %118 : tensor<32x64xf8E5M2, #blocked> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
      %119 = ttg.memdesc_index %64[%117] : !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
      ttg.local_store %112, %119 : tensor<64x32xf8E5M2, #blocked1> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
      scf.yield %114, %95, %96, %117, %118, %119 : tensor<32x32xf32, #mma>, tensor<32x64xi32, #blocked>, tensor<64x32xi32, #blocked1>, i32, !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>, !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
    }
    %69 = arith.cmpi sge, %41, %c1_i32 : i32
    %70 = ttg.local_load %68#4 : !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64> -> tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
    %71 = ttg.local_load %68#5 : !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32> -> tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
    %72 = scf.if %69 -> (tensor<32x32xf32, #mma>) {
      %95 = tt.dot_scaled %70, %71, %68#0 lhs = e5m2 rhs = e5m2 {fastMath = false} : tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<32x32xf32, #mma>
      scf.yield %95 : tensor<32x32xf32, #mma>
    } else {
      scf.yield %68#0 : tensor<32x32xf32, #mma>
    }
    %73 = arith.select %69, %72, %68#0 : tensor<32x32xf32, #mma>
    ttg.local_dealloc %64 : !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable>
    ttg.local_dealloc %63 : !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable>
    %74 = arith.truncf %73 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma>
    %75 = tt.expand_dims %22 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
    %76 = arith.muli %arg8, %14 : i32
    %77 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
    %78 = tt.expand_dims %16 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
    %79 = tt.splat %arg8 : i32 -> tensor<32x1xi32, #mma>
    %80 = arith.muli %79, %78 : tensor<32x1xi32, #mma>
    %81 = tt.broadcast %80 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
    %82 = tt.expand_dims %18 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
    %83 = tt.broadcast %82 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
    %84 = arith.addi %76, %25 : i32
    %85 = arith.addi %81, %83 : tensor<32x32xi32, #mma>
    %86 = tt.splat %84 : i32 -> tensor<32x32xi32, #mma>
    %87 = arith.addi %86, %85 : tensor<32x32xi32, #mma>
    %88 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #mma>
    %89 = arith.cmpi slt, %75, %88 : tensor<32x1xi32, #mma>
    %90 = tt.splat %arg4 : i32 -> tensor<1x32xi32, #mma>
    %91 = arith.cmpi slt, %77, %90 : tensor<1x32xi32, #mma>
    %92 = tt.broadcast %89 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
    %93 = tt.broadcast %91 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
    %94 = arith.andi %92, %93 : tensor<32x32xi1, #mma>
    amdgpu.buffer_store %74, %arg2[%87], %94 : tensor<32x32xf16, #mma>
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx950}, convert-scf-to-cf, gluon-inline, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx950 ftz=true}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
      disable_threading: true,
      verify_each: true
    }
  }
#-}
//tritonmm.py:91:0: error: Failures have been detected while processing an MLIR pass pipeline
//tritonmm.py:91:0: note: Pipeline failed while executing [ConvertTritonAMDGPUToLLVM on 'builtin.module' operation]: reproducer generated at std::errs, please share the reproducer above with Triton project.
Traceback (most recent call last):
  File "//tritonmm.py", line 231, in <module>
    triton_output = matmul(a, b)
                    ^^^^^^^^^^^^
  File "//tritonmm.py", line 197, in matmul
    matmul_kernel[grid](
  File "//venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 359, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 240, in run
    benchmark()
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 164, in _bench
    return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/testing.py", line 149, in do_bench
    fn()
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 150, in kernel_call
    self.fn.run(
  File "//venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 675, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 803, in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/compiler/compiler.py", line 320, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py", line 474, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py", line 326, in make_llir
    pm.run(mod, 'make_llir')
RuntimeError: PassManager::run failed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this, would it be ok to include the change to matrix_instr_nonkdim = 32 ?



def get_autotune_config():
Expand Down Expand Up @@ -372,7 +374,7 @@ def matmul(a, b, activation=""):
print("❌ Triton and Torch differ")

TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
if TORCH_HAS_FP8 and (is_cuda() or is_hip_cdna4()):
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
Expand Down Expand Up @@ -403,7 +405,7 @@ def matmul(a, b, activation=""):

configs = []
for fp8_inputs in [False, True]:
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
if fp8_inputs and (not TORCH_HAS_FP8 or (not is_cuda() and not is_hip_cdna4())):
continue
configs.append(
triton.testing.Benchmark(
Expand Down