Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,12 @@ void init_gluon_ir(py::module &&m) {
self.create<ttag::AsyncTDMCopyGlobalToLocalOp>(descPtr, indices,
result, pred);
})
.def("create_async_tdm_copy_local_to_global",
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
Value src) {
self.create<ttag::AsyncTDMCopyLocalToGlobalOp>(descPtr, indices,
src);
})
.def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) {
ValueRange tokens;
self.create<ttag::AsyncTDMWait>(tokens, num);
Expand Down
53 changes: 49 additions & 4 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2733,7 +2733,7 @@ def kernel():


@gluon.jit
def amd_tdm_kernel(ptr):
def amd_tdm_load_kernel(ptr):
SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [16, 64], [1, 0])
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])

Expand All @@ -2748,17 +2748,17 @@ def amd_tdm_kernel(ptr):


@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
def test_amd_tdm(target):
def test_amd_tdm_load(target):

ptr = MockTensor(ttgl.float16)
module = run_parser(amd_tdm_kernel, *make_args(ptr), target)
module = run_parser(amd_tdm_load_kernel, *make_args(ptr), target)
expecttest.assert_expected_inline(
anonymize_ir(module.str_nodebug()), """\
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @amd_tdm_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
tt.func public @amd_tdm_load_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
Expand All @@ -2775,3 +2775,48 @@ def test_amd_tdm(target):
}
}
""")


@gluon.jit
def amd_tdm_store_kernel(ptr):
SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])

desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=ptr, shape=(32, 128), strides=(128, 1),
block_shape=(16, 64), layout=SHARED_LAYOUT)

value = ttgl.full([16, 64], 1.0, ttgl.float16, layout=BLOCKED_LAYOUT)
buffer = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value)

ttgl.amd.gfx1250.tdm.async_store(desc, offsets=[0, 2], src=buffer)
ttgl.amd.gfx1250.tdm.async_wait(0)


@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
def test_amd_tdm_store(target):

ptr = MockTensor(ttgl.float16)
module = run_parser(amd_tdm_store_kernel, *make_args(ptr), target)
expecttest.assert_expected_inline(
anonymize_ir(module.str_nodebug()), """\
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @amd_tdm_store_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c1_i64 = arith.constant 1 : i64
%0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : <f16>, <tensor<16x64xf16, #shared>>
%cst = arith.constant 1.000000e+00 : f16
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked>
%1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
%c0_i32 = arith.constant 0 : i32
%c2_i32 = arith.constant 2 : i32
amdgpu.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<16x64xf16, #shared>>
%2 = amdgpu.async_tdm_wait {num = 0 : i32}
tt.return
}
}
""")
27 changes: 22 additions & 5 deletions python/triton/experimental/gluon/language/amd/gfx1250/tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import triton.experimental.gluon.language._core as ttgl
from triton.experimental.gluon.language._layouts import PaddedSharedLayout
from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr

if TYPE_CHECKING:
Expand All @@ -20,7 +20,7 @@ class tensor_descriptor_type(ttgl.base_type):
block_type: ttgl.block_type
shape_type: ttgl.tuple_type
strides_type: ttgl.tuple_type
layout: PaddedSharedLayout
layout: PaddedSharedLayout | SwizzledSharedLayout

def __str__(self) -> str:
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
Expand Down Expand Up @@ -84,15 +84,15 @@ def layout(self):
@builtin
def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.tensor],
strides: List[ttgl.constexpr | ttgl.tensor], block_shape: List[ttgl.constexpr],
layout: PaddedSharedLayout, _semantic=None) -> tensor_descriptor:
layout: PaddedSharedLayout | SwizzledSharedLayout, _semantic=None) -> tensor_descriptor:
"""Make a tensor descriptor object.

Args:
base (tensor): base pointer of the tensor in global memory.
shape (List[int]): shape of the tensor.
strides (List[int]): strides of the tensor.
block_shape (List[int]): block shape of the tensor.
layout (PaddedSharedLayout): the layout of the tensor in shared memory.
layout (PaddedSharedLayout | SwizzledSharedLayout): the layout of the tensor in shared memory.

Returns:
tensor_descriptor: the created tensor descriptor object
Expand All @@ -105,7 +105,10 @@ def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.
assert isinstance(base.dtype, ttgl.pointer_type), "Expected base to be a pointer"

layout = _unwrap_if_constexpr(layout)
assert isinstance(layout, PaddedSharedLayout), "Expected layout to be a PaddedSharedLayout"
assert isinstance(layout, (PaddedSharedLayout, SwizzledSharedLayout)), \
"Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout"
if isinstance(layout, SwizzledSharedLayout):
assert layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout"

base_handle = base.handle
shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) # i32 shape
Expand Down Expand Up @@ -137,6 +140,20 @@ def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tenso
_semantic.builder.create_async_tdm_copy_global_to_local(src.handle, offset_handles, dest.handle)


@builtin
def async_store(dest: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], src: shared_memory_descriptor,
_semantic=None) -> None:
"""Store a block of tensor specified in tensor descriptor from shared memory to global memory asynchronously.

