Skip to content

Commit 3ab723d

Browse files
borontionyangshuxin
authored andcommitted
[AMD] Support TDM store on gfx1250 (triton-lang#8392)
This PR adds support for TDM store on gfx1250, following triton-lang#8333. Groups common TDM utilities for load/store.
1 parent c9ce7a4 commit 3ab723d

File tree

13 files changed

+611
-224
lines changed

13 files changed

+611
-224
lines changed

python/src/gluon_ir.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,12 @@ void init_gluon_ir(py::module &&m) {
789789
self.create<ttag::AsyncTDMCopyGlobalToLocalOp>(descPtr, indices,
790790
result, pred);
791791
})
792+
.def("create_async_tdm_copy_local_to_global",
793+
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
794+
Value src) {
795+
self.create<ttag::AsyncTDMCopyLocalToGlobalOp>(descPtr, indices,
796+
src);
797+
})
792798
.def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) {
793799
ValueRange tokens;
794800
self.create<ttag::AsyncTDMWait>(tokens, num);

python/test/gluon/test_frontend.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2816,7 +2816,7 @@ def kernel():
28162816

28172817

28182818
@gluon.jit
2819-
def amd_tdm_kernel(ptr):
2819+
def amd_tdm_load_kernel(ptr):
28202820
SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [16, 64], [1, 0])
28212821
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])
28222822

@@ -2831,17 +2831,17 @@ def amd_tdm_kernel(ptr):
28312831

28322832

28332833
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
2834-
def test_amd_tdm(target):
2834+
def test_amd_tdm_load(target):
28352835

28362836
ptr = MockTensor(ttgl.float16)
2837-
module = run_parser(amd_tdm_kernel, *make_args(ptr), target)
2837+
module = run_parser(amd_tdm_load_kernel, *make_args(ptr), target)
28382838
expecttest.assert_expected_inline(
28392839
anonymize_ir(module.str_nodebug()), """\
28402840
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
28412841
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}>
28422842
#smem = #ttg.shared_memory
28432843
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2844-
tt.func public @amd_tdm_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
2844+
tt.func public @amd_tdm_load_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
28452845
%c32_i32 = arith.constant 32 : i32
28462846
%c128_i32 = arith.constant 128 : i32
28472847
%c128_i64 = arith.constant 128 : i64
@@ -2858,3 +2858,48 @@ def test_amd_tdm(target):
28582858
}
28592859
}
28602860
""")
2861+
2862+
2863+
@gluon.jit
2864+
def amd_tdm_store_kernel(ptr):
2865+
SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
2866+
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])
2867+
2868+
desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=ptr, shape=(32, 128), strides=(128, 1),
2869+
block_shape=(16, 64), layout=SHARED_LAYOUT)
2870+
2871+
value = ttgl.full([16, 64], 1.0, ttgl.float16, layout=BLOCKED_LAYOUT)
2872+
buffer = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value)
2873+
2874+
ttgl.amd.gfx1250.tdm.async_store(desc, offsets=[0, 2], src=buffer)
2875+
ttgl.amd.gfx1250.tdm.async_wait(0)
2876+
2877+
2878+
@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
2879+
def test_amd_tdm_store(target):
2880+
2881+
ptr = MockTensor(ttgl.float16)
2882+
module = run_parser(amd_tdm_store_kernel, *make_args(ptr), target)
2883+
expecttest.assert_expected_inline(
2884+
anonymize_ir(module.str_nodebug()), """\
2885+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2886+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
2887+
#smem = #ttg.shared_memory
2888+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2889+
tt.func public @amd_tdm_store_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
2890+
%c32_i32 = arith.constant 32 : i32
2891+
%c128_i32 = arith.constant 128 : i32
2892+
%c128_i64 = arith.constant 128 : i64
2893+
%c1_i64 = arith.constant 1 : i64
2894+
%0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : <f16>, <tensor<16x64xf16, #shared>>
2895+
%cst = arith.constant 1.000000e+00 : f16
2896+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked>
2897+
%1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
2898+
%c0_i32 = arith.constant 0 : i32
2899+
%c2_i32 = arith.constant 2 : i32
2900+
amdgpu.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<16x64xf16, #shared>>
2901+
%2 = amdgpu.async_tdm_wait {num = 0 : i32}
2902+
tt.return
2903+
}
2904+
}
2905+
""")

python/triton/experimental/gluon/language/amd/gfx1250/tdm.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44

55
import triton.experimental.gluon.language._core as ttgl
6-
from triton.experimental.gluon.language._layouts import PaddedSharedLayout
6+
from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout
77
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
88

