Skip to content

Commit 61ce6d7

Browse files
[MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations (#148198)
This builds upon the framework established by #149810 to add lowering to `bfmmla`.
1 parent c295f05 commit 61ce6d7

File tree

12 files changed

+541
-44
lines changed

12 files changed

+541
-44
lines changed

mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
1717
"apply_patterns.arm_neon.vector_contract_to_i8mm",
1818
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
1919
let description = [{
20-
Indicates that vector.contract operations should be lowered to
21-
finer-grained vector primitives from the ArmNeon dialect.
20+
Indicates that vector contract operations should be lowered to
21+
to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
22+
}];
23+
24+
let assemblyFormat = "attr-dict";
25+
}
26+
27+
def ApplyArmNeonContractionToBFMMLAPatternsOp
28+
: Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
29+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
30+
let description = [{
31+
Indicates that vector contract operations should be lowered to
32+
to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
2233
}];
2334

2435
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/ArmNeon/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace mlir {
1313
class RewritePatternSet;
1414

1515
namespace arm_neon {
16-
void populateLowerContractionToNeonI8MMPatternPatterns(
17-
RewritePatternSet &patterns);
16+
void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
17+
void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
1818
} // namespace arm_neon
1919

2020
} // namespace mlir

mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ class RewritePatternSet;
2020
void populateArmSVELegalizeForLLVMExportPatterns(
2121
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
2222

23-
void populateLowerContractionToSVEI8MMPatternPatterns(
24-
RewritePatternSet &patterns);
23+
void populateLowerContractionToSVEI8MMPatterns(RewritePatternSet &patterns);
2524

2625
void populateLowerContractionToSVEBFMMLAPatterns(RewritePatternSet &patterns);
2726

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,16 @@ void ConvertVectorToLLVMPass::runOnOperation() {
9696
populateVectorGatherLoweringPatterns(patterns);
9797
if (armI8MM) {
9898
if (armNeon)
99-
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
99+
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
100100
if (armSVE)
101-
populateLowerContractionToSVEI8MMPatternPatterns(patterns);
101+
populateLowerContractionToSVEI8MMPatterns(patterns);
102+
}
103+
if (armBF16) {
104+
if (armNeon)
105+
arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
106+
if (armSVE)
107+
populateLowerContractionToSVEBFMMLAPatterns(patterns);
102108
}
103-
if (armBF16)
104-
populateLowerContractionToSVEBFMMLAPatterns(patterns);
105-
106109
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
107110
}
108111

mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ using namespace mlir;
2020

2121
void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
2222
RewritePatternSet &patterns) {
23-
arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
23+
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
24+
}
25+
26+
void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
27+
RewritePatternSet &patterns) {
28+
arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
2429
}
2530

