diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c959f06c6a66..3c97c087f3e3 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2306,9 +2306,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "Unimplemented: non-constant value for transposed not supported"); - if (transposed) - return rewriter.notifyMatchFailure( - op, "Unimplemented: transposed convolution not supported"); auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); @@ -2340,12 +2337,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bias = adaptor.getBias(); if (isa(bias.getType())) { - auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy, - outputElemTy, weightShape); - if (failed(bias_result)) + // ConvTranspose weights use IOHW; the helper expects OIHW, so swap + // dims 0/1 before we synthesize the bias. + SmallVector biasWeightShape = + transposed ? SmallVector{weightShape[1], weightShape[0], + weightShape[2], weightShape[3]} + : weightShape; + + auto biasResult = tosa::getConvBiasForNoneType( + op, rewriter, inputElemTy, outputElemTy, biasWeightShape); + if (failed(biasResult)) return rewriter.notifyMatchFailure( op, "Failed to create bias tensor for none type."); - bias = bias_result.value(); + bias = biasResult.value(); } else { if (!isa(bias.getType())) return rewriter.notifyMatchFailure( @@ -2372,8 +2376,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); - // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D - // padding {height, width}. The Torch OFM computation uses 2*pad in each + // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D + // padding {height, width}. The PyTorch OFM computation uses 2*pad in each // spatial direction, implying the same top=bottom=height and left=right=width // values for TOSA. SmallVector padding( @@ -2390,19 +2394,126 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get accumulator type for convolution ops"); + // Weight layout reference: + // Conv : PyTorch OIHW -> TOSA OHWI + // Depthwise : PyTorch OIHW* -> TOSA HWIM + // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier) + // Grouped : PyTorch O(I/G)HW -> N/A + // Transposed : PyTorch IOHW -> TOSA OHWI // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. // Perform the necessary transformations. SmallVector nchwToNhwcDims({0, 2, 3, 1}); - SmallVector transposedInputShape( - {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); + SmallVector nhwcToNchwDims({0, 3, 1, 2}); + SmallVector transposedInputShape; + for (int32_t dim : nchwToNhwcDims) + transposedInputShape.push_back(inputShape[dim]); auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); - auto transposedInput = - tosa::TransposeOp::create( - rewriter, op->getLoc(), - getTypeConverter()->convertType(transposedInputType), input, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) - .getResult(); + auto createTransposedInput = [&]() { + return tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedInputType), input, + rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + .getResult(); + }; + + if (transposed) { + if (groups != 1) + return rewriter.notifyMatchFailure( + op, "Unimplemented: grouped transposed convolution not supported by " + "TOSA"); + if (dilation[0] != 1 || dilation[1] != 1) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dilated transposed convolution not supported by " + "TOSA"); + + SmallVector iohwToOhwi({1, 2, 3, 0}); + + // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. + // Map from PyTorch's (padding, output_padding): + // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) + // Negative values are allowed and will be handled by the TOSA + // decomposition. + SmallVector outPadding2D; + if (!matchPattern(adaptor.getOutputPadding(), + m_TorchListOfConstantInts(outPadding2D))) + return rewriter.notifyMatchFailure( + op, "non-const output_padding list unsupported for transposed conv"); + + int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; + int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; + int64_t outPadTop = outPadH / 2; + int64_t outPadBottom = outPadH - outPadTop; + int64_t outPadLeft = outPadW / 2; + int64_t outPadRight = outPadW - outPadLeft; + SmallVector outPad( + {outPadTop, outPadBottom, outPadLeft, outPadRight}); + + Value nhwcInput = createTransposedInput(); + SmallVector ohwiWeightShape; + for (int32_t dim : iohwToOhwi) + ohwiWeightShape.push_back(weightShape[dim]); + auto ohwiWeightType = RankedTensorType::get( + makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); + Value transformedWeight = + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(ohwiWeightType), weight, + rewriter.getDenseI32ArrayAttr(iohwToOhwi)) + .getResult(); + + // Result type is NHWC (we'll transpose back). + auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); + SmallVector outNHWC; + for (int32_t dim : nchwToNhwcDims) + outNHWC.push_back(outNCHW[dim]); + auto transConvOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); + + // Zero-points. + auto zps = tosa::createZPsAsConst(rewriter, input, weight); + Value inputZp = zps.first ? zps.first + : tosa::createZeroPointTensor( + rewriter, op->getLoc(), inputElemTy, 0) + .value(); + Value weightZp = zps.second ? zps.second + : tosa::createZeroPointTensor( + rewriter, op->getLoc(), weightElemTy, 0) + .value(); + + Value convTOut = tosa::TransposeConv2DOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transConvOpTy), + nhwcInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(outPad), + rewriter.getDenseI64ArrayAttr(stride), accType) + .getResult(); + + SmallVector transposedOutputShape; + for (int32_t dim : nhwcToNchwDims) + transposedOutputShape.push_back(outNHWC[dim]); + auto transposedOutputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); + Value transposedOutput = + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedOutputType), convTOut, + rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + .getResult(); + + // Quantized rescale. + Value rescaledResult = transposedOutput; + if (isa(inputElemTy)) { + rescaledResult = tosa::buildRescaleOpConvOutput( + rewriter, op, transposedOutput, inputTy, weightTy, outputTy); + } + + // Final cast to requested output type. + rewriter.replaceOp( + op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) + .value()}); + return success(); + } SmallVector transformedWeightShape; RankedTensorType transformedWeightType; @@ -2427,6 +2538,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector transposedDims({2, 3, 0, 1}); SmallVector transposedWeightShape = { weightShape[2], weightShape[3], weightShape[0], weightShape[1]}; + + // reshape: HWO(I/G) -> HWIM + outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; + if (outputCDim == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "number of output channels must be statically known for " + "depthwise convolutions"); + } + auto transposedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); auto transposedWeight = @@ -2436,13 +2556,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getDenseI32ArrayAttr(transposedDims)) .getResult(); - // reshape: HWO(I/G) -> HWIM - outputCDim = makeShapeTorchCompatible(outputTy.getShape())[1]; - if (outputCDim == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "number of output channels must be statically known for " - "depthwise convolutions"); - } transformedWeightShape = { transposedWeightShape[0], transposedWeightShape[1], @@ -2463,6 +2576,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Unhandled convolution type"); } + Value transposedInput = createTransposedInput(); + int64_t outputHDim, outputWDim; int64_t inputHDim = inputShape[2]; int64_t inputWDim = inputShape[3]; @@ -2485,7 +2600,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (remainderHDim != 0) { if (remainderHDim > padding[1]) { SmallVector startHSlice(inputTy.getRank(), 0); - SmallVector sizeHSlice(transposedInputShape); + SmallVector sizeHSlice(transposedInputShape); // TOSA uses NHWC, so we will slice dim 1 for Height value sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); transposedInput = tosa::CreateOpAndInfer( @@ -2579,7 +2694,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm_unreachable("Unhandled convolution type"); } - SmallVector nhwcToNchwDims({0, 3, 1, 2}); SmallVector transposedOutputShape( {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); auto transposedOutputType = RankedTensorType::get( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e1b3b67d102..77af422bd6a8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3581,7 +3581,6 @@ "AvgPool3dCountIncludePadFalseWithoutPadding_basic", "Conv_Transpose1dModule_basic", "Conv_Transpose1dStaticModule_basic", - "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", @@ -3706,16 +3705,11 @@ "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", - "Conv_Transpose2dModule_basic", "ConvolutionBackwardModule2DPadded_basic", - "ConvolutionBackwardModule2DStatic_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", - "ConvolutionModule2DTransposeStridedStatic_basic", - "ConvolutionModule2DTransposeStrided_basic", - "ConvolutionModule2DTranspose_basic", "ConvolutionModule2DGroupedTranspose_basic", "ConvolutionModule3DGroups_basic", "ConvolutionModule3DGroupsStrided_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index c9273c1f46c4..f2d148ec466e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -29,6 +29,8 @@ # that depend on TOSA as well as TOSA-to-Standard. "tosa-to-arith", "tosa-to-scf", + # Required for transposed convolution support (decomposes to conv ops). + "tosa-optional-decompositions", # Named ops must be legalized prior to general tosa-to-linalg "tosa-to-linalg-named", # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 12f971cc9767..3d5cf50c30f6 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3646,11 +3646,11 @@ func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<10x2x3x3xf32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x2x10x20xf32>) -> tensor<5x10x20x2xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x14x24x10xf32> // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<5x14x24x10xf32>) -> tensor<5x10x14x24xf32> // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> // CHECK: return %[[VAL_18]] : !torch.vtensor<[5,10,14,24],f32> @@ -3685,13 +3685,13 @@ func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_14]], %[[VAL_15]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<4x1x3x3xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[3, 3, 4, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_13]], %[[VAL_14]] : (tensor<3x3x4x1xf32>, !tosa.shape<4>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<5x4x10x20xf32>) -> tensor<5x10x20x4xf32> // CHECK: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_18:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.depthwise_conv2d %[[VAL_13]], %[[VAL_16]], %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_19:.*]] = tosa.depthwise_conv2d %[[VAL_16]], %[[VAL_15]], %[[VAL_12]], %[[VAL_17]], %[[VAL_18]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x5x10x4xf32> // CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_19]] {perms = array} : (tensor<5x5x10x4xf32>) -> tensor<5x4x5x10xf32> // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[5,4,5,10],f32> @@ -3727,17 +3727,17 @@ func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f3 // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> -// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_5]] {perms = array} : (tensor<128x64x1x1xf32>) -> tensor<128x1x1x64xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x64x56x56xf32>) -> tensor<1x56x56x64xf32> // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {values = dense<[1, 55, 56, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_17:.*]] = tosa.slice %[[VAL_13]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> +// CHECK: %[[VAL_17:.*]] = tosa.slice %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : (tensor<1x56x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x56x64xf32> // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 55, 55, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : (tensor<1x55x56x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x55x55x64xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.conv2d %[[VAL_20]], %[[VAL_14]], %[[VAL_12]], %[[VAL_21]], %[[VAL_22]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> +// CHECK: %[[VAL_23:.*]] = tosa.conv2d %[[VAL_20]], %[[VAL_13]], %[[VAL_12]], %[[VAL_21]], %[[VAL_22]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x55x55x64xf32>, tensor<128x1x1x64xf32>, tensor<128xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x28x28x128xf32> // CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_23]] {perms = array} : (tensor<1x28x28x128xf32>) -> tensor<1x128x28x28xf32> // CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<1x128x28x28xf32> -> !torch.vtensor<[1,128,28,28],f32> // CHECK: return %[[VAL_25]] : !torch.vtensor<[1,128,28,28],f32> @@ -3772,11 +3772,11 @@ func.func @torch.aten.convolution$zero_pad_with_sliced_input(%arg0: !torch.vtens // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x224x224xf32>) -> tensor<1x224x224x3xf32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x112x112x32xf32> // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor<1x112x112x32xf32>) -> tensor<1x32x112x112xf32> // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x32x112x112xf32> -> !torch.vtensor<[1,32,112,112],f32> // CHECK: return %[[VAL_18]] : !torch.vtensor<[1,32,112,112],f32> @@ -3810,17 +3810,17 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor<1x3x225x225xf32>) -> tensor<1x225x225x3xf32> // CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor<1x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x225x3xf32> // CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<1x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x224x224x3xf32> // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_12]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x75x75x32xf32> // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor<1x75x75x32xf32>) -> tensor<1x32x75x75xf32> // CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<1x32x75x75xf32> -> !torch.vtensor<[1,32,75,75],f32> // CHECK: return %[[VAL_24]] : !torch.vtensor<[1,32,75,75],f32> @@ -3855,11 +3855,11 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_with_sliced_inp // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_12]], %[[VAL_13]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_12]], %[[VAL_11]], %[[VAL_14]], %[[VAL_15]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_16]] {perms = array} : (tensor) -> tensor // CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,32,112,112],f32> // CHECK: return %[[VAL_18]] @@ -3894,17 +3894,17 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor // CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor // CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_13]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor +// CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_12]], %[[VAL_11]], %[[VAL_20]], %[[VAL_21]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<32x3x3x3xf32>, tensor<32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor // CHECK: %[[VAL_23:.*]] = tosa.transpose %[[VAL_22]] {perms = array} : (tensor) -> tensor // CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor -> !torch.vtensor<[?,32,75,75],f32> // CHECK: return %[[VAL_24]] diff --git a/test/Conversion/TorchToTosa/conv2d_transpose.mlir b/test/Conversion/TorchToTosa/conv2d_transpose.mlir index 7c24dc896630..ba78ba865d5b 100644 --- a/test/Conversion/TorchToTosa/conv2d_transpose.mlir +++ b/test/Conversion/TorchToTosa/conv2d_transpose.mlir @@ -1,8 +1,23 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file | FileCheck %s -// The following test ensures that a tranposed convolution op is not -// lowered in the torch-to-tosa conversion pass. +// The lowering now legalizes transpose convolutions into the TOSA dialect. +// Verify that we emit tosa.transpose_conv2d with the expected reshapes/ +// permutations. +// CHECK-LABEL: func.func @forward +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { +// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32> +// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32> +// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32> +// CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32> +// CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32> +// CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array, stride = array} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32> +// CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32> +// CHECK: } func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { %true = torch.constant.bool true %int1 = torch.constant.int 1 @@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[ %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> return %output : !torch.vtensor<[1,64,2,200],f32> }