Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,12 @@ void init_gluon_ir(py::module &&m) {
self.create<ttag::AsyncTDMCopyGlobalToLocalOp>(descPtr, indices,
result, pred);
})
.def("create_async_tdm_copy_local_to_global",
[](GluonOpBuilder &self, Value descPtr, std::vector<Value> &indices,
Value src) {
self.create<ttag::AsyncTDMCopyLocalToGlobalOp>(descPtr, indices,
src);
})
.def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) {
ValueRange tokens;
self.create<ttag::AsyncTDMWait>(tokens, num);
Expand Down
53 changes: 49 additions & 4 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2733,7 +2733,7 @@ def kernel():


@gluon.jit
def amd_tdm_kernel(ptr):
def amd_tdm_load_kernel(ptr):
SHARED_LAYOUT: ttgl.constexpr = ttgl.PaddedSharedLayout.with_identity_for([[32, 4]], [16, 64], [1, 0])
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])

Expand All @@ -2748,17 +2748,17 @@ def amd_tdm_kernel(ptr):


@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
def test_amd_tdm(target):
def test_amd_tdm_load(target):

ptr = MockTensor(ttgl.float16)
module = run_parser(amd_tdm_kernel, *make_args(ptr), target)
module = run_parser(amd_tdm_load_kernel, *make_args(ptr), target)
expecttest.assert_expected_inline(
anonymize_ir(module.str_nodebug()), """\
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [16, 64]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @amd_tdm_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
tt.func public @amd_tdm_load_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
Expand All @@ -2775,3 +2775,48 @@ def test_amd_tdm(target):
}
}
""")


@gluon.jit
def amd_tdm_store_kernel(ptr):
SHARED_LAYOUT: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0])
BLOCKED_LAYOUT: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [4, 8], [4, 1], [1, 0])

desc = ttgl.amd.gfx1250.tdm.make_tensor_descriptor(base=ptr, shape=(32, 128), strides=(128, 1),
block_shape=(16, 64), layout=SHARED_LAYOUT)

value = ttgl.full([16, 64], 1.0, ttgl.float16, layout=BLOCKED_LAYOUT)
buffer = ttgl.allocate_shared_memory(desc.dtype, desc.block_shape, desc.layout, value)

ttgl.amd.gfx1250.tdm.async_store(desc, offsets=[0, 2], src=buffer)
ttgl.amd.gfx1250.tdm.async_wait(0)


@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
def test_amd_tdm_store(target):

ptr = MockTensor(ttgl.float16)
module = run_parser(amd_tdm_store_kernel, *make_args(ptr), target)
expecttest.assert_expected_inline(
anonymize_ir(module.str_nodebug()), """\
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @amd_tdm_store_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c32_i32 = arith.constant 32 : i32
%c128_i32 = arith.constant 128 : i32
%c128_i64 = arith.constant 128 : i64
%c1_i64 = arith.constant 1 : i64
%0 = tt.make_tensor_descriptor %arg0, [%c32_i32, %c128_i32], [%c128_i64, %c1_i64] : <f16>, <tensor<16x64xf16, #shared>>
%cst = arith.constant 1.000000e+00 : f16
%cst_0 = arith.constant dense<1.000000e+00> : tensor<16x64xf16, #blocked>
%1 = ttg.local_alloc %cst_0 : (tensor<16x64xf16, #blocked>) -> !ttg.memdesc<16x64xf16, #shared, #smem, mutable>
%c0_i32 = arith.constant 0 : i32
%c2_i32 = arith.constant 2 : i32
amdgpu.async_tdm_copy_local_to_global %0[%c0_i32, %c2_i32] from %1 : !ttg.memdesc<16x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<16x64xf16, #shared>>
%2 = amdgpu.async_tdm_wait {num = 0 : i32}
tt.return
}
}
""")
27 changes: 22 additions & 5 deletions python/triton/experimental/gluon/language/amd/gfx1250/tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

import triton.experimental.gluon.language._core as ttgl
from triton.experimental.gluon.language._layouts import PaddedSharedLayout
from triton.experimental.gluon.language._layouts import PaddedSharedLayout, SwizzledSharedLayout
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr

if TYPE_CHECKING:
Expand All @@ -20,7 +20,7 @@ class tensor_descriptor_type(ttgl.base_type):
block_type: ttgl.block_type
shape_type: ttgl.tuple_type
strides_type: ttgl.tuple_type
layout: PaddedSharedLayout
layout: PaddedSharedLayout | SwizzledSharedLayout