99
if TYPE_CHECKING:
@@ -20,7 +20,7 @@ class tensor_descriptor_type(ttgl.base_type):
2020
block_type: ttgl.block_type
2121
shape_type: ttgl.tuple_type
2222
strides_type: ttgl.tuple_type
23-
layout: PaddedSharedLayout
23+
layout: PaddedSharedLayout | SwizzledSharedLayout
2424

2525
def __str__(self) -> str:
2626
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
@@ -84,15 +84,15 @@ def layout(self):
8484
@builtin
8585
def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.tensor],
8686
strides: List[ttgl.constexpr | ttgl.tensor], block_shape: List[ttgl.constexpr],
87-
layout: PaddedSharedLayout, _semantic=None) -> tensor_descriptor:
87+
layout: PaddedSharedLayout | SwizzledSharedLayout, _semantic=None) -> tensor_descriptor:
8888
"""Make a tensor descriptor object.
8989
9090
Args:
9191
base (tensor): base pointer of the tensor in global memory.
9292
shape (List[int]): shape of the tensor.
9393
strides (List[int]): strides of the tensor.
9494
block_shape (List[int]): block shape of the tensor.
95-
layout (PaddedSharedLayout): the layout of the tensor in shared memory.
95+
layout (PaddedSharedLayout | SwizzledSharedLayout): the layout of the tensor in shared memory.
9696
9797
Returns:
9898
tensor_descriptor: the created tensor descriptor object
@@ -105,7 +105,10 @@ def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.
105105
assert isinstance(base.dtype, ttgl.pointer_type), "Expected base to be a pointer"
106106

107107
layout = _unwrap_if_constexpr(layout)
108-
assert isinstance(layout, PaddedSharedLayout), "Expected layout to be a PaddedSharedLayout"
108+
assert isinstance(layout, (PaddedSharedLayout, SwizzledSharedLayout)), \
109+
"Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout"
110+
if isinstance(layout, SwizzledSharedLayout):
111+
assert layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout"
109112

110113
base_handle = base.handle
111114
shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) # i32 shape
@@ -137,6 +140,20 @@ def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tenso
137140
_semantic.builder.create_async_tdm_copy_global_to_local(src.handle, offset_handles, dest.handle)
138141

139142

143+
@builtin
144+
def async_store(dest: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], src: shared_memory_descriptor,
145+
_semantic=None) -> None:
146+
"""Store a block of tensor specified in tensor descriptor from shared memory to global memory asynchronously.
147+
148+
Args:
149+
dest (tensor_descriptor): the destination tensor descriptor.
150+
offsets (List[int]): the offsets from the base pointer in the tensor descriptor.
151+
src (shared_memory_descriptor): the shared memory source to load the data.
152+
"""
153+
offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False)
154+
_semantic.builder.create_async_tdm_copy_local_to_global(dest.handle, offset_handles, src.handle)
155+
156+
140157
@builtin
141158
def async_wait(num_outstanding=0, _semantic=None) -> None:
142159
"""Wait for the outstanding asynchronous tensor operations to complete.
Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,51 @@
1-
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=GFX1250
1+
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s
22

33
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
44
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
55
#smem = #ttg.shared_memory
6-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
7-
// GFX1250-LABEL: tdm_kernel
8-
tt.func public @tdm_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
7+
// CHECK-LABEL: tdm_load
8+
tt.func public @tdm_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
99
%c_shape = arith.constant 128 : i32
1010
%c_stride0 = arith.constant 128 : i64
1111
%c_stride1 = arith.constant 1 : i64
1212
%c_offset = arith.constant 0 : i32
1313
%c_pred = arith.constant true
14-
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16>>
14+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
1515
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
16-
// GFX1250-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
17-
// GFX1250-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
18-
// GFX1250: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
19-
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
20-
// GFX1250: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
16+
// CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
17+
// CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
18+
// CHECK: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
19+
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16, #shared>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
20+
// CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
2121
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
2222
%4 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
2323
tt.return
2424
}
2525
}
26+
27+
// -----
28+
29+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
30+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
31+
#smem = #ttg.shared_memory
32+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
33+
// CHECK-LABEL: tdm_store
34+
tt.func public @tdm_store(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
35+
%c_shape = arith.constant 128 : i32
36+
%c_stride0 = arith.constant 128 : i64
37+
%c_stride1 = arith.constant 1 : i64
38+
%c_offset = arith.constant 0 : i32
39+
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
40+
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
41+
%2 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked>
42+
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
43+
// CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
44+
// CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
45+
// CHECK: llvm.amdgcn.tensor.store.from.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
46+
amdgpu.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<64x64xf16, #shared>>
47+
// CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
48+
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
49+
tt.return
50+
}
51+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,9 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local">
707707
This operation copies data from global memory to local memory
708708
asynchronously. This is analogue to tt.load except the data are copied to
709709
local memory pointed by `result` instead of a distributed tensor. The data
710-
copied depends on the global memory descriptor pointed to by `desc`. Set
711-
`pred` to false will disable the copy.
710+
copied depends on the global memory pointed to by `desc`. Set `pred` to
711+
false will disable the copy. This operation does not support shared memory
712+
swizzling.
712713
}];
713714

