1
- // ===- LowerContractionToNeonI8MMPattern .cpp - Contract to I8MM -*- C++ -*-===//
1
+ // ===- LowerContractToNeonPatterns .cpp - Contract to I8MM/BF16 - -*- C++ -*-===//
2
2
//
3
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
@@ -93,15 +93,20 @@ class VectorContractRewriter {
93
93
// multiplications.
94
94
enum class MMLA {
95
95
Nop,
96
- Signed, // smmla
97
- Unsigned, // ummla
98
- Mixed, // usmmla
99
- MixedSwapped // usmmla with LHS and RHS swapped
96
+ SignedInt, // smmla
97
+ UnsignedInt, // ummla
98
+ MixedInt, // usmmla
99
+ Bfloat // bfmmla
100
100
};
101
101
102
102
// Lower-level operation to be emitted.
103
103
MMLA mmlaOp = MMLA::Nop;
104
104
105
+ // Indicate if the operands for the ArmNeon dialect operation need to be
106
+ // swapped. Currently this is needed in order to emulate an "summla"
107
+ // operation.
108
+ bool swapOperands = false ;
109
+
105
110
// The operand tiles. These are not necessarily the operands of
106
111
// `vector.contract`, for example they could be operands to `arith.extsi`
107
112
// that is in turn fed into `vector.contract`.
@@ -126,21 +131,22 @@ class VectorContractRewriter {
126
131
// Create the matrix multiply and accumulate operation according to `mmlaOp`.
127
132
Value createMMLA (PatternRewriter &rewriter, Location loc, Value acc,
128
133
Value lhs, Value rhs) {
134
+
135
+ if (swapOperands)
136
+ std::swap (lhs, rhs);
129
137
switch (mmlaOp) {
130
- case MMLA::Signed :
138
+ case MMLA::SignedInt :
131
139
return rewriter.createOrFold <arm_neon::SmmlaOp>(loc, acc.getType (), acc,
132
140
lhs, rhs);
133
- case MMLA::Unsigned :
141
+ case MMLA::UnsignedInt :
134
142
return rewriter.createOrFold <arm_neon::UmmlaOp>(loc, acc.getType (), acc,
135
143
lhs, rhs);
136
- case MMLA::Mixed :
144
+ case MMLA::MixedInt :
137
145
return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, acc.getType (), acc,
138
146
lhs, rhs);
139
- case MMLA::MixedSwapped:
140
- // The accumulator comes transposed and the result will be transposed
141
- // later, so all we have to do here is swap the operands.
142
- return rewriter.createOrFold <arm_neon::UsmmlaOp>(loc, acc.getType (), acc,
143
- rhs, lhs);
147
+ case MMLA::Bfloat:
148
+ return rewriter.create <arm_neon::BfmmlaOp>(loc, acc.getType (), acc, lhs,
149
+ rhs);
144
150
case MMLA::Nop:
145
151
llvm_unreachable (" Uninitialized operation type" );
146
152
}
@@ -273,7 +279,7 @@ class VectorContractRewriter {
273
279
// Transpose ACC if doing signed by unsigned multiplication, because we're
274
280
// using the instruction for unsigned by signed multiplication with
275
281
// reversed operands.
276
- if (mmlaOp == MMLA::MixedSwapped )
282
+ if (swapOperands )
277
283
tiledAcc = rewriter.create <vector::TransposeOp>(
278
284
loc, tiledAcc, ArrayRef<int64_t >({1 , 0 }));
279
285
@@ -302,7 +308,7 @@ class VectorContractRewriter {
302
308
303
309
// Because of the reversed operands the result is obtained transposed.
304
310
// Transpose it back,
305
- if (mmlaOp == MMLA::MixedSwapped )
311
+ if (swapOperands )
306
312
tiledRes = rewriter.create <vector::TransposeOp>(
307
313
loc, tiledRes, ArrayRef<int64_t >({1 , 0 }));
308
314
@@ -339,10 +345,10 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
339
345
// values before the extension. All four signed/unsigned combinations for
340
346
// input operands are supported, but they are lowered to different
341
347
// operations. Determine which is the appropriate operation to lower to.
342
- mmlaOp = MMLA::Signed ;
348
+ mmlaOp = MMLA::SignedInt ;
343
349
auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs ());
344
350
if (!maybeLhs) {
345
- mmlaOp = MMLA::Unsigned ;
351
+ mmlaOp = MMLA::UnsignedInt ;
346
352
maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs ());
347
353
}
348
354
if (!maybeLhs)
@@ -351,11 +357,13 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
351
357
352
358
auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs ());
353
359
if (maybeRhs) {
354
- if (mmlaOp == MMLA::Unsigned )
355
- mmlaOp = MMLA::Mixed ;
360
+ if (mmlaOp == MMLA::UnsignedInt )
361
+ mmlaOp = MMLA::MixedInt ;
356
362
} else {
357
- if (mmlaOp == MMLA::Signed)
358
- mmlaOp = MMLA::MixedSwapped;
363
+ if (mmlaOp == MMLA::SignedInt) {
364
+ mmlaOp = MMLA::MixedInt;
365
+ swapOperands = true ;
366
+ }
359
367
maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs ());
360
368
}
361
369
@@ -372,16 +380,17 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
372
380
auto lhsExtInType = cast<VectorType>(lhs.getType ());
373
381
if (lhsExtInType.getElementTypeBitWidth () < 8 )
374
382
lhs = extendSmallIntVector (loc, lhsExtInType, lhs,
375
- /* signExt */ mmlaOp == MMLA::Signed ||
376
- mmlaOp == MMLA::Mixed,
383
+ /* signExt */
384
+ (mmlaOp == MMLA::SignedInt ||
385
+ (mmlaOp == MMLA::MixedInt && !swapOperands)),
377
386
rewriter);
378
387
379
388
auto rhsExtInType = cast<VectorType>(rhs.getType ());
380
389
if (rhsExtInType.getElementTypeBitWidth () < 8 )
381
-
382
390
rhs = extendSmallIntVector (loc, rhsExtInType, rhs,
383
- /* signExt */ mmlaOp != MMLA::Unsigned &&
384
- mmlaOp != MMLA::Mixed,
391
+ /* signExt */
392
+ (mmlaOp == MMLA::SignedInt ||
393
+ (mmlaOp == MMLA::MixedInt && swapOperands)),
385
394
rewriter);
386
395
387
396
// Initialize parameters for unrolling.
@@ -395,6 +404,47 @@ class VectorContractRewriterI8MM : public VectorContractRewriter {
395
404
}
396
405
};
397
406
407
+ class VectorContractRewriterBFMMLA : public VectorContractRewriter {
408
+ public:
409
+ LogicalResult matchAndInit (vector::ContractionOp op,
410
+ PatternRewriter &rewriter) {
411
+
412
+ if (failed (VectorContractRewriter::matchAndInit (op, rewriter)))
413
+ return failure ();
414
+
415
+ // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for
416
+ // tiling.
417
+ if ((dimM != 1 && dimM % 2 != 0 ) || dimN % 2 != 0 || dimK % 4 != 0 )
418
+ return rewriter.notifyMatchFailure (op, " Unsupported operand shapes" );
419
+
420
+ // Check the output is a vector of Float32 elements.
421
+ auto outTy = dyn_cast<VectorType>(op.getResultType ());
422
+ if (!outTy || outTy.getElementType () != rewriter.getF32Type ())
423
+ return rewriter.notifyMatchFailure (op,
424
+ " output type is not a vector of f32" );
425
+
426
+ // Check the inputs are vectors of BFloat16 elements.
427
+ if (op.getLhsType ().getElementType () != rewriter.getBF16Type ())
428
+ return rewriter.notifyMatchFailure (op,
429
+ " input type is not a vector of bf16" );
430
+
431
+ mmlaOp = MMLA::Bfloat;
432
+ swapOperands = false ;
433
+ lhs = op.getLhs ();
434
+ rhs = op.getRhs ();
435
+ acc = op.getAcc ();
436
+
437
+ // Initialize parameters for unrolling.
438
+ iterationBounds = *op.getShapeForUnroll ();
439
+ if (iterationBounds.size () == 3 )
440
+ subTileShape = SmallVector<int64_t >({dimM == 1 ? 1 : 2 , 2 , 4 });
441
+ else
442
+ subTileShape = SmallVector<int64_t >({2 , 4 });
443
+
444
+ return success ();
445
+ }
446
+ };
447
+
398
448
// / Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
399
449
// / any vector.contract into multiple smmla instructions with unrolling so long
400
450
// / as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
@@ -416,10 +466,32 @@ class LowerContractionToNeonI8MMPattern
416
466
}
417
467
};
418
468
469
+ class LowerContractionToNeonBFMMLAPattern
470
+ : public OpRewritePattern<vector::ContractionOp> {
471
+ public:
472
+ using OpRewritePattern::OpRewritePattern;
473
+ LogicalResult matchAndRewrite (vector::ContractionOp op,
474
+ PatternRewriter &rewriter) const override {
475
+
476
+ VectorContractRewriterBFMMLA vcr;
477
+ if (failed (vcr.matchAndInit (op, rewriter)))
478
+ return failure ();
479
+ vcr.lower (op, rewriter);
480
+
481
+ return success ();
482
+ }
483
+ };
484
+
419
485
} // namespace
420
486
421
- void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns (
487
+ void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns (
422
488
RewritePatternSet &patterns) {
423
489
MLIRContext *context = patterns.getContext ();
424
490
patterns.add <LowerContractionToNeonI8MMPattern>(context, /* benefit=*/ 2 );
425
491
}
492
+
493
+ void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns (
494
+ RewritePatternSet &patterns) {
495
+ MLIRContext *context = patterns.getContext ();
496
+ patterns.add <LowerContractionToNeonBFMMLAPattern>(context, /* benefit=*/ 2 );
497
+ }
0 commit comments