From 9a4fdec25f5a0b1a7df08fd4edbba7e6cd9081ec Mon Sep 17 00:00:00 2001 From: "Zhao, Pengzhan" Date: Tue, 7 Oct 2025 13:36:41 -0700 Subject: [PATCH 1/5] support tdm store --- python/src/gluon_ir.cc | 6 + python/test/gluon/test_frontend.py | 53 ++++- .../gluon/language/amd/gfx1250/tdm.py | 27 ++- .../Conversion/amd/tritongpu_tdm_to_llvm.mlir | 46 +++- .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 36 ++- .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 53 +++++ .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 1 + .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 218 +++++------------- .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp | 154 +++++++++++++ .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.h | 34 +++ .../TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp | 15 +- .../amd/python/test/test_gluon_gfx1250.py | 118 ++++++++-- 12 files changed, 550 insertions(+), 211 deletions(-) create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index a2418d8e6bcc..4ea1d0c2d6db 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -777,6 +777,12 @@ void init_gluon_ir(py::module &&m) { self.create(descPtr, indices, result, pred); }) + .def("create_async_tdm_copy_local_to_global", + [](GluonOpBuilder &self, Value descPtr, std::vector &indices, + Value src) { + self.create(descPtr, indices, + src); + }) .def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) { ValueRange tokens; self.create(tokens, num); diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 71e5a1a3704b..473f79a652cb 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -2733,7 +2733,7 @@ def kernel(): @gluon.jit -def amd_tdm_kernel(ptr): +def amd_tdm_load_kernel(ptr): SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [16, 64], [1, 0]) BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) @@ -2748,17 +2748,17 @@ def amd_tdm_kernel(ptr): @pytest.mark.parametrize("target", [HIP_TARGET_GFX1250]) -def test_amd_tdm(target): +def test_amd_tdm_load(target): ptr = MockTensor(ttgl.float16) - module = run_parser(amd_tdm_kernel, *make_args(ptr), target) + module = run_parser(amd_tdm_load_kernel, *make_args(ptr), target) expecttest.assert_expected_inline( anonymize_ir(module.str_nodebug()), """\ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @amd_tdm_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + tt.func public @amd_tdm_load_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c32_i32 = arith.constant 32 : i32 %c128_i32 = arith.constant 128 : i32 %c128_i64 = arith.constant 128 : i64 @@ -2775,3 +2775,48 @@ def test_amd_tdm(target): } } """) + + +@gluon.jit +def amd_tdm_store_kernel(ptr): + SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + + desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=ptr, shape=(32, 128), strides=(128, 1), + block_shape=(16, 64), layout=SHARED_LAYOUT) + + value = ttgl.full([16, 64], 1.0, ttgl.float16, layout=BLOCKED_LAYOUT) + buffer = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value) + + ttgl.amd.gfx1250.tdm.async_store(desc, offsets=[0, 2], src=buffer) + ttgl.amd.gfx1250.tdm.async_wait(0) + + +@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250]) +def test_amd_tdm_store(target): + + ptr = MockTensor(ttgl.float16) + module = run_parser(amd_tdm_store_kernel, *make_args(ptr), target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @amd_tdm_store_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %c128_i32 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : , > + %cst = arith.constant 1.000000e+00 : f16 + %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked> + %1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + amdgpu.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + %2 = amdgpu.async_tdm_wait {num = 0 : i32} + tt.return + } +} +""") diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py b/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py index e50d1e25a344..41520b59d698 100644 --- a/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py +++ b/python/triton/experimental/gluon/language/amd/gfx1250/tdm.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import triton.experimental.gluon.language._core as ttgl -from triton.experimental.gluon.language._layouts import PaddedSharedLayout +from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr if TYPE_CHECKING: @@ -20,7 +20,7 @@ class tensor_descriptor_type(ttgl.base_type): block_type: ttgl.block_type shape_type: ttgl.tuple_type strides_type: ttgl.tuple_type - layout: PaddedSharedLayout + layout: PaddedSharedLayout | SwizzledSharedLayout def __str__(self) -> str: return f"tensor_descriptor<{self.block_type}, {self.layout}>" @@ -84,7 +84,7 @@ def layout(self): @builtin def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.tensor], strides: List[ttgl.constexpr | ttgl.tensor], block_shape: List[ttgl.constexpr], - layout: PaddedSharedLayout, _semantic=None) -> tensor_descriptor: + layout: PaddedSharedLayout | SwizzledSharedLayout, _semantic=None) -> tensor_descriptor: """Make a tensor descriptor object. Args: @@ -92,7 +92,7 @@ def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl. shape (List[int]): shape of the tensor. strides (List[int]): strides of the tensor. block_shape (List[int]): block shape of the tensor. - layout (PaddedSharedLayout): the layout of the tensor in shared memory. + layout (PaddedSharedLayout | SwizzledSharedLayout): the layout of the tensor in shared memory. Returns: tensor_descriptor: the created tensor descriptor object @@ -105,7 +105,10 @@ def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl. assert isinstance(base.dtype, ttgl.pointer_type), "Expected base to be a pointer" layout = _unwrap_if_constexpr(layout) - assert isinstance(layout, PaddedSharedLayout), "Expected layout to be a PaddedSharedLayout" + assert isinstance(layout, (PaddedSharedLayout, SwizzledSharedLayout)), \ + "Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout" + if isinstance(layout, SwizzledSharedLayout): + assert layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout" base_handle = base.handle shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) # i32 shape @@ -137,6 +140,20 @@ def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tenso _semantic.builder.create_async_tdm_copy_global_to_local(src.handle, offset_handles, dest.handle) +@builtin +def async_store(dest: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], src: shared_memory_descriptor, + _semantic=None) -> None: + """Store a block of tensor specified in tensor descriptor from shared memory to global memory asynchronously. + + Args: + dest (tensor_descriptor): the destination tensor descriptor. + offsets (List[int]): the offsets from the base pointer in the tensor descriptor. + src (shared_memory_descriptor): the shared memory source to load the data. + """ + offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False) + _semantic.builder.create_async_tdm_copy_local_to_global(dest.handle, offset_handles, src.handle) + + @builtin def async_wait(num_outstanding=0, _semantic=None) -> None: """Wait for the outstanding asynchronous tensor operations to complete. diff --git a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir index 9dc27a6e165f..79597d55a918 100644 --- a/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_tdm_to_llvm.mlir @@ -1,25 +1,51 @@ -// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=GFX1250 +// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}> #smem = #ttg.shared_memory -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { - // GFX1250-LABEL: tdm_kernel - tt.func public @tdm_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tdm_load + tt.func public @tdm_load(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c_shape = arith.constant 128 : i32 %c_stride0 = arith.constant 128 : i64 %c_stride1 = arith.constant 1 : i64 %c_offset = arith.constant 0 : i32 %c_pred = arith.constant true - %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> - // GFX1250-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32> - // GFX1250-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32> - // GFX1250: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> () - %2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> - // GFX1250: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> () + // CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32> + // CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32> + // CHECK: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> () + %2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + // CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> () %3 = amdgpu.async_tdm_wait {num = 0 : i32} %4 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked> tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tdm_store + tt.func public @tdm_store(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c_shape = arith.constant 128 : i32 + %c_stride0 = arith.constant 128 : i64 + %c_stride1 = arith.constant 1 : i64 + %c_offset = arith.constant 0 : i32 + %0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + %2 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked> + ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable> + // CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32> + // CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32> + // CHECK: llvm.amdgcn.tensor.store.from.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> () + amdgpu.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc> + // CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> () + %3 = amdgpu.async_tdm_wait {num = 0 : i32} + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 04cc779c7eff..f7a0116ea1c6 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -707,8 +707,9 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local"> This operation copies data from global memory to local memory asynchronously. This is analogue to tt.load except the data are copied to local memory pointed by `result` instead of a distributed tensor. The data - copied depends on the global memory descriptor pointed to by `desc`. Set - `pred` to false will disable the copy. + copied depends on the global memory pointed to by `desc`. Set `pred` to + false will disable the copy. This operation does not support shared memory + swizzling. }]; let arguments = (ins @@ -724,6 +725,37 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local"> $desc `[` $indices `]` `into` $result `,` $pred attr-dict `:` qualified(type($desc)) `->` qualified(type($result)) }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// AsyncTDMCopyLocalToGlobalOp +//===----------------------------------------------------------------------===// + +def AsyncTDMCopyLocalToGlobalOp : TT_AMDGPU_Op<"async_tdm_copy_local_to_global"> { + let summary = "Copy data based on descriptor from local memory to global memory asynchronously"; + + let description = [{ + This operation copies data from local memory to global memory + asynchronously. This is analogue to tt.store except the data are copied from + local memory pointed by `src` instead of a distributed tensor. The copy + destination depends on the global memory pointed to by `desc`. This + operation does not support shared memory padding or swizzling. + }]; + + let arguments = (ins + Arg, MemWrite]>:$desc, + Variadic:$indices, + Arg]>:$src + ); + + let assemblyFormat = [{ + $desc `[` $indices `]` `from` $src + attr-dict `:` qualified(type($src)) `->` qualified(type($desc)) + }]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index e68389625cd7..070f61f43048 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -607,4 +607,57 @@ void ConcatOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, patterns.add(foldConcatOpFromSingleSource); } +LogicalResult AsyncTDMCopyGlobalToLocalOp::verify() { + auto tensorDescTy = getDesc().getType(); + auto smemTy = getResult().getType(); + + auto swizzledEnc = + llvm::dyn_cast(smemTy.getEncoding()); + if (swizzledEnc && swizzledEnc.getMaxPhase() != 1) + return emitOpError("TDM does not support swizzling"); + + auto paddedEnc = + llvm::dyn_cast(smemTy.getEncoding()); + if (!paddedEnc && !swizzledEnc) + return emitOpError("Invalid shared memory layout for TDM"); + + Type elementType = smemTy.getElementType(); + auto elementBitWidth = elementType.getIntOrFloatBitWidth(); + if (paddedEnc) { + unsigned dwordSize = 32; + for (auto [interval, padding] : + llvm::zip(paddedEnc.getIntervals(), paddedEnc.getPaddings())) { + auto intervalInDwords = interval * elementBitWidth / dwordSize; + if (intervalInDwords < 2) + return emitOpError("TDM padding interval must be at least 2 dwords"); + + auto paddingInDwords = padding * elementBitWidth / dwordSize; + if (paddingInDwords < 1) + return emitOpError("TDM padding amount must be at least 1 dword"); + } + } + + return success(); +} + +LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() { + auto tensorDescTy = getDesc().getType(); + auto smemTy = getSrc().getType(); + + auto swizzledEnc = + llvm::dyn_cast(smemTy.getEncoding()); + if (swizzledEnc && swizzledEnc.getMaxPhase() != 1) + return emitOpError("TDM does not support swizzling"); + + auto paddedEnc = + llvm::dyn_cast(smemTy.getEncoding()); + if (paddedEnc) + return emitOpError("TDM store does not support padding"); + + if (!paddedEnc && !swizzledEnc) + return emitOpError("Invalid shared memory layout for TDM"); + + return success(); +} + } // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 32534d049b03..e08889a529d9 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -27,6 +27,7 @@ add_triton_library(TritonAMDGPUToLLVM Fp4ToFpOpToLLVM.cpp MembarUtility.cpp ScalarizePackedFOps.cpp + TDMUtility.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index c1e9a9e53ce2..7297c3246363 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -3,6 +3,7 @@ #include "BufferOpsEmitter.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "PatternTritonGPUOpToLLVM.h" +#include "TDMUtility.h" #include "TargetInfo.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -1006,81 +1007,6 @@ struct AsyncTDMCopyGlobalToLocalOpConversion : ConvertOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} - std::pair createTDMDescriptors( - RewriterBase &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, int64_t elementSizeInBytes, - ArrayRef tensorShape, ArrayRef blockShape, - ArrayRef tensorStride, Value srcPtr, Value dstPtr, Value pred, - Value multicastMask, unsigned padIntervalInDwords, - unsigned padAmountInDwords) const { - assert(tensorShape.size() == 2 && tensorStride.size() == 2 && - blockShape.size() == 2); - auto b = TritonLLVMOpBuilder(loc, rewriter); - - Value ldsAddr = b.ptrtoint(i32_ty, dstPtr); - Value globalAddr = b.ptrtoint(i64_ty, srcPtr); - - // group0 (128 bits / 4 dwords) effective bit encoding: - // [1:0]: pred - // [63:32]: lds address - // [120:64]: global address - // [127:126]: type - currently always set to 0x2 - SmallVector group0(4, b.i32_val(0)); - group0[0] = b.zext(i32_ty, pred); - group0[1] = ldsAddr; - group0[2] = b.trunc(i32_ty, globalAddr); - group0[3] = b.trunc(i32_ty, b.lshr(globalAddr, b.i64_val(32))); - group0[3] = b.or_(group0[3], b.i32_val(0x80000000)); - - VectorType vecTy0 = vec_ty(i32_ty, 4); - Value group0Vec = b.undef(vecTy0); - for (unsigned ii = 0; ii < 4; ++ii) { - Value vecIdx = createIndexAttrConstant(rewriter, loc, - typeConverter->getIndexType(), ii); - group0Vec = b.insert_element(vecTy0, group0Vec, group0[ii], vecIdx); - } - - // group1 (256 bits / 8 dwords) effective bit encoding: - // [15:0]: multicast mask - // [17:16]: data size - log2(element size in bytes) - // [20]: enable padding - // [24:22]: pad interval - log2(pad interval in dwords) - 1 - // [31:25]: pad amount - pad amount in dwords - 1 - // [79:48]: tensor shape dim inner - // [111:80]: tensor shape dim outer - // [127:112]: block shape dim inner - // [143:128]: block shape dim outer - // [207:160]: tensor stride dim outer (we only use 32 bits) - SmallVector group1(8, b.i32_val(0)); - int32_t dataSize = log2(elementSizeInBytes); - group1[0] = multicastMask; - group1[0] = b.or_(group1[0], b.i32_val(dataSize << 16)); - if (padIntervalInDwords > 0 && padAmountInDwords > 0) { - assert(llvm::isPowerOf2_32(padIntervalInDwords)); - int32_t log2PadInterval = log2(padIntervalInDwords); - group1[0] = b.or_(group1[0], b.i32_val(1 << 20)); - group1[0] = b.or_(group1[0], b.i32_val((log2PadInterval - 1) << 22)); - group1[0] = b.or_(group1[0], b.i32_val((padAmountInDwords - 1) << 25)); - } - group1[1] = b.shl(tensorShape[1], b.i32_val(16)); - group1[2] = b.lshr(tensorShape[1], b.i32_val(16)); - group1[2] = b.or_(group1[2], b.shl(tensorShape[0], b.i32_val(16))); - group1[3] = b.lshr(tensorShape[0], b.i32_val(16)); - group1[3] = b.or_(group1[3], b.i32_val(blockShape[1] << 16)); - group1[4] = b.i32_val(blockShape[0] & 0xFFFF); - group1[5] = tensorStride[0]; - - VectorType vecTy1 = vec_ty(i32_ty, 8); - Value group1Vec = b.undef(vecTy1); - for (unsigned ii = 0; ii < 8; ++ii) { - Value vecIdx = createIndexAttrConstant(rewriter, loc, - typeConverter->getIndexType(), ii); - group1Vec = b.insert_element(vecTy1, group1Vec, group1[ii], vecIdx); - } - - return {group0Vec, group1Vec}; - } - LogicalResult matchAndRewrite(triton::amdgpu::AsyncTDMCopyGlobalToLocalOp op, OpAdaptor adaptor, @@ -1089,23 +1015,12 @@ struct AsyncTDMCopyGlobalToLocalOpConversion auto loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); - auto mod = op->getParentOfType(); auto tensorDescTy = op.getDesc().getType(); auto smemTy = op.getResult().getType(); - - auto swizzledEnc = - llvm::dyn_cast(smemTy.getEncoding()); - if (swizzledEnc && swizzledEnc.getMaxPhase() != 1) - return rewriter.notifyMatchFailure(op, "TDM does not support swizzling"); - auto paddedEnc = llvm::dyn_cast(smemTy.getEncoding()); - if (!paddedEnc && !swizzledEnc) - return rewriter.notifyMatchFailure( - op, "Invalid shared memory layout for TDM."); - - Type llvmElemTy = getTypeConverter()->convertType(smemTy.getElementType()); - auto elementBitWidth = llvmElemTy.getIntOrFloatBitWidth(); + Type elementType = getTypeConverter()->convertType(smemTy.getElementType()); + auto elementBitWidth = elementType.getIntOrFloatBitWidth(); unsigned padInterval = 0; unsigned padAmount = 0; @@ -1117,86 +1032,74 @@ struct AsyncTDMCopyGlobalToLocalOpConversion padInterval = paddedEnc.getIntervals()[0]; padAmount = paddedEnc.getPaddings()[0]; } - unsigned dwordSize = 32; - auto padIntervalInDwords = padInterval * elementBitWidth / dwordSize; - auto padAmountInDwords = padAmount * elementBitWidth / dwordSize; - if (padInterval > 0 && padIntervalInDwords < 2) - return rewriter.notifyMatchFailure( - op, "TDM padding interval must be at least 2 dwords"); - if (padAmount > 0 && padAmountInDwords < 1) - return rewriter.notifyMatchFailure( - op, "TDM padding amount must be at least 1 dword"); - - // [base, shape0, shape1, stride0, stride1] - SmallVector descriptorFields = - unpackLLElements(loc, adaptor.getDesc(), rewriter); - if (descriptorFields.size() != 5) - return rewriter.notifyMatchFailure(op, "NYI: TDM > 2D cases."); - - Value base = descriptorFields[0]; - SmallVector tensorShape{descriptorFields[1], descriptorFields[2]}; - SmallVector tensorStride{descriptorFields[3], descriptorFields[4]}; - // Cast strides from i64 to i32 - tensorStride[0] = b.trunc(i32_ty, tensorStride[0]); - tensorStride[1] = b.trunc(i32_ty, tensorStride[1]); + auto mod = op->getParentOfType(); + int numCTAs = TritonGPUDialect::getNumCTAs(mod); + if (numCTAs > 1) + return rewriter.notifyMatchFailure(op, "NYI: Support multicast."); - SmallVector offset = adaptor.getIndices(); SmallVector blockShape = llvm::to_vector(tensorDescTy.getBlockType().getShape()); - SmallVector blockShapePerCTA = blockShape; - - int numCTAs = TritonGPUDialect::getNumCTAs(mod); - Value multicastMask = b.i32_val(0); - if (numCTAs > 1) { - return rewriter.notifyMatchFailure(op, "NYI: Support multicast."); - } + auto [srcPtr, tensorShape, tensorStride] = + LLVM::AMD::unpackTensorDesc(rewriter, loc, adaptor.getDesc()); + auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( + loc, adaptor.getResult(), elementType, rewriter); + Value dstPtr = dstMemObj.getBase(); + SmallVector offset = adaptor.getIndices(); + int numWraps = triton::gpu::lookupNumWarps(op); - Type globalPtrTy = ptr_ty(ctx, 1); - Type sharedPtrTy = ptr_ty(ctx, 3); + auto [group0, group1] = LLVM::AMD::createTDMDescriptor( + rewriter, loc, getTypeConverter(), elementType, blockShape, tensorShape, + tensorStride, offset, srcPtr, dstPtr, op.getPred(), numWraps, + padInterval, padAmount); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + "llvm.amdgcn.tensor.load.to.lds.d2", {}, + {group0, group1, b.i32_val(0)}); - // For block shape [M, N], each warp will handle shape [M/numWarps, N]. - auto numWarps = triton::gpu::lookupNumWarps(op); - auto warpId = getLaneAndWarpId(rewriter, loc).second; + rewriter.eraseOp(op); + return success(); + } +}; - int outerBlockShape = blockShapePerCTA[0]; - int outerBlockShapePerWarp = ceil(outerBlockShape, numWarps); - int outerBlockStride = blockShapePerCTA[1]; +struct AsyncTDMCopyLocalToGlobalOpConversion + : public ConvertOpToLLVMPattern< + triton::amdgpu::AsyncTDMCopyLocalToGlobalOp>, + public LoadStoreConversionBase { + AsyncTDMCopyLocalToGlobalOpConversion( + LLVMTypeConverter &converter, const AMD::TargetInfo &targetInfo, + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + LoadStoreConversionBase(targetInfo, axisAnalysisPass) {} - // Shift global pointer by offset - Value outerOffset = b.mul(b.i32_val(outerBlockShapePerWarp), warpId); - offset[0] = b.add(offset[0], outerOffset); + LogicalResult + matchAndRewrite(triton::amdgpu::AsyncTDMCopyLocalToGlobalOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto ctx = rewriter.getContext(); + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); - Value baseOffset = b.add(b.mul(tensorStride[0], offset[0]), - b.mul(tensorStride[1], offset[1])); - base = b.gep(globalPtrTy, llvmElemTy, base, baseOffset); + auto tensorDescTy = op.getDesc().getType(); + auto smemTy = op.getSrc().getType(); + Type elementType = getTypeConverter()->convertType(smemTy.getElementType()); - // Shift shared pointer by offset + SmallVector blockShape = + llvm::to_vector(tensorDescTy.getBlockType().getShape()); + auto [srcPtr, tensorShape, tensorStride] = + LLVM::AMD::unpackTensorDesc(rewriter, loc, adaptor.getDesc()); auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( - loc, adaptor.getResult(), llvmElemTy, rewriter); - Value dstBase = dstMemObj.getBase(); - Value dstOffset = b.mul(b.i32_val(outerBlockStride), outerOffset); - if (paddedEnc) { - Value padding = emitPadding(loc, rewriter, paddedEnc, elementBitWidth, - dstOffset, false); - dstOffset = b.add(dstOffset, padding); - } - dstBase = b.gep(sharedPtrTy, llvmElemTy, dstBase, dstOffset); - - // Update tensor shape and block shape based on offset - Value zero = b.i32_val(0); - tensorShape[0] = b.smax(zero, b.sub(tensorShape[0], offset[0])); - tensorShape[1] = b.smax(zero, b.sub(tensorShape[1], offset[1])); - - blockShapePerCTA[0] = outerBlockShapePerWarp; + loc, adaptor.getSrc(), elementType, rewriter); + Value dstPtr = dstMemObj.getBase(); + SmallVector offset = adaptor.getIndices(); + int numWraps = triton::gpu::lookupNumWarps(op); + Value pred = b.true_val(); - auto elementSizeInBytes = elementBitWidth / 8; - auto [group0, group1] = createTDMDescriptors( - rewriter, loc, getTypeConverter(), elementSizeInBytes, tensorShape, - blockShapePerCTA, tensorStride, base, dstBase, op.getPred(), - multicastMask, padIntervalInDwords, padAmountInDwords); + auto [group0, group1] = LLVM::AMD::createTDMDescriptor( + rewriter, loc, getTypeConverter(), elementType, blockShape, tensorShape, + tensorStride, offset, srcPtr, dstPtr, pred, numWraps, + /*padInterval=*/0, /*padAmount=*/0); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, - "llvm.amdgcn.tensor.load.to.lds.d2", {}, + "llvm.amdgcn.tensor.store.from.lds.d2", {}, {group0, group1, b.i32_val(0)}); rewriter.eraseOp(op); @@ -2030,8 +1933,9 @@ void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, StoreOpConversion, BufferLoadOpConversion, BufferLoadToLocalOpConversion, BufferStoreOpConversion, BufferAtomicRMWOpConversion, AsyncCopyGlobalToLocalOpConversion, - AsyncTDMCopyGlobalToLocalOpConversion, BufferAtomicCASOpConversion>( - typeConverter, targetInfo, axisInfoAnalysis, benefit); + BufferAtomicCASOpConversion, AsyncTDMCopyGlobalToLocalOpConversion, + AsyncTDMCopyLocalToGlobalOpConversion>(typeConverter, targetInfo, + axisInfoAnalysis, benefit); patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp new file mode 100644 index 000000000000..b7f41234c819 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -0,0 +1,154 @@ +#include "TDMUtility.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir::LLVM::AMD { + +std::pair +createTDMDescriptor(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type elementType, + SmallVector blockShape, + SmallVector tensorShape, + SmallVector tensorStride, SmallVector offset, + Value srcPtr, Value dstPtr, Value pred, int numWarps, + unsigned padInterval, unsigned padAmount) { + assert(tensorShape.size() == 2 && tensorStride.size() == 2 && + blockShape.size() == 2 && offset.size() == 2 && + "NYI: TDM > 2D cases."); + auto ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + auto elementBitWidth = elementType.getIntOrFloatBitWidth(); + auto elementSizeInBytes = elementBitWidth / 8; + + Type globalPtrTy = ptr_ty(ctx, 1); + Type sharedPtrTy = ptr_ty(ctx, 3); + + // Cast strides from i64 to i32 + tensorStride[0] = b.trunc(i32_ty, tensorStride[0]); + tensorStride[1] = b.trunc(i32_ty, tensorStride[1]); + + // For block shape [M, N], each warp will handle shape [M/numWarps, N]. + auto warpId = getLaneAndWarpId(rewriter, loc).second; + int outerBlockShape = blockShape[0]; + int outerBlockShapePerWarp = ceil(outerBlockShape, numWarps); + int outerBlockStride = blockShape[1]; + + // Shift global pointer by offset + Value outerOffset = b.mul(b.i32_val(outerBlockShapePerWarp), warpId); + offset[0] = b.add(offset[0], outerOffset); + + Value baseOffset = b.add(b.mul(tensorStride[0], offset[0]), + b.mul(tensorStride[1], offset[1])); + srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset); + + // Shift shared pointer by offset + Value dstOffset = b.mul(b.i32_val(outerBlockStride), outerOffset); + if (padInterval > 0 && padAmount > 0) { + Value iVal = b.i32_val(log2(padInterval)); + Value pVal = b.i32_val(log2(padAmount)); + Value padOffset = b.shl(i32_ty, b.ashr(dstOffset, iVal), pVal); + dstOffset = b.add(dstOffset, padOffset); + } + dstPtr = b.gep(sharedPtrTy, elementType, dstPtr, dstOffset); + + // Update tensor shape and block shape based on offset + Value zero = b.i32_val(0); + tensorShape[0] = b.smax(zero, b.sub(tensorShape[0], offset[0])); + tensorShape[1] = b.smax(zero, b.sub(tensorShape[1], offset[1])); + + blockShape[0] = outerBlockShapePerWarp; + + // group0 (128 bits / 4 dwords) effective bit encoding: + // [1:0]: pred + // [63:32]: lds address + // [120:64]: global address + // [127:126]: type - currently always set to 0x2 + SmallVector group0(4, b.i32_val(0)); + Value globalAddr = b.ptrtoint(i64_ty, srcPtr); + Value ldsAddr = b.ptrtoint(i32_ty, dstPtr); + group0[0] = b.zext(i32_ty, pred); + group0[1] = ldsAddr; + group0[2] = b.trunc(i32_ty, globalAddr); + group0[3] = b.trunc(i32_ty, b.lshr(globalAddr, b.i64_val(32))); + group0[3] = b.or_(group0[3], b.i32_val(0x80000000)); + + VectorType vecTy0 = vec_ty(i32_ty, 4); + Value group0Vec = b.undef(vecTy0); + for (unsigned ii = 0; ii < 4; ++ii) { + Value vecIdx = rewriter.create( + loc, typeConverter->getIndexType(), rewriter.getI32IntegerAttr(ii)); + group0Vec = b.insert_element(vecTy0, group0Vec, group0[ii], vecIdx); + } + + // group1 (256 bits / 8 dwords) effective bit encoding: + // [15:0]: multicast mask + // [17:16]: data size - log2(element size in bytes) + // [20]: enable padding + // [24:22]: pad interval - log2(pad interval in dwords) - 1 + // [31:25]: pad amount - pad amount in dwords - 1 + // [79:48]: tensor shape dim inner + // [111:80]: tensor shape dim outer + // [127:112]: block shape dim inner + // [143:128]: block shape dim outer + // [207:160]: tensor stride dim outer (we only use 32 bits) + SmallVector group1(8, b.i32_val(0)); + int32_t dataSize = log2(elementSizeInBytes); + unsigned dwordSize = 32; + auto padIntervalInDwords = padInterval * elementBitWidth / dwordSize; + auto padAmountInDwords = padAmount * elementBitWidth / dwordSize; + group1[0] = b.or_(group1[0], b.i32_val(dataSize << 16)); + if (padIntervalInDwords > 0 && padAmountInDwords > 0) { + assert(llvm::isPowerOf2_32(padIntervalInDwords)); + int32_t log2PadInterval = log2(padIntervalInDwords); + group1[0] = b.or_(group1[0], b.i32_val(1 << 20)); + group1[0] = b.or_(group1[0], b.i32_val((log2PadInterval - 1) << 22)); + group1[0] = b.or_(group1[0], b.i32_val((padAmountInDwords - 1) << 25)); + } + group1[1] = b.shl(tensorShape[1], b.i32_val(16)); + group1[2] = b.lshr(tensorShape[1], b.i32_val(16)); + group1[2] = b.or_(group1[2], b.shl(tensorShape[0], b.i32_val(16))); + group1[3] = b.lshr(tensorShape[0], b.i32_val(16)); + group1[3] = b.or_(group1[3], b.i32_val(blockShape[1] << 16)); + group1[4] = b.i32_val(blockShape[0] & 0xFFFF); + group1[5] = tensorStride[0]; + + VectorType vecTy1 = vec_ty(i32_ty, 8); + Value group1Vec = b.undef(vecTy1); + for (unsigned ii = 0; ii < 8; ++ii) { + Value vecIdx = rewriter.create( + loc, typeConverter->getIndexType(), rewriter.getIndexAttr(ii)); + group1Vec = b.insert_element(vecTy1, group1Vec, group1[ii], vecIdx); + } + + return {group0Vec, group1Vec}; +} + +Value packTensorDesc(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Value base, + ValueRange tensorShape, ValueRange tensorStride, + Type resultTy) { + SmallVector elems; + + elems.push_back(base); + llvm::append_range(elems, tensorShape); + llvm::append_range(elems, tensorStride); + return packLLElements(loc, typeConverter, elems, rewriter, resultTy); +} + +std::tuple, SmallVector> +unpackTensorDesc(RewriterBase &rewriter, Location loc, Value desc) { + SmallVector descriptorFields = unpackLLElements(loc, desc, rewriter); + auto length = descriptorFields.size(); + assert(length >= 5 && "invalid tensor descriptor"); + + Value base = descriptorFields[0]; + SmallVector tensorShape; + SmallVector tensorStride; + for (int i = 1; i < (length - 1) / 2 + 1; i++) + tensorShape.push_back(descriptorFields[i]); + for (int i = (length - 1) / 2 + 1; i < length; i++) + tensorStride.push_back(descriptorFields[i]); + return {base, tensorShape, tensorStride}; +} + +} // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h new file mode 100644 index 000000000000..6432eb044b3a --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h @@ -0,0 +1,34 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TDMUTILITY_H +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TDMUTILITY_H + +#include "TargetInfo.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" + +using mlir::triton::AMD::TargetInfo; + +namespace mlir::LLVM::AMD { + +// Create a TDM descriptor, divided into 2 groups. +std::pair createTDMDescriptor( + RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type elementType, + SmallVector blockShape, SmallVector tensorShape, + SmallVector tensorStride, SmallVector tensorOffset, + Value srcPtr, Value dstPtr, Value pred, int numWarps, unsigned padInterval, + unsigned padAmount); + +// Pack base pointer, shape, and stride from a tensor descriptor into a single +// llvm struct value. +Value packTensorDesc(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Value base, + ValueRange tensorShape, ValueRange tensorStride, + Type resultTy); + +// Unpack a tensor descriptor from a single llvm struct value into +// (base, [shape0, shape1, ...], [stride0, stride1, ...]). +std::tuple, SmallVector> +unpackTensorDesc(RewriterBase &rewriter, Location loc, Value desc); + +} // namespace mlir::LLVM::AMD + +#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_TDMUTILITY_H diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index 5523c1f7ca5e..170dba874702 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -1,4 +1,5 @@ #include "PatternTritonGPUOpToLLVM.h" +#include "TDMUtility.h" #include "Utility.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/IR/BuiltinTypes.h" @@ -19,19 +20,15 @@ struct MakeTensorDescOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto basePtr = adaptor.getBase(); auto tensorShape = adaptor.getShape(); auto tensorStride = adaptor.getStrides(); - auto basePtr = adaptor.getBase(); auto result = op.getResult(); - SmallVector elems; - elems.push_back(basePtr); - llvm::append_range(elems, tensorShape); - llvm::append_range(elems, tensorStride); - - auto newValue = packLLElements(op.getLoc(), getTypeConverter(), elems, - rewriter, result.getType()); - rewriter.replaceOp(op, newValue); + Value desc = + LLVM::AMD::packTensorDesc(rewriter, loc, getTypeConverter(), basePtr, + tensorShape, tensorStride, result.getType()); + rewriter.replaceOp(op, desc); return success(); } }; diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index 674ff8fc8f3e..ce57ef41ec29 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -364,9 +364,8 @@ def torch_gemm_mxfp(a, b, a_scale, b_scale, scale_block, M, N, K): @gluon.jit -def tensor_copy_kernel(a_ptr, b_ptr, # - M, N, # - BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr): +def tensor_copy_kernel(a_ptr, b_ptr, M, N, # + BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, NUM_BUFFERS: ttgl.constexpr): SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [BLOCK_M, BLOCK_N], [1, 0]) BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) @@ -377,53 +376,124 @@ def tensor_copy_kernel(a_ptr, b_ptr, # a_desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=a_ptr, shape=(M, N), strides=(N, 1), block_shape=(BLOCK_M, BLOCK_N), layout=SHARED_LAYOUT) + a_buffer = ttgl.allocate_shared_memory(a_desc.dtype, [NUM_BUFFERS] + a_desc.block_shape, a_desc.layout) - a_buffer = ttgl.allocate_shared_memory(a_desc.dtype, shape=a_desc.block_shape, layout=a_desc.layout) - ttgl.amd.gfx1250.tdm.async_load(a_desc, [pid_m * BLOCK_M, pid_n * BLOCK_N], a_buffer) + idx_m = pid_m * BLOCK_M + for i in ttgl.static_range(0, NUM_BUFFERS): + idx_n = pid_n * (BLOCK_N * NUM_BUFFERS) + i * BLOCK_N + ttgl.amd.gfx1250.tdm.async_load(a_desc, [idx_m, idx_n], a_buffer.index(i)) ttgl.amd.gfx1250.tdm.async_wait(0) - a = a_buffer.load(layout=BLOCKED_LAYOUT) - b_offsets = (pid_m * BLOCK_M + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT)))[:, None] * N + \ - (pid_n * BLOCK_N + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT)))[None, :] - ttgl.store(b_ptr + b_offsets, a) + for i in ttgl.static_range(0, NUM_BUFFERS): + idx_n = pid_n * (BLOCK_N * NUM_BUFFERS) + i * BLOCK_N + a = a_buffer.index(i).load(layout=BLOCKED_LAYOUT) + + offs_bm = idx_m + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT)) + offs_bn = idx_n + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT)) + offs_b = (offs_bm[:, None] * N) + offs_bn[None, :] + mask_b = (offs_bm[:, None] < M) & (offs_bn[None, :] < N) + + ttgl.store(b_ptr + offs_b, a, mask=mask_b) @pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)]) -def test_compile_tensor_copy(BLOCK_M, BLOCK_N): +@pytest.mark.parametrize("NUM_BUFFERS", [1, 2]) +def test_compile_tensor_copy(BLOCK_M, BLOCK_N, NUM_BUFFERS): k = triton.compile( gluon._runtime.GluonASTSource( fn=tensor_copy_kernel, signature={ - "a_ptr": "*bf16", "b_ptr": "*bf16", "M": "i32", "N": "i32", "BLOCK_M": "constexpr", "BLOCK_N": - "constexpr" - }, constexprs={"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N}), target=GPUTarget("hip", 'gfx1250', 32)) + "a_ptr": "*fp16", "b_ptr": "*fp16", "M": "i32", "N": "i32", # + "BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "NUM_BUFFERS": "constexpr" + }, constexprs={"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "NUM_BUFFERS": NUM_BUFFERS}), + target=GPUTarget("hip", 'gfx1250', 32)) amdgcn = k.asm["amdgcn"] - - tensor_pattern = r"tensor_load_to_lds" - assert re.search(tensor_pattern, amdgcn) - - wait_pattern = r"s_wait_tensorcnt 0x0" - assert re.search(wait_pattern, amdgcn) + for pattern in ("tensor_load_to_lds", "s_wait_tensorcnt 0x0"): + assert re.search(pattern, amdgcn) @pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)]) -def test_runtime_tensor_copy(BLOCK_M, BLOCK_N): - M, N = 1024, 1024 - +@pytest.mark.parametrize("NUM_BUFFERS", [1, 2]) +@pytest.mark.parametrize("M,N", [(1024, 1024), (1000, 1000)]) +def test_runtime_tensor_copy(M, N, BLOCK_M, BLOCK_N, NUM_BUFFERS): torch.manual_seed(42) a = torch.randint(0x0, 0xFFFF, (M, N), dtype=torch.uint16) b = torch.zeros_like(a) a_device = a.cuda() b_device = b.cuda() - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) - tensor_copy_kernel[grid](a_device, b_device, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N) + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N * NUM_BUFFERS), 1) + tensor_copy_kernel[grid](a_device, b_device, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, NUM_BUFFERS=NUM_BUFFERS) b_triton = b_device.cpu() assert torch.equal(b_triton, a) +@gluon.jit +def tensor_fill_kernel(a_ptr, M, N, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr, NUM_BUFFERS: ttgl.constexpr): + SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0]) + BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0]) + + pid = ttgl.program_id(axis=0) + num_pid_m = ttgl.cdiv(M, BLOCK_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + a_desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=a_ptr, shape=(M, N), strides=(N, 1), + block_shape=(BLOCK_M, BLOCK_N), layout=SHARED_LAYOUT) + a_buffer = ttgl.allocate_shared_memory(a_desc.dtype, [NUM_BUFFERS] + a_desc.block_shape, a_desc.layout) + + idx_m = pid_m * BLOCK_M + for i in ttgl.static_range(0, NUM_BUFFERS): + idx_n = pid_n * (BLOCK_N * NUM_BUFFERS) + i * BLOCK_N + vm = idx_m + ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, BLOCKED_LAYOUT)) + vn = idx_n + ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, BLOCKED_LAYOUT)) + v = (vm[:, None] * N) + vn[None, :] + v = v.to(a_desc.dtype) + a_buffer.index(i).store(v) + + for i in ttgl.static_range(0, NUM_BUFFERS): + idx_n = pid_n * (BLOCK_N * NUM_BUFFERS) + i * BLOCK_N + ttgl.amd.gfx1250.tdm.async_store(a_desc, [idx_m, idx_n], a_buffer.index(i)) + + ttgl.amd.gfx1250.tdm.async_wait(0) + + +@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)]) +@pytest.mark.parametrize("NUM_BUFFERS", [1, 2]) +def test_compile_tensor_fill(BLOCK_M, BLOCK_N, NUM_BUFFERS): + k = triton.compile( + gluon._runtime.GluonASTSource( + fn=tensor_fill_kernel, signature={ + "a_ptr": "*fp16", "M": "i32", "N": "i32", # + "BLOCK_M": "constexpr", "BLOCK_N": "constexpr", "NUM_BUFFERS": "constexpr" + }, constexprs={"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "NUM_BUFFERS": NUM_BUFFERS}), + target=GPUTarget("hip", 'gfx1250', 32)) + + amdgcn = k.asm["amdgcn"] + + for pattern in ("tensor_store_from_lds", "s_wait_tensorcnt 0x0"): + assert re.search(pattern, amdgcn) + + +@pytest.mark.parametrize("BLOCK_M,BLOCK_N", [(32, 32), (32, 64), (64, 64)]) +@pytest.mark.parametrize("NUM_BUFFERS", [1, 2]) +@pytest.mark.parametrize("M,N", [(1024, 1024), (1000, 1000)]) +def test_runtime_tensor_fill(M, N, BLOCK_M, BLOCK_N, NUM_BUFFERS): + a = torch.zeros((M, N), dtype=torch.uint16) + + a_device = a.cuda() + grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N * NUM_BUFFERS), 1) + tensor_fill_kernel[grid](a_device, M, N, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, NUM_BUFFERS=NUM_BUFFERS) + + a_triton = a_device.cpu() + a_ref = torch.arange(M, dtype=torch.int16).unsqueeze(1) * N + \ + torch.arange(N, dtype=torch.int16).unsqueeze(0) + a_ref = a_ref.to(torch.uint16) + assert torch.equal(a_triton, a_ref) + + @gluon.jit def mxgemm_kernel(a_ptr, b_ptr, c_ptr, a_scale, b_scale, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scale, DTYPE_A: ttgl.constexpr, DTYPE_B: ttgl.constexpr, From 78f0f8e8f17bc4202a9e6a4cad0be2dcbf32deb6 Mon Sep 17 00:00:00 2001 From: Pengzhan Zhao Date: Tue, 7 Oct 2025 16:50:50 -0700 Subject: [PATCH 2/5] remove read side effect --- .../amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index f7a0116ea1c6..7f5d0803355d 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -745,7 +745,7 @@ def AsyncTDMCopyLocalToGlobalOp : TT_AMDGPU_Op<"async_tdm_copy_local_to_global"> }]; let arguments = (ins - Arg, MemWrite]>:$desc, + Arg]>:$desc, Variadic:$indices, Arg]>:$src ); From 5a0847f84f1b111ac1663c4a4d6a4d4ffe5538db Mon Sep 17 00:00:00 2001 From: Pengzhan Zhao Date: Wed, 8 Oct 2025 13:28:46 -0700 Subject: [PATCH 3/5] typo --- .../amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 7297c3246363..1bf8a3c6ac5f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1046,11 +1046,11 @@ struct AsyncTDMCopyGlobalToLocalOpConversion loc, adaptor.getResult(), elementType, rewriter); Value dstPtr = dstMemObj.getBase(); SmallVector offset = adaptor.getIndices(); - int numWraps = triton::gpu::lookupNumWarps(op); + int numWarps = triton::gpu::lookupNumWarps(op); auto [group0, group1] = LLVM::AMD::createTDMDescriptor( rewriter, loc, getTypeConverter(), elementType, blockShape, tensorShape, - tensorStride, offset, srcPtr, dstPtr, op.getPred(), numWraps, + tensorStride, offset, srcPtr, dstPtr, op.getPred(), numWarps, padInterval, padAmount); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.tensor.load.to.lds.d2", {}, @@ -1091,12 +1091,12 @@ struct AsyncTDMCopyLocalToGlobalOpConversion loc, adaptor.getSrc(), elementType, rewriter); Value dstPtr = dstMemObj.getBase(); SmallVector offset = adaptor.getIndices(); - int numWraps = triton::gpu::lookupNumWarps(op); + int numWarps = triton::gpu::lookupNumWarps(op); Value pred = b.true_val(); auto [group0, group1] = LLVM::AMD::createTDMDescriptor( rewriter, loc, getTypeConverter(), elementType, blockShape, tensorShape, - tensorStride, offset, srcPtr, dstPtr, pred, numWraps, + tensorStride, offset, srcPtr, dstPtr, pred, numWarps, /*padInterval=*/0, /*padAmount=*/0); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.tensor.store.from.lds.d2", {}, From 46b6ca5b02a801b390ca182336985b364ae22abb Mon Sep 17 00:00:00 2001 From: Pengzhan Zhao Date: Wed, 8 Oct 2025 14:47:43 -0700 Subject: [PATCH 4/5] directly create tdm descriptor --- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 46 +++-- .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp | 195 ++++++++++-------- .../amd/lib/TritonAMDGPUToLLVM/TDMUtility.h | 36 ++-- .../TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp | 38 +++- .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 14 +- 5 files changed, 187 insertions(+), 142 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 1bf8a3c6ac5f..54aba24624bc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1020,15 +1020,12 @@ struct AsyncTDMCopyGlobalToLocalOpConversion auto paddedEnc = llvm::dyn_cast(smemTy.getEncoding()); Type elementType = getTypeConverter()->convertType(smemTy.getElementType()); - auto elementBitWidth = elementType.getIntOrFloatBitWidth(); unsigned padInterval = 0; unsigned padAmount = 0; if (paddedEnc) { - if (paddedEnc.getIntervals().size() != 1 || - paddedEnc.getPaddings().size() != 1) - return rewriter.notifyMatchFailure( - op, "NYI: Multiple interval-padding pairs in TDM."); + assert(paddedEnc.getIntervals().size() == 1 && + paddedEnc.getPaddings().size() == 1); padInterval = paddedEnc.getIntervals()[0]; padAmount = paddedEnc.getPaddings()[0]; } @@ -1038,20 +1035,27 @@ struct AsyncTDMCopyGlobalToLocalOpConversion if (numCTAs > 1) return rewriter.notifyMatchFailure(op, "NYI: Support multicast."); + SmallVector desc = + unpackLLElements(loc, adaptor.getDesc(), rewriter); + assert(desc.size() == 12); + auto group0Vec = SmallVector(desc.begin(), desc.begin() + 4); + auto group1Vec = SmallVector(desc.begin() + 4, desc.end()); + SmallVector blockShape = llvm::to_vector(tensorDescTy.getBlockType().getShape()); - auto [srcPtr, tensorShape, tensorStride] = - LLVM::AMD::unpackTensorDesc(rewriter, loc, adaptor.getDesc()); auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getResult(), elementType, rewriter); Value dstPtr = dstMemObj.getBase(); SmallVector offset = adaptor.getIndices(); int numWarps = triton::gpu::lookupNumWarps(op); - auto [group0, group1] = LLVM::AMD::createTDMDescriptor( - rewriter, loc, getTypeConverter(), elementType, blockShape, tensorShape, - tensorStride, offset, srcPtr, dstPtr, op.getPred(), numWarps, - padInterval, padAmount); + LLVM::AMD::fillTDMDescriptor(rewriter, loc, getTypeConverter(), elementType, + blockShape, numWarps, padInterval, padAmount, + group0Vec, group1Vec, offset, dstPtr, + op.getPred()); + + auto group0 = packLLVector(loc, group0Vec, rewriter); + auto group1 = packLLVector(loc, group1Vec, rewriter); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.tensor.load.to.lds.d2", {}, {group0, group1, b.i32_val(0)}); @@ -1083,21 +1087,27 @@ struct AsyncTDMCopyLocalToGlobalOpConversion auto smemTy = op.getSrc().getType(); Type elementType = getTypeConverter()->convertType(smemTy.getElementType()); + SmallVector desc = + unpackLLElements(loc, adaptor.getDesc(), rewriter); + assert(desc.size() == 12); + auto group0Vec = SmallVector(desc.begin(), desc.begin() + 4); + auto group1Vec = SmallVector(desc.begin() + 4, desc.end()); + SmallVector blockShape = llvm::to_vector(tensorDescTy.getBlockType().getShape()); - auto [srcPtr, tensorShape, tensorStride] = - LLVM::AMD::unpackTensorDesc(rewriter, loc, adaptor.getDesc()); auto dstMemObj = LLVM::getSharedMemoryObjectFromStruct( loc, adaptor.getSrc(), elementType, rewriter); Value dstPtr = dstMemObj.getBase(); SmallVector offset = adaptor.getIndices(); int numWarps = triton::gpu::lookupNumWarps(op); - Value pred = b.true_val(); - auto [group0, group1] = LLVM::AMD::createTDMDescriptor( - rewriter, loc, getTypeConverter(), elementType, blockShape, tensorShape, - tensorStride, offset, srcPtr, dstPtr, pred, numWarps, - /*padInterval=*/0, /*padAmount=*/0); + LLVM::AMD::fillTDMDescriptor(rewriter, loc, getTypeConverter(), elementType, + blockShape, numWarps, /*padInterval=*/0, + /*padAmount=*/0, group0Vec, group1Vec, offset, + dstPtr, b.true_val()); + + auto group0 = packLLVector(loc, group0Vec, rewriter); + auto group1 = packLLVector(loc, group1Vec, rewriter); LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.tensor.store.from.lds.d2", {}, {group0, group1, b.i32_val(0)}); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp index b7f41234c819..bd0558320f7d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.cpp @@ -3,82 +3,72 @@ namespace mlir::LLVM::AMD { -std::pair +namespace { + +// Decode a TDM descriptor from group vectors into +// (base, [shape0, shape1], [stride0, stride1]). +std::tuple, SmallVector> +decodeTDMDescriptor(RewriterBase &rewriter, Location loc, + ArrayRef group0, ArrayRef group1) { + auto ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Type globalPtrTy = ptr_ty(ctx, 1); + + Value globalAddrLow = group0[2]; + Value globalAddrHigh = b.and_(group0[3], b.i32_val(0x7FFFFFFF)); + globalAddrLow = b.zext(i64_ty, globalAddrLow); + globalAddrHigh = b.shl(b.zext(i64_ty, globalAddrHigh), b.i64_val(32)); + Value globalAddr = b.or_(globalAddrLow, globalAddrHigh); + Value srcPtr = b.inttoptr(globalPtrTy, globalAddr); + + Value tensorStride0 = group1[5]; + Value tensorStride1 = b.i32_val(1); + SmallVector tensorStride = {tensorStride0, tensorStride1}; + + Value tensorShape1Low = b.lshr(group1[1], b.i32_val(16)); + Value tensorShape1High = b.shl(group1[2], b.i32_val(16)); + Value tensorShape1 = b.or_(tensorShape1Low, tensorShape1High); + Value tensorShape0Low = b.lshr(group1[2], b.i32_val(16)); + Value tensorShape0High = b.shl(group1[3], b.i32_val(16)); + Value tensorShape0 = b.or_(tensorShape0Low, tensorShape0High); + SmallVector tensorShape = {tensorShape0, tensorShape1}; + + return {srcPtr, tensorShape, tensorStride}; +} +} // namespace + +std::pair, SmallVector> createTDMDescriptor(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, Type elementType, - SmallVector blockShape, + SmallVector blockShape, int numWarps, + unsigned padInterval, unsigned padAmount, SmallVector tensorShape, - SmallVector tensorStride, SmallVector offset, - Value srcPtr, Value dstPtr, Value pred, int numWarps, - unsigned padInterval, unsigned padAmount) { + SmallVector tensorStride, Value srcPtr) { assert(tensorShape.size() == 2 && tensorStride.size() == 2 && - blockShape.size() == 2 && offset.size() == 2 && - "NYI: TDM > 2D cases."); + blockShape.size() == 2 && "NYI: TDM > 2D cases."); auto ctx = rewriter.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); auto elementBitWidth = elementType.getIntOrFloatBitWidth(); auto elementSizeInBytes = elementBitWidth / 8; - Type globalPtrTy = ptr_ty(ctx, 1); - Type sharedPtrTy = ptr_ty(ctx, 3); - // Cast strides from i64 to i32 tensorStride[0] = b.trunc(i32_ty, tensorStride[0]); tensorStride[1] = b.trunc(i32_ty, tensorStride[1]); // For block shape [M, N], each warp will handle shape [M/numWarps, N]. - auto warpId = getLaneAndWarpId(rewriter, loc).second; - int outerBlockShape = blockShape[0]; - int outerBlockShapePerWarp = ceil(outerBlockShape, numWarps); - int outerBlockStride = blockShape[1]; - - // Shift global pointer by offset - Value outerOffset = b.mul(b.i32_val(outerBlockShapePerWarp), warpId); - offset[0] = b.add(offset[0], outerOffset); - - Value baseOffset = b.add(b.mul(tensorStride[0], offset[0]), - b.mul(tensorStride[1], offset[1])); - srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset); - - // Shift shared pointer by offset - Value dstOffset = b.mul(b.i32_val(outerBlockStride), outerOffset); - if (padInterval > 0 && padAmount > 0) { - Value iVal = b.i32_val(log2(padInterval)); - Value pVal = b.i32_val(log2(padAmount)); - Value padOffset = b.shl(i32_ty, b.ashr(dstOffset, iVal), pVal); - dstOffset = b.add(dstOffset, padOffset); - } - dstPtr = b.gep(sharedPtrTy, elementType, dstPtr, dstOffset); - - // Update tensor shape and block shape based on offset - Value zero = b.i32_val(0); - tensorShape[0] = b.smax(zero, b.sub(tensorShape[0], offset[0])); - tensorShape[1] = b.smax(zero, b.sub(tensorShape[1], offset[1])); - - blockShape[0] = outerBlockShapePerWarp; + blockShape[0] = ceil(blockShape[0], int64_t(numWarps)); // group0 (128 bits / 4 dwords) effective bit encoding: - // [1:0]: pred - // [63:32]: lds address + // [1:0]: pred (to be filled later) + // [63:32]: lds address (to be filled later) // [120:64]: global address // [127:126]: type - currently always set to 0x2 - SmallVector group0(4, b.i32_val(0)); + SmallVector group0(4, b.i32_val(0)); Value globalAddr = b.ptrtoint(i64_ty, srcPtr); - Value ldsAddr = b.ptrtoint(i32_ty, dstPtr); - group0[0] = b.zext(i32_ty, pred); - group0[1] = ldsAddr; group0[2] = b.trunc(i32_ty, globalAddr); group0[3] = b.trunc(i32_ty, b.lshr(globalAddr, b.i64_val(32))); - group0[3] = b.or_(group0[3], b.i32_val(0x80000000)); - - VectorType vecTy0 = vec_ty(i32_ty, 4); - Value group0Vec = b.undef(vecTy0); - for (unsigned ii = 0; ii < 4; ++ii) { - Value vecIdx = rewriter.create( - loc, typeConverter->getIndexType(), rewriter.getI32IntegerAttr(ii)); - group0Vec = b.insert_element(vecTy0, group0Vec, group0[ii], vecIdx); - } + group0[3] = b.or_(group0[3], b.i32_val(1 << 31)); // group1 (256 bits / 8 dwords) effective bit encoding: // [15:0]: multicast mask @@ -91,7 +81,7 @@ createTDMDescriptor(RewriterBase &rewriter, Location loc, // [127:112]: block shape dim inner // [143:128]: block shape dim outer // [207:160]: tensor stride dim outer (we only use 32 bits) - SmallVector group1(8, b.i32_val(0)); + SmallVector group1(8, b.i32_val(0)); int32_t dataSize = log2(elementSizeInBytes); unsigned dwordSize = 32; auto padIntervalInDwords = padInterval * elementBitWidth / dwordSize; @@ -112,43 +102,72 @@ createTDMDescriptor(RewriterBase &rewriter, Location loc, group1[4] = b.i32_val(blockShape[0] & 0xFFFF); group1[5] = tensorStride[0]; - VectorType vecTy1 = vec_ty(i32_ty, 8); - Value group1Vec = b.undef(vecTy1); - for (unsigned ii = 0; ii < 8; ++ii) { - Value vecIdx = rewriter.create( - loc, typeConverter->getIndexType(), rewriter.getIndexAttr(ii)); - group1Vec = b.insert_element(vecTy1, group1Vec, group1[ii], vecIdx); - } - - return {group0Vec, group1Vec}; + return {group0, group1}; } -Value packTensorDesc(RewriterBase &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, Value base, - ValueRange tensorShape, ValueRange tensorStride, - Type resultTy) { - SmallVector elems; +void fillTDMDescriptor(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type elementType, + SmallVector blockShape, int numWarps, + unsigned padInterval, unsigned padAmount, + SmallVector &group0, SmallVector &group1, + SmallVector offset, Value dstPtr, Value pred) { + assert(offset.size() == 2 && "NYI: TDM > 2D cases."); + auto ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Type globalPtrTy = ptr_ty(ctx, 1); + Type sharedPtrTy = ptr_ty(ctx, 3); - elems.push_back(base); - llvm::append_range(elems, tensorShape); - llvm::append_range(elems, tensorStride); - return packLLElements(loc, typeConverter, elems, rewriter, resultTy); -} + auto [srcPtr, tensorShape, tensorStride] = + decodeTDMDescriptor(rewriter, loc, group0, group1); -std::tuple, SmallVector> -unpackTensorDesc(RewriterBase &rewriter, Location loc, Value desc) { - SmallVector descriptorFields = unpackLLElements(loc, desc, rewriter); - auto length = descriptorFields.size(); - assert(length >= 5 && "invalid tensor descriptor"); - - Value base = descriptorFields[0]; - SmallVector tensorShape; - SmallVector tensorStride; - for (int i = 1; i < (length - 1) / 2 + 1; i++) - tensorShape.push_back(descriptorFields[i]); - for (int i = (length - 1) / 2 + 1; i < length; i++) - tensorStride.push_back(descriptorFields[i]); - return {base, tensorShape, tensorStride}; + auto warpId = getLaneAndWarpId(rewriter, loc).second; + int outerBlockShapePerWarp = ceil(blockShape[0], int64_t(numWarps)); + int outerBlockStride = blockShape[1]; + + // Shift global pointer by offset + Value outerOffset = b.mul(b.i32_val(outerBlockShapePerWarp), warpId); + offset[0] = b.add(offset[0], outerOffset); + + Value baseOffset = b.add(b.mul(tensorStride[0], offset[0]), + b.mul(tensorStride[1], offset[1])); + srcPtr = b.gep(globalPtrTy, elementType, srcPtr, baseOffset); + + // Shift shared pointer by offset + Value dstOffset = b.mul(b.i32_val(outerBlockStride), outerOffset); + if (padInterval > 0 && padAmount > 0) { + Value iVal = b.i32_val(log2(padInterval)); + Value pVal = b.i32_val(log2(padAmount)); + Value padOffset = b.shl(i32_ty, b.ashr(dstOffset, iVal), pVal); + dstOffset = b.add(dstOffset, padOffset); + } + dstPtr = b.gep(sharedPtrTy, elementType, dstPtr, dstOffset); + + // Update tensor shape and block shape based on offset + tensorShape[0] = b.smax(b.i32_val(0), b.sub(tensorShape[0], offset[0])); + tensorShape[1] = b.smax(b.i32_val(0), b.sub(tensorShape[1], offset[1])); + + // group0 changed fields: + // [1:0]: pred + // [63:32]: lds address + // [120:64]: global address + Value globalAddr = b.ptrtoint(i64_ty, srcPtr); + Value ldsAddr = b.ptrtoint(i32_ty, dstPtr); + group0[0] = b.zext(i32_ty, pred); + group0[1] = ldsAddr; + group0[2] = b.trunc(i32_ty, globalAddr); + group0[3] = b.and_(group0[3], b.i32_val(1 << 31)); + group0[3] = + b.or_(group0[3], b.trunc(i32_ty, b.lshr(globalAddr, b.i64_val(32)))); + + // group1 changed fields: + // [79:48]: tensor shape dim inner + // [111:80]: tensor shape dim outer + group1[1] = b.shl(tensorShape[1], b.i32_val(16)); + group1[2] = b.lshr(tensorShape[1], b.i32_val(16)); + group1[2] = b.or_(group1[2], b.shl(tensorShape[0], b.i32_val(16))); + group1[3] = b.and_(group1[3], b.i32_val(0xFFFF << 16)); + group1[3] = b.or_(group1[3], b.lshr(tensorShape[0], b.i32_val(16))); } } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h index 6432eb044b3a..ad3210b2b967 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h @@ -8,26 +8,22 @@ using mlir::triton::AMD::TargetInfo; namespace mlir::LLVM::AMD { -// Create a TDM descriptor, divided into 2 groups. -std::pair createTDMDescriptor( - RewriterBase &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, Type elementType, - SmallVector blockShape, SmallVector tensorShape, - SmallVector tensorStride, SmallVector tensorOffset, - Value srcPtr, Value dstPtr, Value pred, int numWarps, unsigned padInterval, - unsigned padAmount); - -// Pack base pointer, shape, and stride from a tensor descriptor into a single -// llvm struct value. -Value packTensorDesc(RewriterBase &rewriter, Location loc, - const LLVMTypeConverter *typeConverter, Value base, - ValueRange tensorShape, ValueRange tensorStride, - Type resultTy); - -// Unpack a tensor descriptor from a single llvm struct value into -// (base, [shape0, shape1, ...], [stride0, stride1, ...]). -std::tuple, SmallVector> -unpackTensorDesc(RewriterBase &rewriter, Location loc, Value desc); +// Create a TDM descriptor, divided into 2 group vectors. +std::pair, SmallVector> +createTDMDescriptor(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type elementType, + SmallVector blockShape, int numWarps, + unsigned padInterval, unsigned padAmount, + SmallVector tensorShape, + SmallVector tensorStride, Value srcPtr); + +// Fill a TDM descriptor with offset, shared memory address, and pred. +void fillTDMDescriptor(RewriterBase &rewriter, Location loc, + const LLVMTypeConverter *typeConverter, Type elementType, + SmallVector blockShape, int numWarps, + unsigned padInterval, unsigned padAmount, + SmallVector &group0, SmallVector &group1, + SmallVector offset, Value dstPtr, Value pred); } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index 170dba874702..58ba103bb420 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -8,6 +8,7 @@ using namespace mlir; using namespace mlir::triton; +using namespace mlir::triton::gpu; namespace { struct MakeTensorDescOpConversion @@ -18,16 +19,45 @@ struct MakeTensorDescOpConversion LogicalResult matchAndRewrite(triton::MakeTensorDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); auto basePtr = adaptor.getBase(); auto tensorShape = adaptor.getShape(); auto tensorStride = adaptor.getStrides(); auto result = op.getResult(); - Value desc = - LLVM::AMD::packTensorDesc(rewriter, loc, getTypeConverter(), basePtr, - tensorShape, tensorStride, result.getType()); + auto tensorDescTy = result.getType(); + auto blockTy = tensorDescTy.getBlockType(); + auto enc = blockTy.getEncoding(); + if (!enc) { + return rewriter.notifyMatchFailure(op, "Descriptor has no layout."); + } + auto paddedEnc = llvm::dyn_cast(enc); + + unsigned padInterval = 0; + unsigned padAmount = 0; + if (paddedEnc) { + if (paddedEnc.getIntervals().size() != 1 || + paddedEnc.getPaddings().size() != 1) + return rewriter.notifyMatchFailure( + op, "NYI: Multiple interval-padding pairs in TDM."); + padInterval = paddedEnc.getIntervals()[0]; + padAmount = paddedEnc.getPaddings()[0]; + } + + Type elementType = + getTypeConverter()->convertType(blockTy.getElementType()); + SmallVector blockShape = llvm::to_vector(blockTy.getShape()); + int numWarps = lookupNumWarps(op); + + auto [group0, group1] = LLVM::AMD::createTDMDescriptor( + rewriter, loc, getTypeConverter(), elementType, blockShape, numWarps, + padInterval, padAmount, tensorShape, tensorStride, basePtr); + SmallVector groups; + llvm::append_range(groups, group0); + llvm::append_range(groups, group1); + auto desc = + packLLElements(loc, getTypeConverter(), groups, rewriter, tensorDescTy); + rewriter.replaceOp(op, desc); return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index fa5fdb5a9ae0..ff0590597a61 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -77,18 +77,8 @@ class TritonAMDGPUToLLVMTypeConverter : public TritonGPUToLLVMTypeConverter { Type convertTensorDescType(triton::TensorDescType type) { auto ctx = type.getContext(); - - RankedTensorType rankedTensorType = type.getBlockType(); - auto eleType = rankedTensorType.getElementType(); - auto shape = rankedTensorType.getShape(); - SmallVector types; - // base ptr - types.push_back(LLVM::LLVMPointerType::get(ctx, 1)); - // 32 bit shapes - types.append(shape.size(), IntegerType::get(ctx, 32)); - // 64 bit strides - types.append(shape.size(), IntegerType::get(ctx, 64)); - + // 4 for group0, 8 for group1 + auto types = SmallVector(4 + 8, IntegerType::get(ctx, 32)); return LLVM::LLVMStructType::getLiteral(ctx, types); } }; From 668cd1dcbf25944dd3c05a9924f39cea94476c15 Mon Sep 17 00:00:00 2001 From: Pengzhan Zhao Date: Wed, 8 Oct 2025 18:42:19 -0700 Subject: [PATCH 5/5] update doc --- third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h index ad3210b2b967..7a2239f1853a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TDMUtility.h @@ -8,7 +8,9 @@ using mlir::triton::AMD::TargetInfo; namespace mlir::LLVM::AMD { -// Create a TDM descriptor, divided into 2 group vectors. +// Create a TDM descriptor, divided into 2 group vectors. This creates a +// partially filled descriptor, with shared memory address and pred set to zero. +// User of the descriptor is expected to fill these fields later. std::pair, SmallVector> createTDMDescriptor(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, Type elementType, @@ -17,7 +19,8 @@ createTDMDescriptor(RewriterBase &rewriter, Location loc, SmallVector tensorShape, SmallVector tensorStride, Value srcPtr); -// Fill a TDM descriptor with offset, shared memory address, and pred. +// Update the global memory address with offset, and fill the shared memory +// address and pred in a given TDM descriptor. void fillTDMDescriptor(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, Type elementType, SmallVector blockShape, int numWarps,