Skip to content

Commit 386bba4

Browse files
authored
[Tosa] : Use output type for bias for creating tosa.conv (#4252)
For ConvolutionLayer initialized without bias, a zero tensor for bias is created when converting to `tosa.conv2d` as the op always expects a bias tensor. This zero tensor was always initialized to be `fp32` irrespective of what the input/weights type were. This leads to a validation error since `bias` type (fp32) didn't match with output of conv (fp16) when the input/weight are of `fp16` type.
1 parent 1e4c605 commit 386bba4

File tree

6 files changed

+187
-44
lines changed

6 files changed

+187
-44
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
101101
RankedTensorType weightTy,
102102
RankedTensorType outputTy, TypeAttr &accType);
103103

104+
FailureOr<Value> getConvBiasForNoneType(Operation *op,
105+
PatternRewriter &rewriter,
106+
Type inputElemTy, Type outputElemTy,
107+
ArrayRef<int64_t> weightShape);
108+
104109
} // namespace tosa
105110
} // namespace mlir
106111

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include <optional>
3030
#include <random>
3131

32+
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
33+
3234
using namespace mlir;
3335
using namespace mlir::torch;
3436
using namespace mlir::torch::Torch;
@@ -2295,7 +2297,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
22952297
auto weightTy = cast<RankedTensorType>(weight.getType());
22962298
auto outputTy =
22972299
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
2298-
22992300
if (!inputTy || !weightTy || !outputTy)
23002301
return rewriter.notifyMatchFailure(
23012302
op, "Input, weight and output to Convolution must be ranked tensors");
@@ -2304,6 +2305,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23042305
auto weightElemTy = weightTy.getElementType();
23052306
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
23062307
auto weightShape = makeShapeTorchCompatible(weightTy.getShape());
2308+
auto outputElemTy = outputTy.getElementType();
23072309

23082310
if (inputTy.getRank() != 4)
23092311
return rewriter.notifyMatchFailure(
@@ -2316,28 +2318,21 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23162318
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
23172319
// required.
23182320
auto bias = adaptor.getBias();
2319-
if (isa<Torch::NoneType>(adaptor.getBias().getType())) {
2320-
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
2321-
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
2322-
// define a 48-bit int.
2323-
if (isa<quant::QuantizedType>(inputElemTy)) {
2324-
SmallVector<int32_t> zeroVec(weightShape[0], 0);
2325-
bias = tosa::getConstTensor<int32_t>(
2326-
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
2327-
.value();
2328-
} else {
2329-
SmallVector<float> zeroVec(weightShape[0], 0);
2330-
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
2331-
{static_cast<int32_t>(weightShape[0])})
2332-
.value();
2333-
}
2321+
2322+
if (isa<Torch::NoneType>(bias.getType())) {
2323+
auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy,
2324+
outputElemTy, weightShape);
2325+
if (failed(bias_result))
2326+
return rewriter.notifyMatchFailure(
2327+
op, "Failed to create bias tensor for none type.");
2328+
bias = bias_result.value();
23342329
} else {
2335-
if (!cast<RankedTensorType>(bias.getType()))
2330+
if (!isa<RankedTensorType>(bias.getType()))
23362331
return rewriter.notifyMatchFailure(
23372332
op, "Bias provided but not a ranked tensor");
23382333
}
2339-
auto biasElemTy =
2340-
isa<mlir::FloatType>(inputElemTy) ? inputElemTy : rewriter.getI32Type();
2334+
2335+
Type biasElemTy = cast<RankedTensorType>(bias.getType()).getElementType();
23412336

23422337
int64_t groups;
23432338
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
@@ -2528,14 +2523,29 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25282523
auto convOpTy =
25292524
RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy);
25302525

2526+
// create zero-point tensors for input and weight
2527+
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
2528+
// for i8 input/weight, zero-points are returned as un-initialized
2529+
Value inputZp =
2530+
zps.first
2531+
? zps.first
2532+
: tosa::createZeroPointTensor(rewriter, op->getLoc(), inputElemTy, 0)
2533+
.value();
2534+
2535+
Value weightZp =
2536+
zps.second
2537+
? zps.second
2538+
: tosa::createZeroPointTensor(rewriter, op->getLoc(), weightElemTy, 0)
2539+
.value();
2540+
25312541
Value convOpResult;
25322542
if (groups == 1) {
25332543
// full convolution
25342544
convOpResult =
25352545
rewriter
25362546
.create<tosa::Conv2DOp>(
25372547
op->getLoc(), getTypeConverter()->convertType(convOpTy),
2538-
transposedInput, transformedWeight, bias,
2548+
transposedInput, transformedWeight, bias, inputZp, weightZp,
25392549
rewriter.getDenseI64ArrayAttr(padding),
25402550
rewriter.getDenseI64ArrayAttr(stride),
25412551
rewriter.getDenseI64ArrayAttr(dilation), accType)
@@ -2546,7 +2556,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25462556
rewriter
25472557
.create<tosa::DepthwiseConv2DOp>(
25482558
op->getLoc(), getTypeConverter()->convertType(convOpTy),
2549-
transposedInput, transformedWeight, bias,
2559+
transposedInput, transformedWeight, bias, inputZp, weightZp,
25502560
rewriter.getDenseI64ArrayAttr(padding),
25512561
rewriter.getDenseI64ArrayAttr(stride),
25522562
rewriter.getDenseI64ArrayAttr(dilation), accType)
@@ -2574,8 +2584,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25742584
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
25752585
}
25762586

