@@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23042304 if (!matchPattern (op.getTransposed (), m_TorchConstantBool (&transposed)))
23052305 return rewriter.notifyMatchFailure (
23062306 op, " Unimplemented: non-constant value for transposed not supported" );
2307- if (transposed)
2308- return rewriter.notifyMatchFailure (
2309- op, " Unimplemented: transposed convolution not supported" );
23102307
23112308 auto input = adaptor.getInput ();
23122309 auto weight = adaptor.getWeight ();
@@ -2338,12 +2335,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23382335 auto bias = adaptor.getBias ();
23392336
23402337 if (isa<Torch::NoneType>(bias.getType ())) {
2341- auto bias_result = tosa::getConvBiasForNoneType (op, rewriter, inputElemTy,
2342- outputElemTy, weightShape);
2343- if (failed (bias_result))
2338+ SmallVector<int64_t , 4 > biasWeightShape =
2339+ transposed ? SmallVector<int64_t , 4 >{weightShape[1 ], weightShape[0 ],
2340+ weightShape[2 ], weightShape[3 ]}
2341+ : weightShape;
2342+
2343+ auto biasResult = tosa::getConvBiasForNoneType (
2344+ op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
2345+ if (failed (biasResult))
23442346 return rewriter.notifyMatchFailure (
23452347 op, " Failed to create bias tensor for none type." );
2346- bias = bias_result .value ();
2348+ bias = biasResult .value ();
23472349 } else {
23482350 if (!isa<RankedTensorType>(bias.getType ()))
23492351 return rewriter.notifyMatchFailure (
@@ -2370,8 +2372,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23702372 m_TorchListOfConstantInts (padding_2d)))
23712373 return rewriter.notifyMatchFailure (op,
23722374 " non-const padding list unsupported" );
2373- // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
2374- // padding {height, width}. The Torch OFM computation uses 2*pad in each
2375+ // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
2376+ // padding {height, width}. The PyTorch OFM computation uses 2*pad in each
23752377 // spatial direction, implying the same top=bottom=height and left=right=width
23762378 // values for TOSA.
23772379 SmallVector<int64_t > padding (
@@ -2388,9 +2390,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23882390 return rewriter.notifyMatchFailure (
23892391 op, " failed to get accumulator type for convolution ops" );
23902392
2393+ // Weight layout reference:
2394+ // Conv : PyTorch OIHW -> TOSA OHWI
2395+ // Depthwise : PyTorch OIHW* -> TOSA HWIM (*out = in * multiplier)
2396+ // Grouped : PyTorch O(I/G)HW -> N/A
2397+ // Transposed : PyTorch IOHW -> TOSA OHWI
23912398 // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
23922399 // Perform the necessary transformations.
23932400 SmallVector<int32_t > nchwToNhwcDims ({0 , 2 , 3 , 1 });
2401+ SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
23942402 SmallVector<int64_t > transposedInputShape (
23952403 {inputShape[0 ], inputShape[2 ], inputShape[3 ], inputShape[1 ]});
23962404 auto transposedInputType = RankedTensorType::get (
@@ -2403,6 +2411,101 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24032411 rewriter.getDenseI32ArrayAttr (nchwToNhwcDims))
24042412 .getResult ();
24052413
2414+ if (transposed) {
2415+ if (groups != 1 )
2416+ return rewriter.notifyMatchFailure (
2417+ op, " Unimplemented: grouped transposed convolution not supported by "
2418+ " TOSA" );
2419+ if (dilation[0 ] != 1 || dilation[1 ] != 1 )
2420+ return rewriter.notifyMatchFailure (
2421+ op, " Unimplemented: dilated transposed convolution not supported by "
2422+ " TOSA" );
2423+
2424+ SmallVector<int32_t > iohwToOhwi ({1 , 2 , 3 , 0 });
2425+ SmallVector<int64_t > ohwiWeightShape (
2426+ {weightShape[1 ], weightShape[2 ], weightShape[3 ], weightShape[0 ]});
2427+ auto ohwiWeightType = RankedTensorType::get (
2428+ makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2429+ Value transformedWeight =
2430+ rewriter
2431+ .create <tosa::TransposeOp>(
2432+ op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2433+ weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2434+ .getResult ();
2435+
2436+ // TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
2437+ // Map from PyTorch's (padding, output_padding):
2438+ // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
2439+ // Negative values are allowed and will be handled by the TOSA
2440+ // decomposition.
2441+ SmallVector<int64_t , 2 > outPadding2D;
2442+ if (!matchPattern (adaptor.getOutputPadding (),
2443+ m_TorchListOfConstantInts (outPadding2D)))
2444+ return rewriter.notifyMatchFailure (
2445+ op, " non-const output_padding list unsupported for transposed conv" );
2446+
2447+ int64_t outPadH = outPadding2D[0 ] - 2 * padding_2d[0 ];
2448+ int64_t outPadW = outPadding2D[1 ] - 2 * padding_2d[1 ];
2449+ int64_t outPadTop = outPadH / 2 ;
2450+ int64_t outPadBottom = outPadH - outPadTop;
2451+ int64_t outPadLeft = outPadW / 2 ;
2452+ int64_t outPadRight = outPadW - outPadLeft;
2453+ SmallVector<int64_t , 4 > outPad (
2454+ {outPadTop, outPadBottom, outPadLeft, outPadRight});
2455+
2456+ // Result type is NHWC (we'll transpose back).
2457+ auto outNCHW = makeShapeTorchCompatible (outputTy.getShape ());
2458+ SmallVector<int64_t > outNHWC (
2459+ {outNCHW[0 ], outNCHW[2 ], outNCHW[3 ], outNCHW[1 ]});
2460+ auto transConvOpTy =
2461+ RankedTensorType::get (makeShapeLLVMCompatible (outNHWC), biasElemTy);
2462+
2463+ // Zero-points.
2464+ auto zps = tosa::createZPsAsConst (rewriter, input, weight);
2465+ Value inputZp = zps.first ? zps.first
2466+ : tosa::createZeroPointTensor (
2467+ rewriter, op->getLoc (), inputElemTy, 0 )
2468+ .value ();
2469+ Value weightZp = zps.second ? zps.second
2470+ : tosa::createZeroPointTensor (
2471+ rewriter, op->getLoc (), weightElemTy, 0 )
2472+ .value ();
2473+
2474+ Value convTOut =
2475+ rewriter
2476+ .create <tosa::TransposeConv2DOp>(
2477+ op->getLoc (), getTypeConverter ()->convertType (transConvOpTy),
2478+ transposedInput, transformedWeight, bias, inputZp, weightZp,
2479+ rewriter.getDenseI64ArrayAttr (outPad),
2480+ rewriter.getDenseI64ArrayAttr (stride), accType)
2481+ .getResult ();
2482+
2483+ SmallVector<int64_t > transposedOutputShape (
2484+ {outNHWC[0 ], outNHWC[3 ], outNHWC[1 ], outNHWC[2 ]});
2485+ auto transposedOutputType = RankedTensorType::get (
2486+ makeShapeLLVMCompatible (transposedOutputShape), biasElemTy);
2487+ Value transposedOutput =
2488+ rewriter
2489+ .create <tosa::TransposeOp>(
2490+ op->getLoc (),
2491+ getTypeConverter ()->convertType (transposedOutputType), convTOut,
2492+ rewriter.getDenseI32ArrayAttr (nhwcToNchwDims))
2493+ .getResult ();
2494+
2495+ // Quantized rescale.
2496+ Value rescaledResult = transposedOutput;
2497+ if (isa<quant::QuantizedType>(inputElemTy)) {
2498+ rescaledResult = tosa::buildRescaleOpConvOutput (
2499+ rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
2500+ }
2501+
2502+ // Final cast to requested output type.
2503+ rewriter.replaceOp (
2504+ op, {tosa::tosaCastTensorToType (rewriter, rescaledResult, outputTy)
2505+ .value ()});
2506+ return success ();
2507+ }
2508+
24062509 SmallVector<int64_t > transformedWeightShape;
24072510 RankedTensorType transformedWeightType;
24082511 Value transformedWeight;
@@ -2583,7 +2686,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25832686 llvm_unreachable (" Unhandled convolution type" );
25842687 }
25852688
2586- SmallVector<int32_t > nhwcToNchwDims ({0 , 3 , 1 , 2 });
25872689 SmallVector<int64_t > transposedOutputShape (
25882690 {outputShape[0 ], outputShape[3 ], outputShape[1 ], outputShape[2 ]});
25892691 auto transposedOutputType = RankedTensorType::get (
0 commit comments