2631
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
add_mlir_dialect_library(MLIRArmNeonTransforms
2-
LowerContractionToNeonI8MMPattern.cpp
2+
LowerContractToNeonPatterns.cpp
33

44
DEPENDS
55
MLIRArmNeonIncGen

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp renamed to mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp

Lines changed: 99 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===//
1+
//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -93,15 +93,20 @@ class VectorContractRewriter {
9393
// multiplications.
9494
enum class MMLA {
9595
Nop,
96-
Signed, // smmla
97-
Unsigned, // ummla
98-
Mixed, // usmmla
99-
MixedSwapped // usmmla with LHS and RHS swapped
96+
SignedInt, // smmla
97+
UnsignedInt, // ummla
98+
MixedInt, // usmmla
99+
Bfloat // bfmmla
100100
};
101101

102102
// Lower-level operation to be emitted.
103103
MMLA mmlaOp = MMLA::Nop;
104104

105+
// Indicate if the operands for the ArmNeon dialect operation need to be
106+
// swapped. Currently this is needed in order to emulate an "summla"
107+
// operation.
108+
bool swapOperands = false;
109+
105110
// The operand tiles. These are not necessarily the operands of
106111
// `vector.contract`, for example they could be operands to `arith.extsi`
107112
// that is in turn fed into `vector.contract`.
@@ -126,21 +131,22 @@ class VectorContractRewriter {
126131
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
127132
Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
128133
Value lhs, Value rhs) {
134+
135+
if (swapOperands)
136+
std::swap(lhs, rhs);
129137
switch (mmlaOp) {
130-
case MMLA::Signed:
138+
case MMLA::SignedInt:
131139
return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
132140
lhs, rhs);
133-
case MMLA::Unsigned:
141+
case MMLA::UnsignedInt:
134142
return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
135143
lhs, rhs);
136-
case MMLA::Mixed:
144+
case MMLA::MixedInt:
137145
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
138146
lhs, rhs);
139-
case MMLA::MixedSwapped:
140-
// The accumulator comes transposed and the result will be transposed
141-
// later, so all we have to do here is swap the operands.
142-
return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
143-
rhs, lhs);
147+
case MMLA::Bfloat:
148+
return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
149+
rhs);
144150
case MMLA::Nop:
145151
llvm_unreachable("Uninitialized operation type");
146152
}
@@ -273,7 +279,7 @@ class VectorContractRewriter {
273279
// Transpose ACC if doing signed by unsigned multiplication, because we're
274280
// using the instruction for unsigned by signed multiplication with
275281
// reversed operands.
276-
if (mmlaOp == MMLA::MixedSwapped)
282+
if (swapOperands)
277283
tiledAcc = rewriter.create<vector::TransposeOp>(
278284
loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
279285

@@ -302,7 +308,7 @@ class VectorContractRewriter {
302308

303309
// Because of the reversed operands the result is obtained transposed.
304310
// Transpose it back,
305-
if (mmlaOp == MMLA::MixedSwapped)
311+
if (swapOperands)
306312
tiledRes = rewriter.create<vector::TransposeOp>(
307313
loc, tiledRes, ArrayRef<int64_t>({1, 0}));
308314

@@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
339345
// values before the extension. All four signed/unsigned combinations for
340346
// input operands are supported, but they are lowered to different
341347
// operations. Determine which is the appropriate operation to lower to.
342-
mmlaOp = MMLA::Signed;
348+
mmlaOp = MMLA::SignedInt;
343349
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
344350
if (!maybeLhs) {
345-
mmlaOp = MMLA::Unsigned;
351+
mmlaOp = MMLA::UnsignedInt;
346352
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
347353
}
348354
if (!maybeLhs)
@@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
351357

352358
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
353359
if (maybeRhs) {
354-
if (mmlaOp == MMLA::Unsigned)
355-
mmlaOp = MMLA::Mixed;
360+
if (mmlaOp == MMLA::UnsignedInt)
361+
mmlaOp = MMLA::MixedInt;
356362
} else {
357-
if (mmlaOp == MMLA::Signed)
358-
mmlaOp = MMLA::MixedSwapped;
363+
if (mmlaOp == MMLA::SignedInt) {
364+
mmlaOp = MMLA::MixedInt;
365+
swapOperands = true;
366+
}
359367
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
360368
}
361369

@@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
372380
auto lhsExtInType = cast<VectorType>(lhs.getType());
373381
if (lhsExtInType.getElementTypeBitWidth() < 8)
374382
lhs = extendSmallIntVector(loc, lhsExtInType, lhs,
375-
/* signExt */ mmlaOp == MMLA::Signed ||
376-
mmlaOp == MMLA::Mixed,
383+
/* signExt */
384+
(mmlaOp == MMLA::SignedInt ||
385+
(mmlaOp == MMLA::MixedInt && !swapOperands)),
377386
rewriter);
378387

379388
auto rhsExtInType = cast<VectorType>(rhs.getType());
380389
if (rhsExtInType.getElementTypeBitWidth() < 8)
381-
382390
rhs = extendSmallIntVector(loc, rhsExtInType, rhs,
383-
/* signExt */ mmlaOp != MMLA::Unsigned &&
384-
mmlaOp != MMLA::Mixed,
391+
/* signExt */
392+
(mmlaOp == MMLA::SignedInt ||
393+
(mmlaOp == MMLA::MixedInt && swapOperands)),
385394
rewriter);
386395

387396
// Initialize parameters for unrolling.
@@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
395404
}
396405
};
397406

407+
class VectorContractRewriterBFMMLA : public VectorContractRewriter {
408+
public:
409+
LogicalResult matchAndInit(vector::ContractionOp op,
410+
PatternRewriter &rewriter) {
411+
412+
if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
413+
return failure();
414+
415+
// Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
416+
// tiling.
417+
if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0)
418+
return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
419+
420+
// Check the output is a vector of Float32 elements.
421+
auto outTy = dyn_cast<VectorType>(op.getResultType());
422+
if (!outTy || outTy.getElementType() != rewriter.getF32Type())
423+
return rewriter.notifyMatchFailure(op,
424+
"output type is not a vector of f32");
425+
426+
// Check the inputs are vectors of BFloat16 elements.
427+
if (op.getLhsType().getElementType() != rewriter.getBF16Type())
428+
return rewriter.notifyMatchFailure(op,
429+
"input type is not a vector of bf16");
430+
431+
mmlaOp = MMLA::Bfloat;
432+
swapOperands = false;
433+
lhs = op.getLhs();
434+
rhs = op.getRhs();
435+
acc = op.getAcc();
436+
437+
// Initialize parameters for unrolling.
438+
iterationBounds = *op.getShapeForUnroll();
439+
if (iterationBounds.size() == 3)
440+
subTileShape = SmallVector<int64_t>({dimM == 1 ? 1 : 2, 2, 4});
441+
else
442+
subTileShape = SmallVector<int64_t>({2, 4});
443+
444+
return success();
445+
}
446+
};
447+
398448
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
399449
/// any vector.contract into multiple smmla instructions with unrolling so long
400450
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern
416466
}
417467
};
418468