714715
let arguments = (ins
@@ -724,6 +725,37 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local">
724725
$desc `[` $indices `]` `into` $result `,` $pred
725726
attr-dict `:` qualified(type($desc)) `->` qualified(type($result))
726727
}];
728+
729+
let hasVerifier = 1;
730+
}
731+
732+
//===----------------------------------------------------------------------===//
733+
// AsyncTDMCopyLocalToGlobalOp
734+
//===----------------------------------------------------------------------===//
735+
736+
def AsyncTDMCopyLocalToGlobalOp : TT_AMDGPU_Op<"async_tdm_copy_local_to_global"> {
737+
let summary = "Copy data based on descriptor from local memory to global memory asynchronously";
738+
739+
let description = [{
740+
This operation copies data from local memory to global memory
741+
asynchronously. This is analogue to tt.store except the data are copied from
742+
local memory pointed by `src` instead of a distributed tensor. The copy
743+
destination depends on the global memory pointed to by `desc`. This
744+
operation does not support shared memory padding or swizzling.
745+
}];
746+
747+
let arguments = (ins
748+
Arg<TT_TensorDescType, "", [MemWrite<GlobalMemory>]>:$desc,
749+
Variadic<I32>:$indices,
750+
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
751+
);
752+
753+
let assemblyFormat = [{
754+
$desc `[` $indices `]` `from` $src
755+
attr-dict `:` qualified(type($src)) `->` qualified(type($desc))
756+
}];
757+
758+
let hasVerifier = 1;
727759
}
728760

729761
//===----------------------------------------------------------------------===//

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,4 +607,57 @@ void ConcatOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
607607
patterns.add(foldConcatOpFromSingleSource);
608608
}
609609

610+
LogicalResult AsyncTDMCopyGlobalToLocalOp::verify() {
611+
auto tensorDescTy = getDesc().getType();
612+
auto smemTy = getResult().getType();
613+
614+
auto swizzledEnc =
615+
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
616+
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
617+
return emitOpError("TDM does not support swizzling");
618+
619+
auto paddedEnc =
620+
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
621+
if (!paddedEnc && !swizzledEnc)
622+
return emitOpError("Invalid shared memory layout for TDM");
623+
624+
Type elementType = smemTy.getElementType();
625+
auto elementBitWidth = elementType.getIntOrFloatBitWidth();
626+
if (paddedEnc) {
627+
unsigned dwordSize = 32;
628+
for (auto [interval, padding] :
629+
llvm::zip(paddedEnc.getIntervals(), paddedEnc.getPaddings())) {
630+
auto intervalInDwords = interval * elementBitWidth / dwordSize;
631+
if (intervalInDwords < 2)
632+
return emitOpError("TDM padding interval must be at least 2 dwords");
633+
634+
auto paddingInDwords = padding * elementBitWidth / dwordSize;
635+
if (paddingInDwords < 1)
636+
return emitOpError("TDM padding amount must be at least 1 dword");
637+
}
638+
}
639+
640+
return success();
641+
}
642+
643+
LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() {
644+
auto tensorDescTy = getDesc().getType();
645+
auto smemTy = getSrc().getType();
646+
647+
auto swizzledEnc =
648+
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
649+
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
650+
return emitOpError("TDM does not support swizzling");
651+
652+
auto paddedEnc =
653+
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
654+
if (paddedEnc)
655+
return emitOpError("TDM store does not support padding");
656+
657+
if (!paddedEnc && !swizzledEnc)
658+
return emitOpError("Invalid shared memory layout for TDM");
659+
660+
return success();
661+
}
662+
610663
} // namespace mlir::triton::amdgpu

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_triton_library(TritonAMDGPUToLLVM
2727
Fp4ToFpOpToLLVM.cpp
2828
MembarUtility.cpp
2929
ScalarizePackedFOps.cpp
30+
TDMUtility.cpp
3031

3132
DEPENDS
3233
TritonAMDGPUConversionPassIncGen

0 commit comments

Comments
 (0)