@@ -503,6 +503,90 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op,
503503 intel::noUnwindWillReturnAttrs);
504504}
505505
506+ template <typename OpTy>
507+ static void
508+ validateMatrix2DBlockParameters (OpTy op,
509+ mlir::ConversionPatternRewriter &rewriter) {
510+ using namespace mlir ;
511+ using namespace mlir ::LLVM;
512+
513+ Location loc = op->getLoc ();
514+ auto b = TritonLLVMOpBuilder (loc, rewriter);
515+ MLIRContext *ctx = rewriter.getContext ();
516+
517+ Value baseWidth = op.getBaseWidth ();
518+ Value baseHeight = op.getBaseHeight ();
519+ Value basePitch = op.getBasePitch ();
520+ Value x = op.getX ();
521+ unsigned elemSize = op.getElemSizeInBits () / 8 ;
522+
523+ if (!baseWidth.getType ().isInteger (32 ))
524+ baseWidth = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), baseWidth);
525+ if (!baseHeight.getType ().isInteger (32 ))
526+ baseHeight =
527+ rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), baseHeight);
528+ if (!basePitch.getType ().isInteger (32 ))
529+ basePitch = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), basePitch);
530+ if (!x.getType ().isInteger (32 ))
531+ x = rewriter.create <ZExtOp>(loc, rewriter.getI32Type (), x);
532+
533+ Value c0 = b.i32_val (0 );
534+ Value c4 = b.i32_val (4 );
535+ Value c64 = b.i32_val (64 );
536+ Value c24m1 = b.i32_val ((1u << 24 ) - 1 ); // 2^24 - 1
537+ Value cElemSize = b.i32_val (elemSize);
538+
539+ // ===== validation predicates =====
540+
541+ // width!=0 && width<2^24 && width%4==0
542+ Value wZero = rewriter.create <ICmpOp>(loc, ICmpPredicate::eq, baseWidth, c0);
543+ Value wTooLarge =
544+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ugt, baseWidth, c24m1);
545+ Value wRem = rewriter.create <URemOp>(loc, baseWidth, c4);
546+ Value wNotAligned = rewriter.create <ICmpOp>(loc, ICmpPredicate::ne, wRem, c0);
547+ Value badWidth = rewriter.create <OrOp>(
548+ loc, wZero, rewriter.create <OrOp>(loc, wTooLarge, wNotAligned));
549+
550+ // height!=0 && height<2^24
551+ Value hZero = rewriter.create <ICmpOp>(loc, ICmpPredicate::eq, baseHeight, c0);
552+ Value hTooLarge =
553+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ugt, baseHeight, c24m1);
554+ Value badHeight = rewriter.create <OrOp>(loc, hZero, hTooLarge);
555+
556+ // pitch >= 64
557+ Value badPitch =
558+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ult, basePitch, c64);
559+
560+ // x*elemSize % 4 == 0
561+ Value offsetBytes = rewriter.create <MulOp>(loc, x, cElemSize);
562+ Value offsetRem = rewriter.create <URemOp>(loc, offsetBytes, c4);
563+ Value badOffset =
564+ rewriter.create <ICmpOp>(loc, ICmpPredicate::ne, offsetRem, c0);
565+
566+ // assert on any
567+ Value anyBad = rewriter.create <OrOp>(
568+ loc, badWidth,
569+ rewriter.create <OrOp>(loc, badHeight,
570+ rewriter.create <OrOp>(loc, badPitch, badOffset)));
571+
572+ Block *curBlock = rewriter.getBlock ();
573+ auto ip = rewriter.getInsertionPoint ();
574+ Block *contBlock = rewriter.splitBlock (curBlock, ip);
575+ Region *region = contBlock->getParent ();
576+ Block *trapBlock = rewriter.createBlock (region, Region::iterator (contBlock));
577+
578+ // TODO: use __assert_fail instead of llvm.intr.trap
579+ rewriter.setInsertionPointToStart (trapBlock);
580+ rewriter.create <Trap>(loc);
581+ rewriter.create <UnreachableOp>(loc);
582+
583+ rewriter.setInsertionPointToEnd (curBlock);
584+ rewriter.create <CondBrOp>(loc, anyBad, trapBlock, ValueRange{}, contBlock,
585+ ValueRange{});
586+
587+ rewriter.setInsertionPointToStart (contBlock);
588+ }
589+
506590namespace {
507591
508592// ===----------------------------------------------------------------------===//
@@ -636,6 +720,8 @@ struct TritonMatrix2DBlockLoadLowering
636720 LogicalResult
637721 matchAndRewrite (TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
638722 ConversionPatternRewriter &rewriter) const override {
723+ validateMatrix2DBlockParameters (op, rewriter);
724+
639725 if (!isSPVBuiltinAvailable (op)) {
640726 // Fallback to GenISA interface.
641727 rewriter.replaceOp (op, createGenISA2DBlockRead (op, rewriter));
@@ -711,6 +797,8 @@ struct TritonMatrix2DBlockStoreLowering
711797 LogicalResult
712798 matchAndRewrite (TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor,
713799 ConversionPatternRewriter &rewriter) const override {
800+ validateMatrix2DBlockParameters (op, rewriter);
801+
714802 if (!isSPVBuiltinAvailable (op)) {
715803 // Fallback to GenISA interface.
716804 rewriter.replaceOp (op, createGenISA2DBlockWrite (op, rewriter));
@@ -785,6 +873,8 @@ struct TritonMatrix2DBlockPrefetchLowering
785873 LogicalResult
786874 matchAndRewrite (TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor,
787875 ConversionPatternRewriter &rewriter) const override {
876+ validateMatrix2DBlockParameters (op, rewriter);
877+
788878 if (!isSPVBuiltinAvailable (op)) {
789879 // Fallback to GenISA interface.
790880 rewriter.replaceOp (op, createGenISA2DBlockPrefetch (op, rewriter));
0 commit comments