@@ -2335,6 +2335,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
23352335 auto bias = adaptor.getBias ();
23362336
23372337 if (isa<Torch::NoneType>(bias.getType ())) {
2338+ // ConvTranspose weights use IOHW; the helper expects OIHW, so swap
2339+ // dims 0/1 before we synthesize the bias.
23382340 SmallVector<int64_t , 4 > biasWeightShape =
23392341 transposed ? SmallVector<int64_t , 4 >{weightShape[1 ], weightShape[0 ],
23402342 weightShape[2 ], weightShape[3 ]}
@@ -2405,13 +2407,13 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24052407 transposedInputShape.push_back (inputShape[dim]);
24062408 auto transposedInputType = RankedTensorType::get (
24072409 makeShapeLLVMCompatible (transposedInputShape), inputElemTy);
2408- auto transposedInput =
2409- rewriter
2410- .create <tosa::TransposeOp>(
2411- op->getLoc (),
2412- getTypeConverter ()-> convertType (transposedInputType), input,
2413- rewriter. getDenseI32ArrayAttr (nchwToNhwcDims))
2414- . getResult () ;
2410+ auto createTransposedInput = [&]() {
2411+ return rewriter
2412+ .create <tosa::TransposeOp>(
2413+ op->getLoc (), getTypeConverter ()-> convertType (transposedInputType ),
2414+ input, rewriter. getDenseI32ArrayAttr (nchwToNhwcDims))
2415+ . getResult ();
2416+ } ;
24152417
24162418 if (transposed) {
24172419 if (groups != 1 )
@@ -2424,17 +2426,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24242426 " TOSA" );
24252427
24262428 SmallVector<int32_t > iohwToOhwi ({1 , 2 , 3 , 0 });
2427- SmallVector<int64_t , 4 > ohwiWeightShape;
2428- for (int32_t dim : iohwToOhwi)
2429- ohwiWeightShape.push_back (weightShape[dim]);
2430- auto ohwiWeightType = RankedTensorType::get (
2431- makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2432- Value transformedWeight =
2433- rewriter
2434- .create <tosa::TransposeOp>(
2435- op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2436- weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2437- .getResult ();
24382429
24392430 // TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
24402431 // Map from PyTorch's (padding, output_padding):
@@ -2456,6 +2447,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24562447 SmallVector<int64_t , 4 > outPad (
24572448 {outPadTop, outPadBottom, outPadLeft, outPadRight});
24582449
2450+ Value nhwcInput = createTransposedInput ();
2451+ SmallVector<int64_t , 4 > ohwiWeightShape;
2452+ for (int32_t dim : iohwToOhwi)
2453+ ohwiWeightShape.push_back (weightShape[dim]);
2454+ auto ohwiWeightType = RankedTensorType::get (
2455+ makeShapeLLVMCompatible (ohwiWeightShape), weightElemTy);
2456+ Value transformedWeight =
2457+ rewriter
2458+ .create <tosa::TransposeOp>(
2459+ op->getLoc (), getTypeConverter ()->convertType (ohwiWeightType),
2460+ weight, rewriter.getDenseI32ArrayAttr (iohwToOhwi))
2461+ .getResult ();
2462+
24592463 // Result type is NHWC (we'll transpose back).
24602464 auto outNCHW = makeShapeTorchCompatible (outputTy.getShape ());
24612465 SmallVector<int64_t , 4 > outNHWC;
@@ -2479,7 +2483,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
24792483 rewriter
24802484 .create <tosa::TransposeConv2DOp>(
24812485 op->getLoc (), getTypeConverter ()->convertType (transConvOpTy),
2482- transposedInput , transformedWeight, bias, inputZp, weightZp,
2486+ nhwcInput , transformedWeight, bias, inputZp, weightZp,
24832487 rewriter.getDenseI64ArrayAttr (outPad),
24842488 rewriter.getDenseI64ArrayAttr (stride), accType)
24852489 .getResult ();
@@ -2535,6 +2539,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25352539 SmallVector<int32_t > transposedDims ({2 , 3 , 0 , 1 });
25362540 SmallVector<int64_t > transposedWeightShape = {
25372541 weightShape[2 ], weightShape[3 ], weightShape[0 ], weightShape[1 ]};
2542+
2543+ // reshape: HWO(I/G) -> HWIM
2544+ outputCDim = makeShapeTorchCompatible (outputTy.getShape ())[1 ];
2545+ if (outputCDim == kUnknownSize ) {
2546+ return rewriter.notifyMatchFailure (
2547+ op, " number of output channels must be statically known for "
2548+ " depthwise convolutions" );
2549+ }
2550+
25382551 auto transposedWeightType = RankedTensorType::get (
25392552 makeShapeLLVMCompatible (transposedWeightShape), weightElemTy);
25402553 auto transposedWeight =
@@ -2545,13 +2558,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25452558 rewriter.getDenseI32ArrayAttr (transposedDims))
25462559 .getResult ();
25472560
2548- // reshape: HWO(I/G) -> HWIM
2549- outputCDim = makeShapeTorchCompatible (outputTy.getShape ())[1 ];
2550- if (outputCDim == kUnknownSize ) {
2551- return rewriter.notifyMatchFailure (
2552- op, " number of output channels must be statically known for "
2553- " depthwise convolutions" );
2554- }
25552561 transformedWeightShape = {
25562562 transposedWeightShape[0 ],
25572563 transposedWeightShape[1 ],
@@ -2573,6 +2579,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
25732579 llvm_unreachable (" Unhandled convolution type" );
25742580 }
25752581
2582+ Value transposedInput = createTransposedInput ();
2583+
25762584 int64_t outputHDim, outputWDim;
25772585 int64_t inputHDim = inputShape[2 ];
25782586 int64_t inputWDim = inputShape[3 ];
0 commit comments