From 64b1a7fd565eecd71885c650606aaea58bb68b1a Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 29 May 2025 17:28:33 +0000 Subject: [PATCH 01/14] Add pass to convert block load to subgroup 2d block encoding types --- .../optimize-block-io-encoding.mlir | 65 ++++ third_party/intel/backend/compiler.py | 1 + .../TritonIntelGPU/Transforms/Passes.td | 11 + .../TritonIntelGPUTransforms/CMakeLists.txt | 1 + .../OptimizeBlockIOEncoding.cpp | 319 ++++++++++++++++++ third_party/intel/triton_xpu.cc | 3 + 6 files changed, 400 insertions(+) create mode 100644 test/TritonIntelGPU/optimize-block-io-encoding.mlir create mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir new file mode 100644 index 0000000000..68174d5d90 --- /dev/null +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -0,0 +1,65 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 3ceff03a17..91dd358217 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) + intel.passes.ttgpuir.add_optimize_block_load_encoding(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_optimize_dot_operands(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index c20224aaee..91625ba0be 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -409,4 +409,15 @@ def TritonIntelGPUReduceVariableLiveness "mlir::scf::SCFDialect", "mlir::arith::ArithDialect"]; } + +def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { + let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; + + let description = [{ + Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a ConvertLayout op to the existing encoding replaces the result of the LoadOp. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::triton::TritonDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index b8cb96cfa0..bb32041127 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeBlockIOEncoding.cpp OptimizeDotOperands.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp new file mode 100644 index 0000000000..14ca36f31b --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -0,0 +1,319 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/ADT/PriorityWorklist.h" + +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace mlir { +namespace triton { +namespace gpu::intel { + +#define DEBUG_TYPE "tritongpu-optimize-block-encoding" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +SmallVector getTiedArgs(Operation *op, int resultIdx) { + if (auto forOp = dyn_cast(op)) { + auto iterArg = forOp.getRegionIterArg(resultIdx); + auto result = forOp.getResult(resultIdx); + auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx); + auto initVal = forOp.getInitArgs()[resultIdx]; + return {iterArg, result, yieldVal, initVal}; + } else if (auto whileOp = dyn_cast(op)) { + auto iterArg = whileOp.getBeforeArguments()[resultIdx]; + auto result = whileOp.getResults()[resultIdx]; + auto yieldVal = + whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx); + auto initVal = whileOp.getOperands()[resultIdx]; + return {iterArg, result, iterArg, initVal}; + } else if (auto ifOp = dyn_cast(op)) { + SmallVector values; + for (auto &block : ifOp.getThenRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + for (auto &block : ifOp.getElseRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + values.push_back(ifOp->getResults()[resultIdx]); + return values; + } + return {}; +} + +Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +Type getNewPointerType(Type type, Attribute encoding) { + assert(isa(type) && "expected a ptr type!"); + auto oldPointerType = cast(type); + return PointerType::get(getNewType(oldPointerType.getPointeeType(), encoding), + oldPointerType.getAddressSpace()); +} + +struct EncodingInfo { + Attribute desiredEncoding; + bool requiresConvert = false; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding && + requiresConvert == other.requiresConvert; + } +}; + +/** + * The algorithm here takes inspiration from + * TritonNVIDIAGPU::OptimizeDescriptorEncoding. The idea is to iterate the + * def-use chain in both directions starting from the Load Op. We store the + * values that need to be updated along with the new encoding in the + * `valueToEncodingInfo` MapVector. After all value/encoding pairs have been + * determined, we update the encoding for each value, adding aa conversion to + * the existing Load Op result layout for users of the load. + */ +void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { + auto loadOp = cast(op); + auto loadPtrType = cast(loadOp->getOperand(0).getType()); + auto addressSpace = loadPtrType.getAddressSpace(); + + llvm::MapVector, EncodingInfo> valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + + auto updateEncoding = [&](ArrayRef ptrValues, EncodingInfo info) { + for (auto value : ptrValues) { + bool requiresConvert = llvm::any_of( + value.getUsers(), [](auto user) { return isa(user); }); + info.requiresConvert = requiresConvert; + + auto typedVal = cast>(value); + auto itr = valueToEncodingInfo.find(typedVal); + if (itr == valueToEncodingInfo.end()) { + LLVM_DEBUG(DBGS() << "Add encoding " << info.desiredEncoding + << " for value " << typedVal << "\n"); + valueToEncodingInfo[typedVal] = info; + worklist.insert(typedVal); + } else { + LLVM_DEBUG(DBGS() << "Found existing encoding info " + << itr->second.desiredEncoding << " for value " + << typedVal << ". Ensure new encoding " + << info.desiredEncoding << " matches.\n"); + assert(itr->second == info && "already visited encoding info for " + "value, expected them to be equal!"); + continue; + } + } + }; + + worklist.insert(cast>(loadOp->getOperand(0))); + + // 1. Starting from the Load Op, propagate encoding info up and down the + // def-use chain. + while (!worklist.empty()) { + auto crtValue = worklist.pop_back_val(); + + // Propagate to users + for (OpOperand &use : crtValue.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto vals = getTiedArgs(op, use.getOperandNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(crtValue)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto vals = getTiedArgs(definingOp, opResult.getResultNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } else if (auto blockArg = dyn_cast(crtValue)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto vals = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + } + + // 2. Update the type for each value in-place. Add a ConvertLayout Op after + // any loads which require conversion to the existing layout for the loaded + // value. + for (auto &[val, einfo] : valueToEncodingInfo) { + Attribute newEncoding = einfo.desiredEncoding; + LLVM_DEBUG(DBGS() << "Rewrite encoding to " << newEncoding << " for value " + << val << "\n"); + + PointerType oldType = val.getType(); + auto oldTensorTy = cast(oldType.getPointeeType()); + auto newTensorTy = RankedTensorType::get( + oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding); + + val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace())); + if (einfo.requiresConvert) { + for (auto user : val.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + + OpBuilder builder(loadOp); + auto oldLoadType = loadOp.getType(); + Value result = loadOp.getResult(); + + builder.setInsertionPointAfter(loadOp); + auto cvt = builder.create(loadOp.getLoc(), + result.getType(), result); + LLVM_DEBUG(DBGS() << "Added convert Op:\n" + << cvt << " after Load Op:\n" + << loadOp << "\n"); + result.setType(newTensorTy); + + result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); + } + } + } + } +} + +} // namespace + +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEBLOCKIOENCODINGPASS +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +class TritonIntelGPUOptimizeBlockIOEncodingPass + : public impl::TritonIntelGPUOptimizeBlockIOEncodingPassBase< + TritonIntelGPUOptimizeBlockIOEncodingPass> { + + void getSubgroup2DBlockLayoutForOperand( + Value operand, DpasEncodingAttr dpasLayout, + llvm::MapVector &layoutMap) { + auto isCandidateLoad = [](Value v) -> LoadOp { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + if (auto cvtOp = v.getDefiningOp()) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = v.getDefiningOp()) { + v = transOp.getSrc(); + continue; + } + break; + } + return isa(v.getDefiningOp()) ? cast(v.getDefiningOp()) + : nullptr; + }; + + LoadOp loadOp = isCandidateLoad(operand); + if (!loadOp) + return; + + auto dotOperandType = cast(operand.getType()); + auto dotOperandEncoding = + cast(dotOperandType.getEncoding()); + // layout width is determined by the DPAS operand encoding width + const int kWidth = dotOperandEncoding.getKWidth(); + + Attribute blockIOAttr = + loadOp->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return; + + // get the MakeTensorPtr Op for the load + Value ptr = loadOp.getPtr(); + if (!isTensorPointerType(ptr.getType())) { + // TODO: support tensor of pointer loads + LLVM_DEBUG(DBGS() << "Ptr\n" + << ptr << " for Load Op:\n" + << loadOp + << "\nincompatible with Subgroup 2D Block Layout.\n"); + return; + } + MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); + assert(makeTensorPtrOp && + "expecting a tensor pointer parent to block io load " + "with tensor pointer type"); + + auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); + auto oldTensorType = + cast(oldTensorPtrType.getPointeeType()); + // Note: we need the old layout to get the order for the load, but it is not + // clear the layout will always be Blocked. Is there a better way to get + // this info? + auto oldLayout = cast(oldTensorType.getEncoding()); + + auto CTALayout = getCTALayout(dpasLayout); + const unsigned elemSizeInBits = + oldTensorType.getElementType().getIntOrFloatBitWidth(); + + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(dotOperandEncoding), + oldTensorType.getShape(), + blockIOAttr == StringAttr::get(&getContext(), "row_major"), + elemSizeInBits / 8, &getContext()); + SmallVector instrShape{tileParams[0], tileParams[1]}; + const unsigned vBlocks = tileParams[2]; + + auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( + &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, + tileParams[2], + getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ 2, + /*kContig*/ true), + kWidth, dpasLayout.getThreadsPerWarp()); + + LLVM_DEBUG(DBGS() << "Generated new encoding: " << subgroup2DBlockEncoding + << " for op : " << loadOp << "\n"); + + layoutMap[loadOp] = subgroup2DBlockEncoding; + } + +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Step 1. Find all loads which are candidates for conversion to Subgroup 2D + // Block Encoding. To be a candidate load, a load must be consumed by a Dot + // Op and the load operand must be a block ptr (produced by a MakeTensorPtr + // Op). Currently we look for loads with the "block_io" attribute but we + // could consider moving that logic to this pass later. We place the load + // and the candidate encoding into the layout map for propagation in step 2 + llvm::MapVector layoutMap; + m.walk([&](DotOp dotOp) { + auto dotOpType = cast(dotOp.getResult().getType()); + auto dpasLayout = dyn_cast(dotOpType.getEncoding()); + if (!dpasLayout) + return; + + getSubgroup2DBlockLayoutForOperand(dotOp.getA(), dpasLayout, layoutMap); + getSubgroup2DBlockLayoutForOperand(dotOp.getB(), dpasLayout, layoutMap); + }); + + // Step 2. Rewrite MakeTensorPtr to use the new layout and propagate the + // change through the def-use chain, terminating at the Load Op. We add a + // ConvertLayout Op after the Load Op to convert back to the original + // layout. Subgroup2DBlockEncoding layouts will be chosen as anchor layouts + // in RemoveLayoutConversions, and a subsequent run of + // RemoveLayoutConversions after this pass cleans up intermediate layout + // conversions and removes the original Load Op encoding. + for (auto &kv : layoutMap) { + rewriteTensorLayoutsForOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu::intel +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 1aeae8f4d3..b26c757072 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -119,6 +119,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPUReduceDataDuplication); ADD_PASS_WRAPPER_0("add_materialize_block_pointer", gpu::intel::createTritonIntelGPUMaterializeBlockPointer); + ADD_PASS_WRAPPER_0( + "add_optimize_block_load_encoding", + gpu::intel::createTritonIntelGPUOptimizeBlockIOEncodingPass); ADD_PASS_WRAPPER_0("add_optimize_reduction_locality", gpu::intel::createTritonIntelGPUOptimizeReductionLocality); ADD_PASS_WRAPPER_0("add_reduce_variable_liveness", From 34f6aa54f0b9fefddfc53487ee4a3b7a93c0c7c4 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 23 Jun 2025 19:10:36 +0000 Subject: [PATCH 02/14] validate scf for loop type changes --- .../optimize-block-io-encoding.mlir | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 68174d5d90..3ea474114d 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s +// RUN: triton-opt %s -split-input-file --tritonintelgpu-optimize-block-io-encoding | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> @@ -8,7 +8,7 @@ // CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { - tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %c4_i32 = arith.constant 4 : i32 %c256_i32 = arith.constant 256 : i32 %c1024_i64 = arith.constant 1024 : i64 @@ -30,17 +30,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar %7 = arith.remsi %0, %c64_i32 : i32 %8 = arith.divsi %7, %4 : i32 %9 = arith.muli %6, %c256_i32 : i32 - // CHECK: tt.make_tensor_ptr {{.*}} : > + // CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : > %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > %11 = arith.muli %8, %c256_i32 : i32 - // CHECK: tt.make_tensor_ptr {{.*}} : > + // CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : > %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]]) %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> - // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> - // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %[[ARG6]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> @@ -50,10 +51,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.tar // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> - // CHECK: tt.advance {{.*}} : > + // CHECK: %[[ADVANCE_A:.*]] = tt.advance {{.*}} : > %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > - // CHECK: tt.advance {{.*}} : > + // CHECK: %[[ADVANCE_B:.*]] = tt.advance {{.*}} : > %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + // CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]] scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> } %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > From a53f8d0e8b79e1bb6311c9897070878cf8b09284 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 23 Jun 2025 19:12:51 +0000 Subject: [PATCH 03/14] reduce set of module attributes --- test/TritonIntelGPU/optimize-block-io-encoding.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 3ea474114d..e162724cf1 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -7,7 +7,7 @@ // CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> // CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { %c4_i32 = arith.constant 4 : i32 %c256_i32 = arith.constant 256 : i32 From 80c41d75829ef0139205cf0ac59ad05416a6dbd2 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 23 Jun 2025 20:53:57 +0000 Subject: [PATCH 04/14] fixup typo, cleanup load cast --- .../lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index 14ca36f31b..214fd6b86a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -77,7 +77,7 @@ struct EncodingInfo { * def-use chain in both directions starting from the Load Op. We store the * values that need to be updated along with the new encoding in the * `valueToEncodingInfo` MapVector. After all value/encoding pairs have been - * determined, we update the encoding for each value, adding aa conversion to + * determined, we update the encoding for each value, adding a conversion to * the existing Load Op result layout for users of the load. */ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { @@ -213,8 +213,7 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass } break; } - return isa(v.getDefiningOp()) ? cast(v.getDefiningOp()) - : nullptr; + return dyn_cast(v.getDefiningOp()); }; LoadOp loadOp = isCandidateLoad(operand); From 272fc3e20369d91c17851d10c9253ae0cf9bbae6 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 23 Jun 2025 20:59:55 +0000 Subject: [PATCH 05/14] break long lines --- .../Dialect/TritonIntelGPU/Transforms/Passes.td | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 91625ba0be..c350c13bc7 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -410,14 +410,21 @@ def TritonIntelGPUReduceVariableLiveness "mlir::arith::ArithDialect"]; } -def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { +def TritonIntelGPUOptimizeBlockIOEncodingPass + : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; let description = [{ - Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a ConvertLayout op to the existing encoding replaces the result of the LoadOp. + Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. + + The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the + encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a + ConvertLayout op to the existing encoding replaces the result of the LoadOp. }]; - let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::triton::TritonDialect"]; + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::gpu::intel::TritonIntelGPUDialect", + "mlir::triton::TritonDialect"]; } #endif // TRITON_INTEL_GPU_PASSES From 1e364c31c2fb0cf865b2e1be472a2c455b7a521e Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 23 Jun 2025 21:05:14 +0000 Subject: [PATCH 06/14] remove dead functions --- .../OptimizeBlockIOEncoding.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index 214fd6b86a..d77d07ce05 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -48,19 +48,6 @@ SmallVector getTiedArgs(Operation *op, int resultIdx) { return {}; } -Type getNewType(Type type, Attribute encoding) { - RankedTensorType tensorType = cast(type); - return RankedTensorType::get(tensorType.getShape(), - tensorType.getElementType(), encoding); -} - -Type getNewPointerType(Type type, Attribute encoding) { - assert(isa(type) && "expected a ptr type!"); - auto oldPointerType = cast(type); - return PointerType::get(getNewType(oldPointerType.getPointeeType(), encoding), - oldPointerType.getAddressSpace()); -} - struct EncodingInfo { Attribute desiredEncoding; bool requiresConvert = false; From d9f8e6af4fa7cb941131c9bb519e7ae26531a293 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 23 Jun 2025 21:05:29 +0000 Subject: [PATCH 07/14] use attr variable in lit test --- .../optimize-block-io-encoding.mlir | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index e162724cf1..ba795634f6 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -3,9 +3,9 @@ #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> -// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> -// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { @@ -30,30 +30,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th %7 = arith.remsi %0, %c64_i32 : i32 %8 = arith.divsi %7, %4 : i32 %9 = arith.muli %6, %c256_i32 : i32 - // CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : > + // CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : > %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > %11 = arith.muli %8, %c256_i32 : i32 - // CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : > + // CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : > %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]]) %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> - // CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> - // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + // CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<256x32xf16, #blocked1> %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> - // CHECK: %[[B_LOAD:.*]] = tt.load %[[ARG6]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> - // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + // CHECK: %[[B_LOAD:.*]] = tt.load %[[ARG6]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<32x256xf16, #blocked2> %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> - // CHECK: %[[ADVANCE_A:.*]] = tt.advance {{.*}} : > + // CHECK: %[[ADVANCE_A:.*]] = tt.advance {{.*}} : > %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > - // CHECK: %[[ADVANCE_B:.*]] = tt.advance {{.*}} : > + // CHECK: %[[ADVANCE_B:.*]] = tt.advance {{.*}} : > %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > // CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]] scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> From 716f589ab4517166de361f91682273708139fdf1 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Tue, 24 Jun 2025 02:04:58 +0000 Subject: [PATCH 08/14] add complex control flow test --- .../optimize-block-io-encoding.mlir | 79 ++++++++++++++++++- 1 file changed, 75 insertions(+), 4 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index ba795634f6..63b406b588 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -1,11 +1,11 @@ -// RUN: triton-opt %s -split-input-file --tritonintelgpu-optimize-block-io-encoding | FileCheck %s +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> -// CHECK: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> -// CHECK: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> -// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { @@ -65,3 +65,74 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th tt.return } } + +// ----- + +// COM: test complex control flow +// COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if. +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { +// CHECK-LABEL: @matmul_change_block_ptr_in_prologue +tt.func @matmul_change_block_ptr_in_prologue(%a_base: !tt.ptr, + %b_base: !tt.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %k_tiles = arith.constant 32 : i64 + %true = arith.constant true + %false = arith.constant false + + %zero = arith.constant dense<0.0> : tensor<128x128xf32, #blocked> + + // CHECK: %[[A_UNDEF:.*]] = ub.poison : !tt.ptr> + // CHECK: %[[B_UNDEF:.*]] = ub.poison : !tt.ptr> + %a_ptr_undef = ub.poison : !tt.ptr> + %b_ptr_undef = ub.poison : !tt.ptr> + // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[A_PTR:.*]] = %[[A_UNDEF]], %[[B_PTR:.*]] = %[[B_UNDEF]]) + scf.for %k = %c0_i64 to %k_tiles step %c1_i64 iter_args(%acc = %zero, %flag = %true, %a_ptr = %a_ptr_undef, %b_ptr = %b_ptr_undef) -> (tensor<128x128xf32, #blocked>, i1, !tt.ptr>, !tt.ptr>) : i64 { + %do_prologue = "prologue_cond"(%k) : (i64) -> i1 + // CHECK: %[[PTRS:.*]]:2 = scf.if {{.*}} -> (!tt.ptr>, !tt.ptr>) + %cur_a_ptr, %cur_b_ptr = scf.if %do_prologue -> (!tt.ptr>, !tt.ptr>) { + %off_m, %off_n, %off_k = "get_offsets"(%k) : (i64) -> (i32, i32, i32) + // CHECK tt.make_tensor_ptr {{.*}} : > + %next_a_ptr = tt.make_tensor_ptr %a_base, [%k, %k], [%c1_i64, %c1_i64], [%off_m, %off_k] {order = array} : > + // CHECK tt.make_tensor_ptr {{.*}} : > + %next_b_ptr = tt.make_tensor_ptr %b_base, [%k, %k], [%c1_i64, %c1_i64], [%off_n, %off_k] {order = array} : > + // CHECK: scf.yield {{.*}} : !tt.ptr>, !tt.ptr> + scf.yield %next_a_ptr, %next_b_ptr : !tt.ptr>, !tt.ptr> + } else { + // CHECK: scf.yield {{.*}} : !tt.ptr>, !tt.ptr> + scf.yield %a_ptr, %b_ptr : !tt.ptr>, !tt.ptr> + } + + // CHECK: %[[A:.*]] = tt.load %[[PTRS]]#0 {{.*}} : !tt.ptr> + %a = tt.load %cur_a_ptr {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: ttg.convert_layout %[[A]] : tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<128x64xf16, #blocked1> + // CHECK: %[[B:.*]] = tt.load %[[PTRS]]#1 {{.*}} : !tt.ptr> + %b = tt.load %cur_b_ptr {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B]] : tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<64x128xf16, #blocked2> + %a_dot = ttg.convert_layout %a : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b_dot = ttg.convert_layout %b : tensor<64x128xf16, #blocked2> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %a_dot_dpas = ttg.convert_layout %a_dot : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %b_dot_dpas = ttg.convert_layout %b_dot : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %accum = ttg.convert_layout %acc : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + %c = tt.dot %a_dot_dpas, %b_dot_dpas, %accum, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %c_out = ttg.convert_layout %c : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> + + %do_epilogue = arith.cmpi eq, %k, %c0_i64 : i64 + %use_acc = arith.select %do_epilogue, %false, %true : i1 + scf.if %do_epilogue { + "acc_user"(%c_out) : (tensor<128x128xf32, #blocked>) -> () + } + // CHECK: scf.yield {{.*}} : {{.*}}, i1, !tt.ptr>, !tt.ptr> + scf.yield %c_out, %use_acc, %cur_a_ptr, %cur_b_ptr : tensor<128x128xf32, #blocked>, i1, !tt.ptr>, !tt.ptr> + } + + tt.return + } +} From 32418054d383fae3cbc417fbee6b776562bee5cc Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Tue, 24 Jun 2025 02:05:13 +0000 Subject: [PATCH 09/14] remove unnecessary assert --- .../lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index d77d07ce05..c20bc1105b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -229,9 +229,6 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass return; } MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); - assert(makeTensorPtrOp && - "expecting a tensor pointer parent to block io load " - "with tensor pointer type"); auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); auto oldTensorType = From c117b37cfcd407702d66523a129c687f3458e925 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Tue, 24 Jun 2025 13:38:07 +0000 Subject: [PATCH 10/14] further reduce unecessary test values --- .../optimize-block-io-encoding.mlir | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 63b406b588..9b7ea8b85b 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -20,22 +20,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th %c64_i32 = arith.constant 64 : i32 %c5120_i32 = arith.constant 5120 : i32 %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> - %0 = tt.get_program_id x : i32 - %1 = arith.divsi %0, %c64_i32 : i32 - %2 = arith.muli %1, %c4_i32 : i32 - %3 = arith.subi %c4_i32, %2 : i32 - %4 = arith.minsi %3, %c4_i32 : i32 - %5 = arith.remsi %0, %4 : i32 - %6 = arith.addi %2, %5 : i32 - %7 = arith.remsi %0, %c64_i32 : i32 - %8 = arith.divsi %7, %4 : i32 - %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : > - %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > - %11 = arith.muli %8, %c256_i32 : i32 + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array} : > // CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : > - %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > - // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]]) + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c256_i32] {order = array} : > + // CHECK: %[[RES:.*]]:3 = scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]]) %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> // CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> @@ -58,7 +48,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]] scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> } - %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c256_i32] {order = array} : > + // CHECK aritch.truncf %[[RES]]#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> From 0cf543569590adf8c849f1a5f8092c56b42857de Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 25 Jun 2025 01:04:01 +0000 Subject: [PATCH 11/14] properly handle while loops --- .../optimize-block-io-encoding.mlir | 52 +++++++++++++++++++ .../OptimizeBlockIOEncoding.cpp | 9 ++-- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index 9b7ea8b85b..e110d449f4 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -1,5 +1,6 @@ // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s +// COM: test complete example #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> @@ -59,6 +60,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- +// COM: Test while loop / tt.advance before tt.load (TODO) +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +// CHECK-DAG: #[[$SUBGROUP_2D_BLOCK:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr) { + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + + // CHECK: %[[A_PTR:.*]] = tt.make_tensor_ptr %arg0, {{.*}} : > + %a_ptr = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array} : > + + // CHECK: scf.while {{.*}} : (!tt.ptr>) -> !tt.ptr> + %1 = scf.while (%a_ptr_crt = %a_ptr) : (!tt.ptr>) -> (!tt.ptr>) { + %2 = "dummy.evaluate_condition"() : () -> i1 + // CHECK: scf.condition({{.*}}) {{.*}} : !tt.ptr> + scf.condition(%2) %a_ptr_crt : !tt.ptr> + } do { + ^bb0(%a_ptr_crt: !tt.ptr>): + // CHECK: ^bb0({{.*}}: !tt.ptr>): + + // CHECK: %[[A_LOAD:.*]] = tt.load {{.*}} : !tt.ptr> + %3 = tt.load %a_ptr_crt {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]> -> tensor<256x32xf16, #[[$BLOCKED]]> + // CHECK: ttg.convert_layout {{.*}} : tensor<256x32xf16, #[[$BLOCKED]]> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> + %4 = ttg.convert_layout %3 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + + %cstB = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> + %5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1> + // COM: TODO: support nested tt.advance + // %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : > + + // CHECK: scf.yield {{.*}} : !tt.ptr> + scf.yield %a_ptr_crt : !tt.ptr> + } + tt.return + } +} + +// ----- + // COM: test complex control flow // COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if. #blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index c20bc1105b..2d0b406d2f 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -26,10 +26,10 @@ SmallVector getTiedArgs(Operation *op, int resultIdx) { } else if (auto whileOp = dyn_cast(op)) { auto iterArg = whileOp.getBeforeArguments()[resultIdx]; auto result = whileOp.getResults()[resultIdx]; - auto yieldVal = - whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx); + auto yieldVal = whileOp.getConditionOp().getArgs()[resultIdx]; auto initVal = whileOp.getOperands()[resultIdx]; - return {iterArg, result, iterArg, initVal}; + auto bodyArg = whileOp.getAfterArguments()[resultIdx]; + return {iterArg, result, yieldVal, initVal, bodyArg}; } else if (auto ifOp = dyn_cast(op)) { SmallVector values; for (auto &block : ifOp.getThenRegion().getBlocks()) { @@ -228,7 +228,10 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass << "\nincompatible with Subgroup 2D Block Layout.\n"); return; } + LLVM_DEBUG(DBGS() << "Retrieving tensor ptr op for ptr " << ptr << "\n"); MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); + LLVM_DEBUG(DBGS() << "Rerwrite encoding for block ptr op " + << makeTensorPtrOp << "\n"); auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); auto oldTensorType = From 56bc07515f633b4da12982c304cd2ea4baff4a7c Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 25 Jun 2025 02:08:00 +0000 Subject: [PATCH 12/14] propagate layout to tt.advance --- test/TritonIntelGPU/optimize-block-io-encoding.mlir | 6 +++--- .../TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir index e110d449f4..beb421548e 100644 --- a/test/TritonIntelGPU/optimize-block-io-encoding.mlir +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -60,7 +60,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // ----- -// COM: Test while loop / tt.advance before tt.load (TODO) +// COM: Test while loop / nested tt.advance #blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> // CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> @@ -99,8 +99,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> %5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1> - // COM: TODO: support nested tt.advance - // %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %7 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : > // CHECK: scf.yield {{.*}} : !tt.ptr> scf.yield %a_ptr_crt : !tt.ptr> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index 2d0b406d2f..f3f3dd670c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -117,6 +117,12 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { } else if (isa(op)) { auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + // The operand will be updated when the MakeTensorPtr op result is + // updated. Make sure the result type matches. + for (auto result : op->getResults()) + if (auto desc = dyn_cast>(result)) + updateEncoding(desc, EncodingInfo{encoding}); } } From 8a9b5b4f566753c80e09ab5a7cb7de704bfda06a Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 27 Jun 2025 00:47:57 +0000 Subject: [PATCH 13/14] remove unused var --- .../lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index f3f3dd670c..bdb988a67a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -242,10 +242,6 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); auto oldTensorType = cast(oldTensorPtrType.getPointeeType()); - // Note: we need the old layout to get the order for the load, but it is not - // clear the layout will always be Blocked. Is there a better way to get - // this info? - auto oldLayout = cast(oldTensorType.getEncoding()); auto CTALayout = getCTALayout(dpasLayout); const unsigned elemSizeInBits = From 7cbe3fe07ff4a07fdcb00c7934eaa0e9b2eed098 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Mon, 14 Jul 2025 17:38:09 +0000 Subject: [PATCH 14/14] address review comments --- .../OptimizeBlockIOEncoding.cpp | 75 +++++-------------- 1 file changed, 18 insertions(+), 57 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp index bdb988a67a..29c48b00b9 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -1,6 +1,7 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/PriorityWorklist.h" namespace ttg = mlir::triton::gpu; @@ -16,45 +17,11 @@ namespace gpu::intel { namespace { -SmallVector getTiedArgs(Operation *op, int resultIdx) { - if (auto forOp = dyn_cast(op)) { - auto iterArg = forOp.getRegionIterArg(resultIdx); - auto result = forOp.getResult(resultIdx); - auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx); - auto initVal = forOp.getInitArgs()[resultIdx]; - return {iterArg, result, yieldVal, initVal}; - } else if (auto whileOp = dyn_cast(op)) { - auto iterArg = whileOp.getBeforeArguments()[resultIdx]; - auto result = whileOp.getResults()[resultIdx]; - auto yieldVal = whileOp.getConditionOp().getArgs()[resultIdx]; - auto initVal = whileOp.getOperands()[resultIdx]; - auto bodyArg = whileOp.getAfterArguments()[resultIdx]; - return {iterArg, result, yieldVal, initVal, bodyArg}; - } else if (auto ifOp = dyn_cast(op)) { - SmallVector values; - for (auto &block : ifOp.getThenRegion().getBlocks()) { - auto terminator = block.getTerminator(); - if (isa(terminator)) - values.push_back(terminator->getOperands()[resultIdx]); - } - for (auto &block : ifOp.getElseRegion().getBlocks()) { - auto terminator = block.getTerminator(); - if (isa(terminator)) - values.push_back(terminator->getOperands()[resultIdx]); - } - values.push_back(ifOp->getResults()[resultIdx]); - return values; - } - return {}; -} - struct EncodingInfo { Attribute desiredEncoding; - bool requiresConvert = false; bool operator==(const EncodingInfo &other) const { - return desiredEncoding == other.desiredEncoding && - requiresConvert == other.requiresConvert; + return desiredEncoding == other.desiredEncoding; } }; @@ -77,10 +44,6 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { auto updateEncoding = [&](ArrayRef ptrValues, EncodingInfo info) { for (auto value : ptrValues) { - bool requiresConvert = llvm::any_of( - value.getUsers(), [](auto user) { return isa(user); }); - info.requiresConvert = requiresConvert; - auto typedVal = cast>(value); auto itr = valueToEncodingInfo.find(typedVal); if (itr == valueToEncodingInfo.end()) { @@ -157,24 +120,22 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding); val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace())); - if (einfo.requiresConvert) { - for (auto user : val.getUsers()) { - if (auto loadOp = dyn_cast(user)) { - - OpBuilder builder(loadOp); - auto oldLoadType = loadOp.getType(); - Value result = loadOp.getResult(); - - builder.setInsertionPointAfter(loadOp); - auto cvt = builder.create(loadOp.getLoc(), - result.getType(), result); - LLVM_DEBUG(DBGS() << "Added convert Op:\n" - << cvt << " after Load Op:\n" - << loadOp << "\n"); - result.setType(newTensorTy); - - result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); - } + for (auto user : val.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + + OpBuilder builder(loadOp); + auto oldLoadType = loadOp.getType(); + Value result = loadOp.getResult(); + + builder.setInsertionPointAfter(loadOp); + auto cvt = builder.create(loadOp.getLoc(), + result.getType(), result); + LLVM_DEBUG(DBGS() << "Added convert Op:\n" + << cvt << " after Load Op:\n" + << loadOp << "\n"); + result.setType(newTensorTy); + + result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); } } }