@@ -320,3 +320,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
320320 tt.return
321321 }
322322}
323+
324+ // -----
325+ #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
326+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
327+ #blocked5 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
328+ #blocked6 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
329+ #blocked7 = #ttg.blocked <{sizePerThread = [1 , 1 , 1 , 1 , 2 , 2 , 1 ], threadsPerWarp = [1 , 1 , 4 , 16 , 1 , 1 , 1 ], warpsPerCTA = [4 , 1 , 1 , 1 , 1 , 1 , 1 ], order = [6 , 5 , 4 , 3 , 2 , 1 , 0 ]}>
330+ #blocked8 = #ttg.blocked <{sizePerThread = [1 , 2 , 1 , 1 , 2 , 1 , 1 ], threadsPerWarp = [1 , 1 , 16 , 1 , 1 , 4 , 1 ], warpsPerCTA = [4 , 1 , 1 , 1 , 1 , 1 , 1 ], order = [6 , 1 , 4 , 2 , 5 , 3 , 0 ]}>
331+ #linear = #ttg.linear <{register = [[16 , 0 ], [0 , 4 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [0 , 1 ], [0 , 2 ]], warp = [[32 , 0 ], [64 , 0 ]], block = []}>
332+
333+ // MFMA16: [[$linear1:#.*]] = #ttg.linear<{register = {{\[\[}}0, 4{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2{{]]}}, warp = {{\[\[}}0, 0], [0, 0{{]]}}, block = []}>
334+ // MFMA16: [[$linear2:#.*]] = #ttg.linear<{register = {{\[\[}}0, 4], [16, 0{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2{{]]}}, warp = {{\[\[}}32, 0], [64, 0{{]]}}, block = []}>
335+ // MFMA16: [[$mma:#.*]] = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [16, 16, 128], isTransposed = true, tilesPerWarp = [1, 2]}>
336+ // MFMA16-LABEL: mfma_dot_scaled_fp8_mxfp4
337+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
338+ tt.func public @mfma_dot_scaled_fp8_mxfp4 (
339+ %arg0: tensor <16 x256 xf8 E4 M3 FN, #blocked6 >,
340+ %arg1: tensor <4 x256 x!tt.ptr <i8 >, #blocked5 >,
341+ %arg2: tensor <128 x128 xi8 , #blocked1 >,
342+ %arg3: tensor <16 x128 x!tt.ptr <f32 >, #blocked1 >
343+ ) {
344+ // MFMA16: [[SCALE0:%.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<16x8xi8, [[$linear1]]>
345+ // MFMA16: [[SCALE1:%.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x8xi8, [[$linear2]]>
346+ // MFMA16: tt.dot_scaled {{.*}} scale [[SCALE0]], {{.*}} scale [[SCALE1]], {{.*}} -> tensor<16x128xf32, [[$mma]]>
347+ %cst0 = arith.constant dense <127 > : tensor <16 x8 xi8 , #blocked >
348+ %cst1 = arith.constant dense <0.000000e+00 > : tensor <16 x128 xf32 , #blocked1 >
349+ %load = tt.load %arg1 : tensor <4 x256 x!tt.ptr <i8 >, #blocked5 >
350+ %reshape0 = tt.reshape %load : tensor <4 x256 xi8 , #blocked5 > -> tensor <4 x1 x4 x16 x2 x2 x1 xi8 , #blocked7 >
351+ %trans = tt.trans %reshape0 {order = array<i32 : 0 , 5 , 3 , 1 , 4 , 2 , 6 >} : tensor <4 x1 x4 x16 x2 x2 x1 xi8 , #blocked7 > -> tensor <4 x2 x16 x1 x2 x4 x1 xi8 , #blocked8 >
352+ %reshape1 = tt.reshape %trans : tensor <4 x2 x16 x1 x2 x4 x1 xi8 , #blocked8 > -> tensor <128 x8 xi8 , #linear >
353+ %scale = ttg.convert_layout %reshape1 : tensor <128 x8 xi8 , #linear > -> tensor <128 x8 xi8 , #blocked >
354+ %1 = tt.dot_scaled %arg0 scale %cst0 , %arg2 scale %scale , %cst1 lhs = e4m3 rhs = e2m1 {fastMath = true } : tensor <16 x256 xf8 E4 M3 FN, #blocked6 >, tensor <16 x8 xi8 , #blocked > * tensor <128 x128 xi8 , #blocked1 >, tensor <128 x8 xi8 , #blocked > -> tensor <16 x128 xf32 , #blocked1 >
355+ tt.store %arg3 , %1 : tensor <16 x128 x!tt.ptr <f32 >, #blocked1 >
356+ tt.return
357+ }
358+ }
0 commit comments