Skip to content

Commit feb755c

Browse files
Merge OpenAI Triton commit 4734af3 (#5460)
This PR changes the Triton base from 00cf53f to 4734af3 (Oct 24). Pass rate: 94.59%
2 parents ef7d239 + f1e5aad commit feb755c

File tree

14 files changed

+366
-398
lines changed

14 files changed

+366
-398
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,11 @@ def generate_gemm_input(dim0, dim1, dtype):
779779
triton_out = triton_out.to(torch.float32)
780780
torch.testing.assert_close(torch_out, triton_out, atol=2e-5, rtol=1e-4)
781781
if is_hip() and preshuffle:
782-
assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"]
783782
assert "ds_read_u8" not in k.asm["amdgcn"]
783+
if mfma_nonkdim == 16:
784+
assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"]
785+
elif mfma_nonkdim == 32: # default tilesPerWarp = [1, 1]
786+
assert "tilesPerWarp" not in k.asm["ttgir"]
784787

785788

786789
@pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)])

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
385385
@pytest.mark.interpreter
386386
def test_tensor_descriptor_padding(device):
387387
if is_xpu():
388-
pytest.skip("padding is unsupported")
388+
pytest.skip("FIXME: issue #5400")
389389

390390
@triton.jit
391391
def device_tma_load(in_ptr, out_ptr, IM, IN, YM, YN, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr,

python/triton_kernels/reduce.py

Lines changed: 0 additions & 282 deletions
This file was deleted.

python/triton_kernels/triton_kernels/reduce.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def reduce(
147147
Returns:
148148
- output: torch.Tensor
149149
The reduced tensor with `dim` removed.
150+
- output_mxscale: Optional[torch.Tensor]
151+
The output mx scale if input is micro-scaled, else None.
150152
"""
151153
if x.ndim != 3:
152154
raise NotImplementedError("reduce only supports 3D inputs in this implementation")
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 | FileCheck %s
2+
3+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 8, order = [1, 0]}>
5+
#smem = #ttg.shared_memory
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
7+
// CHECK-LABEL: async_copy_with_swizzle
8+
tt.func public @async_copy_with_swizzle(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
9+
%arg2: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
10+
// We need the splat to allow the AxisAnalysis to work during lowering
11+
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<32x32x!tt.ptr<f32>, #blocked>
12+
// Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds
13+
// CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32
14+
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
15+
%2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
16+
tt.return
17+
}
18+
}
19+
20+
// -----
21+
22+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
23+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}>
24+
#smem = #ttg.shared_memory
25+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} {
26+
// CHECK-LABEL: async_load_strided_into_lds_with_swizzle
27+
tt.func public @async_load_strided_into_lds_with_swizzle(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>},
28+
%arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) {
29+
// Each thread loads 256 contiguous bits so we split into 2 128bit loads. This was not possible on GFX9
30+
// CHECK-COUNT-2: llvm.amdgcn.global.load.async.to.lds.b128
31+
// CHECK-NOT: llvm.amdgcn.global.load.async.to.lds
32+
%6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr<f32>, #blocked> -> <32x32xf32, #shared, #smem, mutable>
33+
tt.return
34+
}
35+
}

test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
320320
tt.return
321321
}
322322
}
323+
324+
// -----
325+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
326+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}>
327+
#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
328+
#blocked6 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}>
329+
#blocked7 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 2, 2, 1], threadsPerWarp = [1, 1, 4, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 1, 1, 1, 1], order = [6, 5, 4, 3, 2, 1, 0]}>
330+
#blocked8 = #ttg.blocked<{sizePerThread = [1, 2, 1, 1, 2, 1, 1], threadsPerWarp = [1, 1, 16, 1, 1, 4, 1], warpsPerCTA = [4, 1, 1, 1, 1, 1, 1], order = [6, 1, 4, 2, 5, 3, 0]}>
331+
#linear = #ttg.linear<{register = [[16, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[32, 0], [64, 0]], block = []}>
332+
333+
// MFMA16: [[$linear1:#.*]] = #ttg.linear<{register = {{\[\[}}0, 4{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2{{]]}}, warp = {{\[\[}}0, 0], [0, 0{{]]}}, block = []}>
334+
// MFMA16: [[$linear2:#.*]] = #ttg.linear<{register = {{\[\[}}0, 4], [16, 0{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2{{]]}}, warp = {{\[\[}}32, 0], [64, 0{{]]}}, block = []}>
335+
// MFMA16: [[$mma:#.*]] = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [16, 16, 128], isTransposed = true, tilesPerWarp = [1, 2]}>
336+
// MFMA16-LABEL: mfma_dot_scaled_fp8_mxfp4
337+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
338+
tt.func public @mfma_dot_scaled_fp8_mxfp4(
339+
%arg0: tensor<16x256xf8E4M3FN, #blocked6>,
340+
%arg1: tensor<4x256x!tt.ptr<i8>, #blocked5>,
341+
%arg2: tensor<128x128xi8, #blocked1>,
342+
%arg3: tensor<16x128x!tt.ptr<f32>, #blocked1>
343+
) {
344+
// MFMA16: [[SCALE0:%.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<16x8xi8, [[$linear1]]>
345+
// MFMA16: [[SCALE1:%.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x8xi8, [[$linear2]]>
346+
// MFMA16: tt.dot_scaled {{.*}} scale [[SCALE0]], {{.*}} scale [[SCALE1]], {{.*}} -> tensor<16x128xf32, [[$mma]]>
347+
%cst0 = arith.constant dense<127> : tensor<16x8xi8, #blocked>
348+
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked1>
349+
%load = tt.load %arg1 : tensor<4x256x!tt.ptr<i8>, #blocked5>
350+
%reshape0 = tt.reshape %load : tensor<4x256xi8, #blocked5> -> tensor<4x1x4x16x2x2x1xi8, #blocked7>
351+
%trans = tt.trans %reshape0 {order = array<i32: 0, 5, 3, 1, 4, 2, 6>} : tensor<4x1x4x16x2x2x1xi8, #blocked7> -> tensor<4x2x16x1x2x4x1xi8, #blocked8>
352+
%reshape1 = tt.reshape %trans : tensor<4x2x16x1x2x4x1xi8, #blocked8> -> tensor<128x8xi8, #linear>
353+
%scale = ttg.convert_layout %reshape1 : tensor<128x8xi8, #linear> -> tensor<128x8xi8, #blocked>
354+
%1 = tt.dot_scaled %arg0 scale %cst0, %arg2 scale %scale, %cst1 lhs = e4m3 rhs = e2m1 {fastMath = true} : tensor<16x256xf8E4M3FN, #blocked6>, tensor<16x8xi8, #blocked> * tensor<128x128xi8, #blocked1>, tensor<128x8xi8, #blocked> -> tensor<16x128xf32, #blocked1>
355+
tt.store %arg3, %1 : tensor<16x128x!tt.ptr<f32>, #blocked1>
356+
tt.return
357+
}
358+
}

0 commit comments

Comments
 (0)