Args:
dest (tensor_descriptor): the destination tensor descriptor.
offsets (List[int]): the offsets from the base pointer in the tensor descriptor.
src (shared_memory_descriptor): the shared memory source to load the data.
"""
offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False)
_semantic.builder.create_async_tdm_copy_local_to_global(dest.handle, offset_handles, src.handle)


@builtin
def async_wait(num_outstanding=0, _semantic=None) -> None:
"""Wait for the outstanding asynchronous tensor operations to complete.
Expand Down
46 changes: 36 additions & 10 deletions test/Conversion/amd/tritongpu_tdm_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,25 +1,51 @@
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=GFX1250
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
// GFX1250-LABEL: tdm_kernel
tt.func public @tdm_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tdm_load
tt.func public @tdm_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c_shape = arith.constant 128 : i32
%c_stride0 = arith.constant 128 : i64
%c_stride1 = arith.constant 1 : i64
%c_offset = arith.constant 0 : i32
%c_pred = arith.constant true
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16>>
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// GFX1250-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
// GFX1250-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
// GFX1250: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// GFX1250: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
// CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
// CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
// CHECK: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16, #shared>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
%4 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tdm_store
tt.func public @tdm_store(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c_shape = arith.constant 128 : i32
%c_stride0 = arith.constant 128 : i64
%c_stride1 = arith.constant 1 : i64
%c_offset = arith.constant 0 : i32
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
%2 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked>
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
// CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
// CHECK: llvm.amdgcn.tensor.store.from.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
amdgpu.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<64x64xf16, #shared>>
// CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,9 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local">
This operation copies data from global memory to local memory
asynchronously. This is analogue to tt.load except the data are copied to
local memory pointed by `result` instead of a distributed tensor. The data
copied depends on the global memory descriptor pointed to by `desc`. Set
`pred` to false will disable the copy.
copied depends on the global memory pointed to by `desc`. Set `pred` to
false will disable the copy. This operation does not support shared memory
swizzling.
}];

let arguments = (ins
Expand All @@ -724,6 +725,37 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local">
$desc `[` $indices `]` `into` $result `,` $pred
attr-dict `:` qualified(type($desc)) `->` qualified(type($result))
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// AsyncTDMCopyLocalToGlobalOp
//===----------------------------------------------------------------------===//

def AsyncTDMCopyLocalToGlobalOp : TT_AMDGPU_Op<"async_tdm_copy_local_to_global"> {
let summary = "Copy data based on descriptor from local memory to global memory asynchronously";

let description = [{
This operation copies data from local memory to global memory
asynchronously. This is analogue to tt.store except the data are copied from
local memory pointed by `src` instead of a distributed tensor. The copy
destination depends on the global memory pointed to by `desc`. This
operation does not support shared memory padding or swizzling.
}];

let arguments = (ins
Arg<TT_TensorDescType, "", [MemWrite<GlobalMemory>]>:$desc,
Variadic<I32>:$indices,
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
);

let assemblyFormat = [{
$desc `[` $indices `]` `from` $src
attr-dict `:` qualified(type($src)) `->` qualified(type($desc))
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,4 +607,57 @@ void ConcatOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
patterns.add(foldConcatOpFromSingleSource);
}

LogicalResult AsyncTDMCopyGlobalToLocalOp::verify() {
auto tensorDescTy = getDesc().getType();
auto smemTy = getResult().getType();

auto swizzledEnc =
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
return emitOpError("TDM does not support swizzling");

auto paddedEnc =
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

Type elementType = smemTy.getElementType();
auto elementBitWidth = elementType.getIntOrFloatBitWidth();
if (paddedEnc) {
unsigned dwordSize = 32;
for (auto [interval, padding] :
llvm::zip(paddedEnc.getIntervals(), paddedEnc.getPaddings())) {
auto intervalInDwords = interval * elementBitWidth / dwordSize;
if (intervalInDwords < 2)
return emitOpError("TDM padding interval must be at least 2 dwords");

auto paddingInDwords = padding * elementBitWidth / dwordSize;
if (paddingInDwords < 1)
return emitOpError("TDM padding amount must be at least 1 dword");
}
}

return success();
}

LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() {
auto tensorDescTy = getDesc().getType();
auto smemTy = getSrc().getType();

auto swizzledEnc =
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
return emitOpError("TDM does not support swizzling");

auto paddedEnc =
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
if (paddedEnc)
return emitOpError("TDM store does not support padding");

if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

return success();
}

} // namespace mlir::triton::amdgpu
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_triton_library(TritonAMDGPUToLLVM
Fp4ToFpOpToLLVM.cpp
MembarUtility.cpp
ScalarizePackedFOps.cpp
TDMUtility.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen
Expand Down
Loading
Loading