def __str__(self) -> str:
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
Expand Down Expand Up @@ -84,15 +84,15 @@ def layout(self):
@builtin
def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.tensor],
strides: List[ttgl.constexpr | ttgl.tensor], block_shape: List[ttgl.constexpr],
layout: PaddedSharedLayout, _semantic=None) -> tensor_descriptor:
layout: PaddedSharedLayout | SwizzledSharedLayout, _semantic=None) -> tensor_descriptor:
"""Make a tensor descriptor object.

Args:
base (tensor): base pointer of the tensor in global memory.
shape (List[int]): shape of the tensor.
strides (List[int]): strides of the tensor.
block_shape (List[int]): block shape of the tensor.
layout (PaddedSharedLayout): the layout of the tensor in shared memory.
layout (PaddedSharedLayout | SwizzledSharedLayout): the layout of the tensor in shared memory.

Returns:
tensor_descriptor: the created tensor descriptor object
Expand All @@ -105,7 +105,10 @@ def make_tensor_descriptor(base: ttgl.tensor, shape: List[ttgl.constexpr | ttgl.
assert isinstance(base.dtype, ttgl.pointer_type), "Expected base to be a pointer"

layout = _unwrap_if_constexpr(layout)
assert isinstance(layout, PaddedSharedLayout), "Expected layout to be a PaddedSharedLayout"
assert isinstance(layout, (PaddedSharedLayout, SwizzledSharedLayout)), \
"Expected layout to be a PaddedSharedLayout or SwizzledSharedLayout"
if isinstance(layout, SwizzledSharedLayout):
assert layout.max_phase == 1, "Expected max_phase to be 1 for SwizzledSharedLayout"

base_handle = base.handle
shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) # i32 shape
Expand Down Expand Up @@ -137,6 +140,20 @@ def async_load(src: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tenso
_semantic.builder.create_async_tdm_copy_global_to_local(src.handle, offset_handles, dest.handle)


@builtin
def async_store(dest: tensor_descriptor, offsets: List[ttgl.constexpr | ttgl.tensor], src: shared_memory_descriptor,
_semantic=None) -> None:
"""Store a block of tensor specified in tensor descriptor from shared memory to global memory asynchronously.

Args:
dest (tensor_descriptor): the destination tensor descriptor.
offsets (List[int]): the offsets from the base pointer in the tensor descriptor.
src (shared_memory_descriptor): the shared memory source to load the data.
"""
offset_handles = _semantic._convert_to_ir_values(offsets, require_i64=False)
_semantic.builder.create_async_tdm_copy_local_to_global(dest.handle, offset_handles, src.handle)


@builtin
def async_wait(num_outstanding=0, _semantic=None) -> None:
"""Wait for the outstanding asynchronous tensor operations to complete.
Expand Down
46 changes: 36 additions & 10 deletions test/Conversion/amd/tritongpu_tdm_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,25 +1,51 @@
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s --check-prefixes=GFX1250
// RUN: triton-opt %s --split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 --convert-builtin-func-to-llvm | FileCheck %s

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.padded_shared<[32:+4] {order = [1, 0], shape = [64, 64]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
// GFX1250-LABEL: tdm_kernel
tt.func public @tdm_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tdm_load
tt.func public @tdm_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c_shape = arith.constant 128 : i32
%c_stride0 = arith.constant 128 : i64
%c_stride1 = arith.constant 1 : i64
%c_offset = arith.constant 0 : i32
%c_pred = arith.constant true
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16>>
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// GFX1250-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
// GFX1250-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
// GFX1250: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// GFX1250: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
// CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
// CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
// CHECK: llvm.amdgcn.tensor.load.to.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
%2 = amdgpu.async_tdm_copy_global_to_local %0[%c_offset, %c_offset] into %1, %c_pred : !tt.tensordesc<tensor<64x64xf16, #shared>> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
%4 = ttg.local_load %1 : !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> tensor<64x64xf16, #blocked>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tdm_store
tt.func public @tdm_store(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c_shape = arith.constant 128 : i32
%c_stride0 = arith.constant 128 : i64
%c_stride1 = arith.constant 1 : i64
%c_offset = arith.constant 0 : i32
%0 = tt.make_tensor_descriptor %arg0, [%c_shape, %c_shape], [%c_stride0, %c_stride1] : <f16>, <tensor<64x64xf16, #shared>>
%1 = ttg.local_alloc : () -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
%2 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked>
ttg.local_store %2, %1 : tensor<64x64xf16, #blocked> -> !ttg.memdesc<64x64xf16, #shared, #smem, mutable>
// CHECK-COUNT-4: llvm.insertelement{{.*}} : vector<4xi32>
// CHECK-COUNT-8: llvm.insertelement{{.*}} : vector<8xi32>
// CHECK: llvm.amdgcn.tensor.store.from.lds.d2{{.*}} : (vector<4xi32>, vector<8xi32>, i32) -> ()
amdgpu.async_tdm_copy_local_to_global %0[%c_offset, %c_offset] from %1: !ttg.memdesc<64x64xf16, #shared, #smem, mutable> -> !tt.tensordesc<tensor<64x64xf16, #shared>>
// CHECK: llvm.amdgcn.s.wait.tensorcnt{{.*}} : (i16) -> ()
%3 = amdgpu.async_tdm_wait {num = 0 : i32}
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,9 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local">
This operation copies data from global memory to local memory
asynchronously. This is analogue to tt.load except the data are copied to
local memory pointed by `result` instead of a distributed tensor. The data
copied depends on the global memory descriptor pointed to by `desc`. Set
`pred` to false will disable the copy.
copied depends on the global memory pointed to by `desc`. Set `pred` to
false will disable the copy. This operation does not support shared memory
swizzling.
}];

