@@ -2816,7 +2816,7 @@ def kernel():
28162816
28172817
28182818@gluon .jit
2819- def amd_tdm_kernel (ptr ):
2819+ def amd_tdm_load_kernel (ptr ):
28202820 SHARED_LAYOUT : ttgl .constexpr = ttgl .PaddedSharedLayout .with_identity_for ([[32 , 4 ]], [16 , 64 ], [1 , 0 ])
28212821 BLOCKED_LAYOUT : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [4 , 8 ], [4 , 1 ], [1 , 0 ])
28222822
@@ -2831,17 +2831,17 @@ def amd_tdm_kernel(ptr):
28312831
28322832
28332833@pytest .mark .parametrize ("target" , [HIP_TARGET_GFX1250 ])
2834- def test_amd_tdm (target ):
2834+ def test_amd_tdm_load (target ):
28352835
28362836 ptr = MockTensor (ttgl .float16 )
2837- module = run_parser (amd_tdm_kernel , * make_args (ptr ), target )
2837+ module = run_parser (amd_tdm_load_kernel , * make_args (ptr ), target )
28382838 expecttest .assert_expected_inline (
28392839 anonymize_ir (module .str_nodebug ()), """\
28402840 #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
28412841#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}>
28422842#smem = #ttg.shared_memory
28432843module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2844- tt.func public @amd_tdm_kernel (%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
2844+ tt.func public @amd_tdm_load_kernel (%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
28452845 %c32_i32 = arith.constant 32 : i32
28462846 %c128_i32 = arith.constant 128 : i32
28472847 %c128_i64 = arith.constant 128 : i64
@@ -2858,3 +2858,48 @@ def test_amd_tdm(target):
28582858 }
28592859}
28602860""" )
2861+
2862+
2863+ @gluon .jit
2864+ def amd_tdm_store_kernel (ptr ):
2865+ SHARED_LAYOUT : ttgl .constexpr = ttgl .SwizzledSharedLayout (1 , 1 , 1 , [1 , 0 ])
2866+ BLOCKED_LAYOUT : ttgl .constexpr = ttgl .BlockedLayout ([1 , 8 ], [4 , 8 ], [4 , 1 ], [1 , 0 ])
2867+
2868+ desc = ttgl .amd .gfx1250 .tdm .make_tensor_descriptor (base = ptr , shape = (32 , 128 ), strides = (128 , 1 ),
2869+ block_shape = (16 , 64 ), layout = SHARED_LAYOUT )
2870+
2871+ value = ttgl .full ([16 , 64 ], 1.0 , ttgl .float16 , layout = BLOCKED_LAYOUT )
2872+ buffer = ttgl .allocate_shared_memory (desc .dtype , desc .block_shape , desc .layout , value )
2873+
2874+ ttgl .amd .gfx1250 .tdm .async_store (desc , offsets = [0 , 2 ], src = buffer )
2875+ ttgl .amd .gfx1250 .tdm .async_wait (0 )
2876+
2877+
2878+ @pytest .mark .parametrize ("target" , [HIP_TARGET_GFX1250 ])
2879+ def test_amd_tdm_store (target ):
2880+
2881+ ptr = MockTensor (ttgl .float16 )
2882+ module = run_parser (amd_tdm_store_kernel , * make_args (ptr ), target )
2883+ expecttest .assert_expected_inline (
2884+ anonymize_ir (module .str_nodebug ()), """\
2885+ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
2886+ #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
2887+ #smem = #ttg.shared_memory
2888+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
2889+ tt.func public @amd_tdm_store_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
2890+ %c32_i32 = arith.constant 32 : i32
2891+ %c128_i32 = arith.constant 128 : i32
2892+ %c128_i64 = arith.constant 128 : i64
2893+ %c1_i64 = arith.constant 1 : i64
2894+ %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : <f16>, <tensor<16x64xf16, #shared>>
2895+ %cst = arith.constant 1.000000e+00 : f16
2896+ %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked>
2897+ %1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
2898+ %c0_i32 = arith.constant 0 : i32
2899+ %c2_i32 = arith.constant 2 : i32
2900+ amdgpu.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<16x64xf16, #shared>>
2901+ %2 = amdgpu.async_tdm_wait {num = 0 : i32}
2902+ tt.return
2903+ }
2904+ }
2905+ """ )
0 commit comments