diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 5efa2002de..c640d50215 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -44,6 +44,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", + "TRITON_INTEL_2DBLOCK_ASSERT", "TRITON_INTEL_AGGRESSIVE_DPAS_REUSE", "TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS", "TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32", diff --git a/python/test/unit/intel/block_load_helper.py b/python/test/unit/intel/block_load_helper.py new file mode 100644 index 0000000000..3f8865b18b --- /dev/null +++ b/python/test/unit/intel/block_load_helper.py @@ -0,0 +1,50 @@ +import torch +import triton + +import ctypes +import sys + + +def run_load_ir(temp_file, elem_size, *args): + out_type = f"i{int(elem_size) * 4}" + ir = f""" + module attributes {{ttg.target = "xpu", "ttg.num-warps" = 32 : i32, + "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 16 : i32}} {{ + tt.func @dyn_block( + %iptr: i64, %base_width: i32, %base_height: i32, %base_pitch: i32, + %x: i32, %y: i32) {{ + %p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr + + %v = triton_gen.2Dblockload %p0, %base_width, %base_height, + %base_pitch, %x, %y + {{ elem_size_in_bits = {elem_size}, tile_width = 8, tile_height = 8, + v_blocks = 1, transpose = false, + vnni_transform = false, cache_control = Default }} + : (!llvm.ptr, i32, i32, i32, i32, i32) + -> vector<1x{out_type}> + + // To prevent gluon-inline from removing the unused 2Dblockload call. + %v_cast = llvm.bitcast %v : vector<1x{out_type}> to {out_type} + llvm.inline_asm has_side_effects asm_dialect = att + "", "r" %v_cast : ({out_type}) -> () + + tt.return + }} + }} + """ + + with open(temp_file, "w", encoding="utf-8") as f: + f.write(ir) + + kernel = triton.compile(temp_file) + + a = torch.zeros((256, 64), dtype=torch.float32, device="xpu") + + addr = ctypes.c_int64(a.data_ptr()).value + + kernel[(1, 1, 1)](addr, *map(int, args), 0) + + +if __name__ == "__main__": + fn = globals()[sys.argv[1]] + fn(*sys.argv[2:]) diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 4d4144946b..de74417863 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -1,5 +1,10 @@ import pytest import torch + +import os +import signal +import subprocess +import sys import pathlib from functools import partial @@ -207,3 +212,45 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False): result_tor = fn_tor() result_tri = fn_tri() torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3) + + +@pytest.mark.parametrize("elem_size, width, height, pitch, x", + [[8, 16777216, 64, 16777216, 0], # width <= 24 bits + [8, 32, 64, 128, 0], # width >= 64 + [8, 66, 64, 128, 0], # width % max(4,elemSize) == 0 + [8, 128, 16777216, 128, 0], # height <= 24 bits + [8, 128, 64, 16777216, 0], # pitch <= 24 bits + [8, 128, 64, 32, 0], # pitch >= 64 + [8, 128, 64, 70, 0], # pitch % 16 == 0 + [8, 128, 64, 120, 0], # pitch >= width + [8, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 8-bit) + [16, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 16-bit) + ]) +@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend") +@pytest.mark.xfail( + not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] + and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']), + reason="Block loads and/or DPAS not supported on this architecture", run=False) +def test_block_load_asserts(elem_size, width, height, pitch, x, monkeypatch, tmp_path: pathlib.Path): + monkeypatch.setenv("TRITON_INTEL_2DBLOCK_ASSERT", "1") + + dir_path = os.path.dirname(os.path.realpath(__file__)) + helper_path = os.path.join(dir_path, "block_load_helper.py") + + temp_file = tmp_path / "test_block_load_asserts.ttgir" + + proc = subprocess.run( + [ + sys.executable, helper_path, "run_load_ir", + str(temp_file), + str(elem_size), + str(width), + str(height), + str(pitch), + str(x) + ], + capture_output=True, + ) + + rc = proc.returncode + assert rc == -signal.SIGABRT diff --git a/test/TritonGEN/tritongen-2Dblockload-to-llvm-asserts.mlir b/test/TritonGEN/tritongen-2Dblockload-to-llvm-asserts.mlir new file mode 100644 index 0000000000..dd7e0783a8 --- /dev/null +++ b/test/TritonGEN/tritongen-2Dblockload-to-llvm-asserts.mlir @@ -0,0 +1,33 @@ +// RUN: env TRITON_INTEL_2DBLOCK_ASSERT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefix=ASSERT +// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefix=NOASSERT + +module attributes {"ttg.threads-per-warp" = 16 : i32} { +llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // ASSERT: llvm.call spir_funccc @__assert_fail + // NOASSERT-NOT: __assert_fail + %0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi16> + llvm.return +} +} + +// ----- + +module attributes {"ttg.threads-per-warp" = 16 : i32} { +llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // ASSERT: llvm.call spir_funccc @__assert_fail + // NOASSERT-NOT: __assert_fail + triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} +} + +// ----- + +module attributes {"ttg.threads-per-warp" = 16 : i32} { +llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { + // ASSERT: llvm.call spir_funccc @__assert_fail + // NOASSERT-NOT: __assert_fail + triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) + llvm.return +} +} diff --git a/third_party/intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h b/third_party/intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h index 6b06feff40..065d9814b1 100644 --- a/third_party/intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h +++ b/third_party/intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h @@ -21,8 +21,13 @@ namespace triton { #define GEN_PASS_DECL #include "intel/include/TritonGENToLLVM/Passes.h.inc" -void populateTritonGENToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); +namespace gpu::intel { +class LibCallEmitter; +} // namespace gpu::intel + +void populateTritonGENToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + const mlir::triton::gpu::intel::LibCallEmitter &emitter); void registerConvertTritonGENToLLVMInterface(DialectRegistry ®istry); diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index cf25e8ba28..4a52a1fd35 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -8,6 +8,7 @@ #include "Attributes.h" #include "Utils/LLVMIntr.h" +#include "Utils/LibCallEmitter.h" #include "Utils/Mangling.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" @@ -35,6 +36,7 @@ #include "llvm/Support/ErrorHandling.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" @@ -508,6 +510,131 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op, intel::noUnwindWillReturnAttrs); } +static void +createAssertNot(ConversionPatternRewriter &rewriter, + const mlir::triton::gpu::intel::LibCallEmitter &emitter, + Value condition, StringRef message) { + + auto *ctx = rewriter.getContext(); + auto loc = rewriter.getInsertionPoint() != rewriter.getBlock()->end() + ? rewriter.getInsertionPoint()->getLoc() + : UnknownLoc::get(ctx); + + StringRef file = "unknown"; + StringRef func = "unknown"; + int line = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + while (auto nameLoc = dyn_cast(loc)) + loc = nameLoc.getChildLoc(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + file = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + } + + Block *prevBlock = rewriter.getBlock(); + auto insertPt = rewriter.getInsertionPoint(); + + Block *thenBlock = rewriter.splitBlock(prevBlock, insertPt); + + Block *ifBlock = rewriter.createBlock(prevBlock->getParent()); + rewriter.setInsertionPointToStart(ifBlock); + emitter.assertFail(rewriter, loc, message, file, func, line); + rewriter.create(loc, thenBlock); + + rewriter.setInsertionPointToEnd(prevBlock); + rewriter.create(loc, condition, ifBlock, thenBlock); + + rewriter.setInsertionPointToStart(thenBlock); +} + +static void create2DBlockAssertsImpl( + const mlir::Value &baseWidth, const mlir::Value &baseHeight, + const mlir::Value &basePitch, const mlir::Value &x, unsigned int elemSize, + const mlir::Location &loc, mlir::ConversionPatternRewriter &rewriter, + const mlir::triton::gpu::intel::LibCallEmitter &emitter) { + using namespace mlir; + using namespace mlir::LLVM; + + auto b = TritonLLVMOpBuilder(loc, rewriter); + + Value c0 = b.i32_val(0); + Value c4 = b.i32_val(4); + Value c64 = b.i32_val(64); + Value c16 = b.i32_val(16); + Value c24m1 = b.i32_val((1u << 24) - 1); + Value cElemSize = b.i32_val(elemSize); + Value cMaxAlign = b.i32_val(std::max(4u, elemSize)); + + Value wTooLarge = + rewriter.create(loc, ICmpPredicate::ugt, baseWidth, c24m1); + createAssertNot(rewriter, emitter, wTooLarge, + "2nd operand (base width) should be <= 24 bits"); + + Value wTooSmall = + rewriter.create(loc, ICmpPredicate::ult, baseWidth, c64); + createAssertNot(rewriter, emitter, wTooSmall, + "2nd operand (base width) should be >= 64"); + + Value wRem = rewriter.create(loc, baseWidth, cMaxAlign); + Value wNotAligned = rewriter.create(loc, ICmpPredicate::ne, wRem, c0); + createAssertNot( + rewriter, emitter, wNotAligned, + "2nd operand (base width) should be aligned to MAX(4, element_size)"); + + Value hTooLarge = + rewriter.create(loc, ICmpPredicate::ugt, baseHeight, c24m1); + createAssertNot(rewriter, emitter, hTooLarge, + "3rd operand (base height) should be <= 24 bits"); + + Value pTooLarge = + rewriter.create(loc, ICmpPredicate::ugt, basePitch, c24m1); + createAssertNot(rewriter, emitter, pTooLarge, + "4th operand (base pitch) should be <= 24 bits"); + + Value pTooSmall = + rewriter.create(loc, ICmpPredicate::ult, basePitch, c64); + createAssertNot(rewriter, emitter, pTooSmall, + "4th operand (base pitch) should be >= 64"); + + Value pRem = rewriter.create(loc, basePitch, c16); + Value pNotAligned = rewriter.create(loc, ICmpPredicate::ne, pRem, c0); + createAssertNot(rewriter, emitter, pNotAligned, + "4th operand (base pitch) should be a multiple of 16 bytes"); + + Value pLessThanWidth = + rewriter.create(loc, ICmpPredicate::ult, basePitch, baseWidth); + createAssertNot( + rewriter, emitter, pLessThanWidth, + "4th operand (base pitch) should be >= 2nd operand (base width)"); + + Value offsetBytes = rewriter.create(loc, x, cElemSize); + Value offsetRem = rewriter.create(loc, offsetBytes, c4); + Value badOffset = + rewriter.create(loc, ICmpPredicate::ne, offsetRem, c0); + createAssertNot( + rewriter, emitter, badOffset, + "5th operand (x) should be properly aligned for the element size"); +} + +template +static void +create2DBlockAsserts(OpTy op, mlir::ConversionPatternRewriter &rewriter, + const mlir::triton::gpu::intel::LibCallEmitter &emitter) { + + if (!triton::tools::getBoolEnv("TRITON_INTEL_2DBLOCK_ASSERT")) { + return; + } + + // put implementation in a separate function to avoid template bloat + create2DBlockAssertsImpl( + op.getBaseWidth(), op.getBaseHeight(), op.getBasePitch(), op.getX(), + op.getElemSizeInBits() / 8, op->getLoc(), rewriter, emitter); +} + namespace { //===----------------------------------------------------------------------===// @@ -638,9 +765,17 @@ struct TritonMatrix2DBlockLoadLowering using ConvertOpToLLVMPattern< TritonGEN::Matrix2DBlockLoadOp>::ConvertOpToLLVMPattern; + explicit TritonMatrix2DBlockLoadLowering( + LLVMTypeConverter &typeConverter, + const mlir::triton::gpu::intel::LibCallEmitter &emitter) + : ConvertOpToLLVMPattern(typeConverter), + emitter(emitter) {} + LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + create2DBlockAsserts(op, rewriter, emitter); + if (!isSPVBuiltinAvailable(op)) { // Fallback to GenISA interface. rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter)); @@ -706,6 +841,9 @@ struct TritonMatrix2DBlockLoadLowering rewriter.replaceOp(op, rewriter.create(loc, resType, dest)); return success(); } + +private: + const mlir::triton::gpu::intel::LibCallEmitter &emitter; }; struct TritonMatrix2DBlockStoreLowering @@ -713,9 +851,17 @@ struct TritonMatrix2DBlockStoreLowering using ConvertOpToLLVMPattern< TritonGEN::Matrix2DBlockStoreOp>::ConvertOpToLLVMPattern; + explicit TritonMatrix2DBlockStoreLowering( + LLVMTypeConverter &typeConverter, + const mlir::triton::gpu::intel::LibCallEmitter &emitter) + : ConvertOpToLLVMPattern(typeConverter), + emitter(emitter) {} + LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + create2DBlockAsserts(op, rewriter, emitter); + if (!isSPVBuiltinAvailable(op)) { // Fallback to GenISA interface. rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter)); @@ -780,6 +926,9 @@ struct TritonMatrix2DBlockStoreLowering rewriter.replaceOp(op, call); return success(); } + +protected: + const mlir::triton::gpu::intel::LibCallEmitter &emitter; }; struct TritonMatrix2DBlockPrefetchLowering @@ -787,9 +936,18 @@ struct TritonMatrix2DBlockPrefetchLowering using ConvertOpToLLVMPattern< TritonGEN::Matrix2DBlockPrefetchOp>::ConvertOpToLLVMPattern; + explicit TritonMatrix2DBlockPrefetchLowering( + LLVMTypeConverter &typeConverter, + const mlir::triton::gpu::intel::LibCallEmitter &emitter) + : ConvertOpToLLVMPattern( + typeConverter), + emitter(emitter) {} + LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + create2DBlockAsserts(op, rewriter, emitter); + if (!isSPVBuiltinAvailable(op)) { // Fallback to GenISA interface. rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter)); @@ -846,6 +1004,9 @@ struct TritonMatrix2DBlockPrefetchLowering rewriter.replaceOp(op, call); return success(); } + +private: + const mlir::triton::gpu::intel::LibCallEmitter &emitter; }; template loadDialect(); - } - - /// Hook for derived dialect interface to provide conversion patterns - /// and mark dialect legal for the conversion target. - void populateConvertToLLVMConversionPatterns( - ConversionTarget &target, LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns) const final { - populateTritonGENToLLVMConversionPatterns(typeConverter, patterns); - } -}; - -} // namespace - //===----------------------------------------------------------------------===// // Pattern Population and Registration //===----------------------------------------------------------------------===// void mlir::triton::populateTritonGENToLLVMConversionPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { + LLVMTypeConverter &converter, RewritePatternSet &patterns, + const mlir::triton::gpu::intel::LibCallEmitter &emitter) { patterns - .add(converter); -} + .add(converter, emitter); -void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry ®istry) { - registry.addExtension( - +[](MLIRContext *ctx, TritonGEN::TritonGENDialect *dialect) { - dialect->addInterfaces(); - }); + patterns.add( + converter); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index 788d746f85..3a466f9874 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -150,143 +150,22 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -Value printfPromoteValue(RewriterBase &rewriter, Value value, bool isSigned) { - auto type = value.getType(); - if (isa(type) && type.getIntOrFloatBitWidth() == 1) { - // FIXME: There is some problem when using i1 type now, - // remove this code once IGC fix the problem. - TritonLLVMOpBuilder b(rewriter.getUnknownLoc(), rewriter); - return b.zext(i8_ty, value); - } else if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { - TritonLLVMOpBuilder b(rewriter.getUnknownLoc(), rewriter); - if (isSigned) { - return b.sext(i32_ty, value); - } else { - return b.zext(i32_ty, value); - } - } else { - return value; - } -} - -// declare __spirv_ocl_printf(i8*, ...) as external function -static LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) { - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("_Z18__spirv_ocl_printf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - MLIRContext *context = rewriter.getContext(); - auto ptrTy = LLVM::LLVMPointerType::get( - context, TritonGEN::TritonGENMemorySpace::kUniformConstant); - SmallVector argsType{ptrTy}; - auto retType = i32_ty; - auto funcType = - LLVM::LLVMFunctionType::get(retType, argsType, /*isVarArg*/ true); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - auto printFunc = rewriter.create( - UnknownLoc::get(context), funcName, funcType, LLVM::Linkage::External, - /*dsoLocal*/ false, LLVM::CConv::SPIR_FUNC, /*comdat=*/SymbolRefAttr{}); - printFunc->setAttr("nounwind", rewriter.getUnitAttr()); - - return printFunc; -} - void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, int /*formatStrByteCount*/, ValueRange args, ArrayRef isSigned) const { - auto *ctx = rewriter.getContext(); - Type ptr = ptr_ty(ctx); - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - auto funcOp = getSpirvPrintfDeclaration(rewriter); - auto loc = UnknownLoc::get(ctx); - auto b = TritonLLVMOpBuilder(loc, rewriter); - - SmallVector operands; - operands.push_back(formatStrStart); - for (auto [i, arg] : llvm::enumerate(args)) { - operands.push_back(printfPromoteValue( - rewriter, arg, isSigned.empty() ? true : isSigned[i])); - } - auto callOp = b.call(funcOp, operands); - callOp.setCConv(triton::gpu::intel::getRequiredCConv(callOp)); + emitter.printf(rewriter, formatStrStart, /*formatStrByteCount*/ 0, args, + isSigned); } void TargetInfo::printf(RewriterBase &rewriter, StringRef msg, ValueRange args, ArrayRef isSigned) const { - assert(!msg.empty() && "printf with empty string not supported"); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value msgValue = getGlobalStringStart( - rewriter.getUnknownLoc(), rewriter, "printfFormat_", msgNewline, - /*addressSpace=*/TritonGEN::kUniformConstant); - printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, isSigned); -} - -static LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName = "__assert_fail"; - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - // void __assert_fail(const char * assertion, const char * file, unsigned - // int line, const char * function); - auto *ctx = rewriter.getContext(); - SmallVector argsType; - argsType = {ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), - ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), i32_ty, - ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric)}; - auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); - - RewriterBase::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - auto func = rewriter.create(UnknownLoc::get(ctx), funcName, - funcType); - func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); - return func; + emitter.printf(rewriter, msg, args, isSigned); } void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { - auto b = TritonLLVMOpBuilder(loc, rewriter); - auto funcOp = getAssertfailDeclaration(rewriter); - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - unsigned addrSpace = TritonGEN::TritonGENMemorySpace::kCrossWorkgroup; - llvm::SmallString<64> messageString(message), fileString(file), - funcString(func); - messageString.push_back('\0'); - fileString.push_back('\0'); - funcString.push_back('\0'); - Value messageStringVal = - getGlobalStringStart(loc, rewriter, "assertMessage_", messageString, - /*addressSpace=*/TritonGEN::kCrossWorkgroup); - Value fileStringVal = - getGlobalStringStart(loc, rewriter, "assertFile_", fileString, - /*addressSpace=*/TritonGEN::kCrossWorkgroup); - Value funcStringVal = - getGlobalStringStart(loc, rewriter, "assertFunc_", funcString, - /*addressSpace=*/TritonGEN::kCrossWorkgroup); - Value lineNumber = b.i32_val(line); - - auto *ctx = rewriter.getContext(); - SmallVector operands; - Value messageStringPtr = b.addrspacecast( - ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), messageStringVal); - Value fileStringPtr = b.addrspacecast( - ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), fileStringVal); - Value funcStringPtr = b.addrspacecast( - ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), funcStringVal); - operands = {messageStringPtr, fileStringPtr, lineNumber, funcStringPtr}; - auto ret = b.call(funcOp, operands); - ret.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + return emitter.assertFail(rewriter, loc, message, file, func, line); } int TargetInfo::getSharedAddressSpace() const { @@ -312,47 +191,7 @@ int TargetInfo::getAddressSpace(Attribute addressSpace) const { Value TargetInfo::getGlobalStringStart(Location loc, RewriterBase &rewriter, StringRef name, StringRef value, unsigned addressSpace) const { - auto b = TritonLLVMOpBuilder(loc, rewriter); - LLVM::GlobalOp global = - getGlobalString(loc, rewriter, name, value, addressSpace); - MLIRContext *ctx = rewriter.getContext(); - Type globalPtrType = ptr_ty(ctx, addressSpace); - Value globalPtr = rewriter.create(loc, global); - return b.gep(globalPtrType, i8_ty, globalPtr, LLVM::GEPArg{0}); -} - -LLVM::GlobalOp TargetInfo::getGlobalString(Location loc, RewriterBase &rewriter, - StringRef name, StringRef value, - unsigned addressSpace) const { - StringAttr valueAttr = rewriter.getStringAttr(value); - std::pair cacheKey{addressSpace, valueAttr}; - auto pos = globals.find(cacheKey); - if (pos != globals.end()) - return pos->second; - - ModuleOp moduleOp = rewriter.getInsertionPoint()->getParentOfType(); - - llvm::SmallString<64> contentStr(value); - size_t contentSize = contentStr.size_in_bytes(); - auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); - - auto createGlobal = [&](StringRef name) { - RewriterBase::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - return rewriter.create( - rewriter.getUnknownLoc(), globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, name, valueAttr, - /*alignment=*/0, addressSpace); - }; - - LLVM::GlobalOp global = - moduleOp.lookupSymbol(name) - ? createGlobal(Twine{name}.concat(Twine{globals.size()}).str()) - : createGlobal(name); - - globals.try_emplace(cacheKey, global); - - return global; + return emitter.getGlobalStringStart(loc, rewriter, name, value, addressSpace); } std::unique_ptr createTargetInfo(ModuleOp mod) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index f63c55c7f2..414e59ce12 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -9,6 +9,7 @@ #ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOINTEL_H #define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOINTEL_H +#include "Utils/LibCallEmitter.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include @@ -87,8 +88,7 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef name, StringRef value, unsigned addressSpace) const; - mutable llvm::DenseMap, LLVM::GlobalOp> - globals; + const mlir::triton::gpu::intel::LibCallEmitter emitter; }; std::unique_ptr createTargetInfo(ModuleOp mod); diff --git a/third_party/intel/lib/Utils/CMakeLists.txt b/third_party/intel/lib/Utils/CMakeLists.txt index 0731c57759..da651add27 100644 --- a/third_party/intel/lib/Utils/CMakeLists.txt +++ b/third_party/intel/lib/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonIntelUtils DefUseChain.cpp + LibCallEmitter.cpp LLVMIntr.cpp Mangling.cpp Utility.cpp diff --git a/third_party/intel/lib/Utils/LibCallEmitter.cpp b/third_party/intel/lib/Utils/LibCallEmitter.cpp new file mode 100644 index 0000000000..35297bc26f --- /dev/null +++ b/third_party/intel/lib/Utils/LibCallEmitter.cpp @@ -0,0 +1,206 @@ +#include "LibCallEmitter.h" + +#include "Dialect/TritonIntelGPU/IR/Utils.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinOps.h" + +using namespace mlir; + +namespace mlir::triton::gpu::intel { + +static Value printfPromoteValue(RewriterBase &rewriter, Value value, + bool isSigned) { + auto type = value.getType(); + if (isa(type) && type.getIntOrFloatBitWidth() == 1) { + // FIXME: There is some problem when using i1 type now, + // remove this code once IGC fix the problem. + TritonLLVMOpBuilder b(rewriter.getUnknownLoc(), rewriter); + return b.zext(i8_ty, value); + } else if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + TritonLLVMOpBuilder b(rewriter.getUnknownLoc(), rewriter); + if (isSigned) { + return b.sext(i32_ty, value); + } else { + return b.zext(i32_ty, value); + } + } else { + return value; + } +} + +// declare __spirv_ocl_printf(i8*, ...) as external function +static LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("_Z18__spirv_ocl_printf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + MLIRContext *context = rewriter.getContext(); + auto ptrTy = LLVM::LLVMPointerType::get( + context, TritonGEN::TritonGENMemorySpace::kUniformConstant); + SmallVector argsType{ptrTy}; + auto retType = i32_ty; + auto funcType = + LLVM::LLVMFunctionType::get(retType, argsType, /*isVarArg*/ true); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto printFunc = rewriter.create( + UnknownLoc::get(context), funcName, funcType, LLVM::Linkage::External, + /*dsoLocal*/ false, LLVM::CConv::SPIR_FUNC, /*comdat=*/SymbolRefAttr{}); + printFunc->setAttr("nounwind", rewriter.getUnitAttr()); + + return printFunc; +} + +static LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = "__assert_fail"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + // void __assert_fail(const char * assertion, const char * file, unsigned + // int line, const char * function); + auto *ctx = rewriter.getContext(); + SmallVector argsType; + argsType = {ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), + ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), i32_ty, + ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric)}; + auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); + + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto func = rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); + func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + return func; +} + +Value LibCallEmitter::getGlobalStringStart(Location loc, RewriterBase &rewriter, + StringRef name, StringRef value, + unsigned addressSpace) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + LLVM::GlobalOp global = + getGlobalString(loc, rewriter, name, value, addressSpace); + MLIRContext *ctx = rewriter.getContext(); + Type globalPtrType = ptr_ty(ctx, addressSpace); + Value globalPtr = rewriter.create(loc, global); + return b.gep(globalPtrType, i8_ty, globalPtr, LLVM::GEPArg{0}); +} + +LLVM::GlobalOp LibCallEmitter::getGlobalString(Location loc, + RewriterBase &rewriter, + StringRef name, StringRef value, + unsigned addressSpace) const { + StringAttr valueAttr = rewriter.getStringAttr(value); + std::pair cacheKey{addressSpace, valueAttr}; + auto pos = globals.find(cacheKey); + if (pos != globals.end()) + return pos->second; + + ModuleOp moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + + llvm::SmallString<64> contentStr(value); + size_t contentSize = contentStr.size_in_bytes(); + auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize); + + auto createGlobal = [&](StringRef name) { + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + return rewriter.create( + rewriter.getUnknownLoc(), globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, name, valueAttr, + /*alignment=*/0, addressSpace); + }; + + LLVM::GlobalOp global = + moduleOp.lookupSymbol(name) + ? createGlobal(Twine{name}.concat(Twine{globals.size()}).str()) + : createGlobal(name); + + globals.try_emplace(cacheKey, global); + + return global; +} + +//===----------------------------------------------------------------------===// +// Public API +//===----------------------------------------------------------------------===// + +void LibCallEmitter::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args, + ArrayRef isSigned) const { + auto *ctx = rewriter.getContext(); + Type ptr = ptr_ty(ctx); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + auto funcOp = getSpirvPrintfDeclaration(rewriter); + auto loc = UnknownLoc::get(ctx); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + SmallVector operands; + operands.push_back(formatStrStart); + for (auto [i, arg] : llvm::enumerate(args)) { + operands.push_back(printfPromoteValue( + rewriter, arg, isSigned.empty() ? true : isSigned[i])); + } + auto callOp = b.call(funcOp, operands); + callOp.setCConv(triton::gpu::intel::getRequiredCConv(callOp)); +} + +void LibCallEmitter::printf(RewriterBase &rewriter, StringRef msg, + ValueRange args, ArrayRef isSigned) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = getGlobalStringStart( + rewriter.getUnknownLoc(), rewriter, "printfFormat_", msgNewline, + /*addressSpace=*/TritonGEN::kUniformConstant); + printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, isSigned); +} + +void LibCallEmitter::assertFail(RewriterBase &rewriter, Location loc, + StringRef message, StringRef file, + StringRef func, int line) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto funcOp = getAssertfailDeclaration(rewriter); + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + unsigned addrSpace = TritonGEN::TritonGENMemorySpace::kCrossWorkgroup; + llvm::SmallString<64> messageString(message), fileString(file), + funcString(func); + messageString.push_back('\0'); + fileString.push_back('\0'); + funcString.push_back('\0'); + Value messageStringVal = + getGlobalStringStart(loc, rewriter, "assertMessage_", messageString, + /*addressSpace=*/TritonGEN::kCrossWorkgroup); + Value fileStringVal = + getGlobalStringStart(loc, rewriter, "assertFile_", fileString, + /*addressSpace=*/TritonGEN::kCrossWorkgroup); + Value funcStringVal = + getGlobalStringStart(loc, rewriter, "assertFunc_", funcString, + /*addressSpace=*/TritonGEN::kCrossWorkgroup); + Value lineNumber = b.i32_val(line); + + auto *ctx = rewriter.getContext(); + SmallVector operands; + Value messageStringPtr = b.addrspacecast( + ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), messageStringVal); + Value fileStringPtr = b.addrspacecast( + ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), fileStringVal); + Value funcStringPtr = b.addrspacecast( + ptr_ty(ctx, TritonGEN::TritonGENMemorySpace::kGeneric), funcStringVal); + operands = {messageStringPtr, fileStringPtr, lineNumber, funcStringPtr}; + auto ret = b.call(funcOp, operands); + ret.setCConv(LLVM::cconv::CConv::SPIR_FUNC); +} + +} // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/Utils/LibCallEmitter.h b/third_party/intel/lib/Utils/LibCallEmitter.h new file mode 100644 index 0000000000..8ee732916f --- /dev/null +++ b/third_party/intel/lib/Utils/LibCallEmitter.h @@ -0,0 +1,45 @@ +//===- LibCallEmitter.h - Emit library calls for Intel backend --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_INTEL_UTILS_LIBCALLEMITTER_H +#define TRITON_INTEL_UTILS_LIBCALLEMITTER_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::gpu::intel { + +class LibCallEmitter { +public: + LibCallEmitter() = default; + + void printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args, + ArrayRef isSigned = {}) const; + + void printf(RewriterBase &rewriter, StringRef msg, ValueRange args, + ArrayRef isSigned = {}) const; + + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const; + + Value getGlobalStringStart(Location loc, RewriterBase &rewriter, + StringRef name, StringRef value, + unsigned addressSpace) const; + +private: + LLVM::GlobalOp getGlobalString(Location loc, RewriterBase &rewriter, + StringRef name, StringRef value, + unsigned addressSpace) const; + + mutable llvm::DenseMap, LLVM::GlobalOp> + globals; +}; + +} // namespace mlir::triton::gpu::intel + +#endif // TRITON_INTEL_UTILS_LIBCALLEMITTER_H