Skip to content

Commit cb0363f

Browse files
committed
[TOSA] Defer input transpose until guards pass
Lazily create the NHWC input transpose so we emit it only once the failure guards in the transposed and depthwise convolution rewrite succeed. Change-Id: Ia362deda898794397107f6da3c44cd89f219f58f
1 parent 6db8cd6 commit cb0363f

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,6 +2335,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23352335
auto bias = adaptor.getBias();
23362336

23372337
if (isa<Torch::NoneType>(bias.getType())) {
2338+
// ConvTranspose weights use IOHW; the helper expects OIHW, so swap
2339+
// dims 0/1 before we synthesize the bias.
23382340
SmallVector<int64_t, 4> biasWeightShape =
23392341
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
23402342
weightShape[2], weightShape[3]}
@@ -2405,13 +2407,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24052407
transposedInputShape.push_back(inputShape[dim]);
24062408
auto transposedInputType = RankedTensorType::get(
24072409
makeShapeLLVMCompatible(transposedInputShape), inputElemTy);
2408-
auto transposedInput =
2409-
rewriter
2410-
.create<tosa::TransposeOp>(
2411-
op->getLoc(),
2412-
getTypeConverter()->convertType(transposedInputType), input,
2413-
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2414-
.getResult();
2410+
auto createTransposedInput = [&]() {
2411+
return rewriter
2412+
.create<tosa::TransposeOp>(
2413+
op->getLoc(), getTypeConverter()->convertType(transposedInputType),
2414+
input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
2415+
.getResult();
2416+
};
24152417

24162418
if (transposed) {
24172419
if (groups != 1)
@@ -2424,17 +2426,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24242426
"TOSA");
24252427

24262428
SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
2427-
SmallVector<int64_t, 4> ohwiWeightShape;
2428-
for (int32_t dim : iohwToOhwi)
2429-
ohwiWeightShape.push_back(weightShape[dim]);
2430-
auto ohwiWeightType = RankedTensorType::get(
2431-
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2432-
Value transformedWeight =
2433-
rewriter
2434-
.create<tosa::TransposeOp>(
2435-
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2436-
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2437-
.getResult();
24382429

24392430
// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
24402431
// Map from PyTorch's (padding, output_padding):
@@ -2456,6 +2447,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24562447
SmallVector<int64_t, 4> outPad(
24572448
{outPadTop, outPadBottom, outPadLeft, outPadRight});
24582449

2450+
Value nhwcInput = createTransposedInput();
2451+
SmallVector<int64_t, 4> ohwiWeightShape;
2452+
for (int32_t dim : iohwToOhwi)
2453+
ohwiWeightShape.push_back(weightShape[dim]);
2454+
auto ohwiWeightType = RankedTensorType::get(
2455+
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
2456+
Value transformedWeight =
2457+
rewriter
2458+
.create<tosa::TransposeOp>(
2459+
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
2460+
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
2461+
.getResult();
2462+
24592463
// Result type is NHWC (we'll transpose back).
24602464
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
24612465
SmallVector<int64_t, 4> outNHWC;
@@ -2479,7 +2483,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24792483
rewriter
24802484
.create<tosa::TransposeConv2DOp>(
24812485
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
2482-
transposedInput, transformedWeight, bias, inputZp, weightZp,
2486+
nhwcInput, transformedWeight, bias, inputZp, weightZp,
24832487
rewriter.getDenseI64ArrayAttr(outPad),
24842488
rewriter.getDenseI64ArrayAttr(stride), accType)
24852489
.getResult();
@@ -2535,6 +2539,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25352539
SmallVector<int32_t> transposedDims({2, 3, 0, 1});
25362540
SmallVector<int64_t> transposedWeightShape = {
25372541
weightShape[2], weightShape[3], weightShape[0], weightShape[1]};
2542+
2543+
// reshape: HWO(I/G) -> HWIM
2544+
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
2545+
if (outputCDim == kUnknownSize) {
2546+
return rewriter.notifyMatchFailure(
2547+
op, "number of output channels must be statically known for "
2548+
"depthwise convolutions");
2549+
}
2550+
25382551
auto transposedWeightType = RankedTensorType::get(
25392552
makeShapeLLVMCompatible(transposedWeightShape), weightElemTy);
25402553
auto transposedWeight =
@@ -2545,13 +2558,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25452558
rewriter.getDenseI32ArrayAttr(transposedDims))
25462559
.getResult();
25472560

2548-
// reshape: HWO(I/G) -> HWIM
2549-
outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1];
2550-
if (outputCDim == kUnknownSize) {
2551-
return rewriter.notifyMatchFailure(
2552-
op, "number of output channels must be statically known for "
2553-
"depthwise convolutions");
2554-
}
25552561
transformedWeightShape = {
25562562
transposedWeightShape[0],
25572563
transposedWeightShape[1],
@@ -2573,6 +2579,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25732579
llvm_unreachable("Unhandled convolution type");
25742580
}
25752581

2582+
Value transposedInput = createTransposedInput();
2583+
25762584
int64_t outputHDim, outputWDim;
25772585
int64_t inputHDim = inputShape[2];
25782586
int64_t inputWDim = inputShape[3];

0 commit comments

Comments
 (0)