let arguments = (ins
Expand All @@ -724,6 +725,37 @@ def AsyncTDMCopyGlobalToLocalOp : TT_AMDGPU_Op<"async_tdm_copy_global_to_local">
$desc `[` $indices `]` `into` $result `,` $pred
attr-dict `:` qualified(type($desc)) `->` qualified(type($result))
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// AsyncTDMCopyLocalToGlobalOp
//===----------------------------------------------------------------------===//

def AsyncTDMCopyLocalToGlobalOp : TT_AMDGPU_Op<"async_tdm_copy_local_to_global"> {
let summary = "Copy data based on descriptor from local memory to global memory asynchronously";

let description = [{
This operation copies data from local memory to global memory
asynchronously. This is analogue to tt.store except the data are copied from
local memory pointed by `src` instead of a distributed tensor. The copy
destination depends on the global memory pointed to by `desc`. This
operation does not support shared memory padding or swizzling.
}];

let arguments = (ins
Arg<TT_TensorDescType, "", [MemWrite<GlobalMemory>]>:$desc,
Variadic<I32>:$indices,
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src
);

let assemblyFormat = [{
$desc `[` $indices `]` `from` $src
attr-dict `:` qualified(type($src)) `->` qualified(type($desc))
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,4 +607,57 @@ void ConcatOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
patterns.add(foldConcatOpFromSingleSource);
}

LogicalResult AsyncTDMCopyGlobalToLocalOp::verify() {
auto tensorDescTy = getDesc().getType();
auto smemTy = getResult().getType();

auto swizzledEnc =
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
return emitOpError("TDM does not support swizzling");

auto paddedEnc =
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

Type elementType = smemTy.getElementType();
auto elementBitWidth = elementType.getIntOrFloatBitWidth();
if (paddedEnc) {
unsigned dwordSize = 32;
for (auto [interval, padding] :
llvm::zip(paddedEnc.getIntervals(), paddedEnc.getPaddings())) {
auto intervalInDwords = interval * elementBitWidth / dwordSize;
if (intervalInDwords < 2)
return emitOpError("TDM padding interval must be at least 2 dwords");

auto paddingInDwords = padding * elementBitWidth / dwordSize;
if (paddingInDwords < 1)
return emitOpError("TDM padding amount must be at least 1 dword");
}
}

return success();
}

LogicalResult AsyncTDMCopyLocalToGlobalOp::verify() {
auto tensorDescTy = getDesc().getType();
auto smemTy = getSrc().getType();

auto swizzledEnc =
llvm::dyn_cast<gpu::SwizzledSharedEncodingAttr>(smemTy.getEncoding());
if (swizzledEnc && swizzledEnc.getMaxPhase() != 1)
return emitOpError("TDM does not support swizzling");

auto paddedEnc =
llvm::dyn_cast<gpu::PaddedSharedEncodingAttr>(smemTy.getEncoding());
if (paddedEnc)
return emitOpError("TDM store does not support padding");

if (!paddedEnc && !swizzledEnc)
return emitOpError("Invalid shared memory layout for TDM");

return success();
}

} // namespace mlir::triton::amdgpu
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ add_triton_library(TritonAMDGPUToLLVM
Fp4ToFpOpToLLVM.cpp
MembarUtility.cpp
ScalarizePackedFOps.cpp
TDMUtility.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen
Expand Down
Loading
Loading