2577-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
2578-
op, getTypeConverter()->convertType(op.getType()), rescaledResult);
2587+
// cast to outputTy is required if convOpTy is not same as outputTy
2588+
// the difference is not in the shape information, rather the element-type
2589+
// itself
2590+
rewriter.replaceOp(
2591+
op,
2592+
{tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy).value()});
25792593

25802594
return success();
25812595
}

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project
1212
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1313
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project
14+
#include "llvm/ADT/ArrayRef.h"
1415

1516
namespace mlir {
1617
namespace tosa {
@@ -551,5 +552,48 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
551552
return success();
552553
}
553554

555+
FailureOr<Value> getConvBiasForNoneType(Operation *op,
556+
PatternRewriter &rewriter,
557+
Type inputElemTy, Type outputElemTy,
558+
ArrayRef<int64_t> weightShape) {
559+
560+
Type biasElemTy;
561+
562+
if (isa<quant::QuantizedType>(outputElemTy)) {
563+
auto input_qtype = dyn_cast<mlir::quant::QuantizedType>(inputElemTy);
564+
if (!input_qtype) {
565+
return rewriter.notifyMatchFailure(op,
566+
"output is qtype but input is not");
567+
}
568+
int input_bits = input_qtype.getStorageTypeIntegralWidth();
569+
if (input_bits != 8) {
570+
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
571+
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt
572+
// to define a 48-bit int.
573+
return rewriter.notifyMatchFailure(
574+
op, "Only int8 input tensor to conv2d is supported.");
575+
}
576+
// For signed int8 input tensor, int32 bias and output
577+
// tensor are generated.
578+
int bias_bits = 32;
579+
biasElemTy = rewriter.getIntegerType(bias_bits);
580+
} else {
581+
biasElemTy = outputElemTy;
582+
}
583+
584+
if (biasElemTy.isInteger()) {
585+
SmallVector<int32_t> zeroVec(weightShape[0], 0);
586+
return tosa::getConstTensor<int32_t>(rewriter, op, zeroVec,
587+
{static_cast<int32_t>(weightShape[0])})
588+
.value();
589+
} else {
590+
SmallVector<float> zeroVec(weightShape[0], 0);
591+
return tosa::getConstTensor<float>(rewriter, op, zeroVec,
592+
{static_cast<int32_t>(weightShape[0])},
593+
biasElemTy)
594+
.value();
595+
}
596+
}
597+
554598
} // namespace tosa
555599
} // namespace mlir

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@
681681
"ConstantBoolParameterModule_basic",
682682
"ContainsIntList_False",
683683
"ContainsIntList_True",
684+
"Conv2dFP16NoBiasModule_basic",
684685
"Conv2dQInt8Module_basic",
685686
"Conv2dQInt8Module_depthwise",
686687
"Conv2dQInt8Module_grouped",
@@ -2874,6 +2875,7 @@
28742875
"Conv2dBiasNoPaddingModule_basic",
28752876
"Conv2dModule_basic",
28762877
"Conv2dNoPaddingModule_basic",
2878+
"Conv2dFP16NoBiasModule_basic",
28772879
"Conv2dQInt8Module_basic",
28782880
"Conv2dQInt8Module_depthwise",
28792881
"Conv2dQInt8Module_grouped",

projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,6 +1259,36 @@ def Conv2dModule_basic(module, tu: TestUtils):
12591259
module.forward(inputVec, weight, bias)
12601260

12611261

1262+
class Conv2dFP16NoBiasModule(torch.nn.Module):
1263+
def __init__(self):
1264+
super().__init__()
1265+
1266+
@export
1267+
@annotate_args(
1268+
[
1269+
None,
1270+
([-1, -1, -1, -1], torch.float16, True),
1271+
([-1, -1, -1, -1], torch.float16, True),
1272+
]
1273+
)
1274+
def forward(self, inputVec, weight):
1275+
return torch.ops.aten.conv2d(
1276+
inputVec,
1277+
weight,
1278+
stride=[1, 1],
1279+
padding=[0, 0],
1280+
dilation=[1, 1],
1281+
groups=1,
1282+
)
1283+
1284+
1285+
@register_test_case(module_factory=lambda: Conv2dFP16NoBiasModule())
1286+
def Conv2dFP16NoBiasModule_basic(module, tu: TestUtils):
1287+
inputVec = tu.rand(2, 2, 6, 6).to(torch.float16)
1288+
weight = torch.randn(8, 2, 3, 3).to(torch.float16)
1289+
module.forward(inputVec, weight)
1290+
1291+
12621292
class Conv3dModule(torch.nn.Module):
12631293
def __init__(self):
12641294
super().__init__()

0 commit comments

Comments
 (0)