29
29
#include < optional>
30
30
#include < random>
31
31
32
+ #include " mlir/Dialect/Tosa/Utils/QuantUtils.h"
33
+
32
34
using namespace mlir ;
33
35
using namespace mlir ::torch;
34
36
using namespace mlir ::torch::Torch;
@@ -2295,7 +2297,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
2295
2297
auto weightTy = cast<RankedTensorType>(weight.getType ());
2296
2298
auto outputTy =
2297
2299
cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
2298
-
2299
2300
if (!inputTy || !weightTy || !outputTy)
2300
2301
return rewriter.notifyMatchFailure (
2301
2302
op, " Input, weight and output to Convolution must be ranked tensors" );
@@ -2304,6 +2305,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
2304
2305
auto weightElemTy = weightTy.getElementType ();
2305
2306
auto inputShape = makeShapeTorchCompatible (inputTy.getShape ());
2306
2307
auto weightShape = makeShapeTorchCompatible (weightTy.getShape ());
2308
+ auto outputElemTy = outputTy.getElementType ();
2307
2309
2308
2310
if (inputTy.getRank () != 4 )
2309
2311
return rewriter.notifyMatchFailure (
@@ -2316,28 +2318,21 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
2316
2318
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
2317
2319
// required.
2318
2320
auto bias = adaptor.getBias ();
2319
- if (isa<Torch::NoneType>(adaptor.getBias ().getType ())) {
2320
- // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
2321
- // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
2322
- // define a 48-bit int.
2323
- if (isa<quant::QuantizedType>(inputElemTy)) {
2324
- SmallVector<int32_t > zeroVec (weightShape[0 ], 0 );
2325
- bias = tosa::getConstTensor<int32_t >(
2326
- rewriter, op, zeroVec, {static_cast <int32_t >(weightShape[0 ])})
2327
- .value ();
2328
- } else {
2329
- SmallVector<float > zeroVec (weightShape[0 ], 0 );
2330
- bias = tosa::getConstTensor<float >(rewriter, op, zeroVec,
2331
- {static_cast <int32_t >(weightShape[0 ])})
2332
- .value ();
2333
- }
2321
+
2322
+ if (isa<Torch::NoneType>(bias.getType ())) {
2323
+ auto bias_result = tosa::getConvBiasForNoneType (op, rewriter, inputElemTy,
2324
+ outputElemTy, weightShape);
2325
+ if (failed (bias_result))
2326
+ return rewriter.notifyMatchFailure (
2327
+ op, " Failed to create bias tensor for none type." );
2328
+ bias = bias_result.value ();
2334
2329
} else {
2335
- if (!cast <RankedTensorType>(bias.getType ()))
2330
+ if (!isa <RankedTensorType>(bias.getType ()))
2336
2331
return rewriter.notifyMatchFailure (
2337
2332
op, " Bias provided but not a ranked tensor" );
2338
2333
}
2339
- auto biasElemTy =
2340
- isa<mlir::FloatType>(inputElemTy) ? inputElemTy : rewriter. getI32Type ();
2334
+
2335
+ Type biasElemTy = cast<RankedTensorType>(bias. getType ()). getElementType ();
2341
2336
2342
2337
int64_t groups;
2343
2338
if (!matchPattern (op.getGroups (), m_TorchConstantInt (&groups))) {
@@ -2528,14 +2523,29 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
2528
2523
auto convOpTy =
2529
2524
RankedTensorType::get (makeShapeLLVMCompatible (outputShape), biasElemTy);
2530
2525
2526
+ // create zero-point tensors for input and weight
2527
+ auto zps = tosa::createZPsAsConst (rewriter, input, weight);
2528
+ // for i8 input/weight, zero-points are returned as un-initialized
2529
+ Value inputZp =
2530
+ zps.first
2531
+ ? zps.first
2532
+ : tosa::createZeroPointTensor (rewriter, op->getLoc (), inputElemTy, 0 )
2533
+ .value ();
2534
+
2535
+ Value weightZp =
2536
+ zps.second
2537
+ ? zps.second
2538
+ : tosa::createZeroPointTensor (rewriter, op->getLoc (), weightElemTy, 0 )
2539
+ .value ();
2540
+
2531
2541
Value convOpResult;
2532
2542
if (groups == 1 ) {
2533
2543
// full convolution
2534
2544
convOpResult =
2535
2545
rewriter
2536
2546
.create <tosa::Conv2DOp>(
2537
2547
op->getLoc (), getTypeConverter ()->convertType (convOpTy),
2538
- transposedInput, transformedWeight, bias,
2548
+ transposedInput, transformedWeight, bias, inputZp, weightZp,
2539
2549
rewriter.getDenseI64ArrayAttr (padding),
2540
2550
rewriter.getDenseI64ArrayAttr (stride),
2541
2551
rewriter.getDenseI64ArrayAttr (dilation), accType)
@@ -2546,7 +2556,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
2546
2556
rewriter
2547
2557
.create <tosa::DepthwiseConv2DOp>(
2548
2558
op->getLoc (), getTypeConverter ()->convertType (convOpTy),
2549
- transposedInput, transformedWeight, bias,
2559
+ transposedInput, transformedWeight, bias, inputZp, weightZp,
2550
2560
rewriter.getDenseI64ArrayAttr (padding),
2551
2561
rewriter.getDenseI64ArrayAttr (stride),
2552
2562
rewriter.getDenseI64ArrayAttr (dilation), accType)
@@ -2574,8 +2584,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
2574
2584
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2575
2585
}
2576
2586
2577
- rewriter.replaceOpWithNewOp <tensor::CastOp>(
2578
- op, getTypeConverter ()->convertType (op.getType ()), rescaledResult);
2587
+ // cast to outputTy is required if convOpTy is not same as outputTy
2588
+ // the difference is not in the shape information, rather the element-type
2589
+ // itself
2590
+ rewriter.replaceOp (
2591
+ op,
2592
+ {tosa::tosaCastTensorToType (rewriter, rescaledResult, outputTy).value ()});
2579
2593
2580
2594
return success ();
2581
2595
}
0 commit comments