469+
class LowerContractionToNeonBFMMLAPattern
470+
: public OpRewritePattern<vector::ContractionOp> {
471+
public:
472+
using OpRewritePattern::OpRewritePattern;
473+
LogicalResult matchAndRewrite(vector::ContractionOp op,
474+
PatternRewriter &rewriter) const override {
475+
476+
VectorContractRewriterBFMMLA vcr;
477+
if (failed(vcr.matchAndInit(op, rewriter)))
478+
return failure();
479+
vcr.lower(op, rewriter);
480+
481+
return success();
482+
}
483+
};
484+
419485
} // namespace
420486

421-
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(
487+
void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns(
422488
RewritePatternSet &patterns) {
423489
MLIRContext *context = patterns.getContext();
424490
patterns.add<LowerContractionToNeonI8MMPattern>(context, /*benefit=*/2);
425491
}
492+
493+
void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns(
494+
RewritePatternSet &patterns) {
495+
MLIRContext *context = patterns.getContext();
496+
patterns.add<LowerContractionToNeonBFMMLAPattern>(context, /*benefit=*/2);
497+
}

mlir/lib/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using namespace mlir;
2020

2121
void transform::ApplyArmSVELowerContractionToI8MMPatternsOp::populatePatterns(
2222
RewritePatternSet &patterns) {
23-
mlir::populateLowerContractionToSVEI8MMPatternPatterns(patterns);
23+
mlir::populateLowerContractionToSVEI8MMPatterns(patterns);
2424
}
2525

2626
void transform::ApplyArmSVELowerContractionToBFMMLAPatternsOp::populatePatterns(

mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// TODO: There may be opportunities to unify this with a similar pattern
1313
// for Neon. See:
1414
// https://github.com/llvm/llvm-project/issues/145559
15-
// LowerContractionToNeonI8MMPattern.cpp
15+
// LowerContractToNeonPatterns.cpp
1616
//
1717
//===----------------------------------------------------------------------===//
1818

@@ -580,7 +580,7 @@ class LowerContractionToSVEBFMMLAPattern
580580

581581
} // namespace
582582

583-
void mlir::populateLowerContractionToSVEI8MMPatternPatterns(
583+
void mlir::populateLowerContractionToSVEI8MMPatterns(
584584
RewritePatternSet &patterns) {
585585
MLIRContext *context = patterns.getContext();
586586
patterns.add<LowerContractionToSVEI8MMPattern>(context, /*benefit=*/2);

0 commit comments

Comments
 (0)