Skip to content

Commit 894b05b

Browse files
committed
Add lit test for aten.linear
1 parent 40d2a24 commit 894b05b

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,10 +2159,6 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
21592159
auto bias = adaptor.getBias();
21602160
auto biasTy = bias.getType();
21612161

2162-
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, bias).failed())
2163-
return rewriter.notifyMatchFailure(
2164-
op, "Failed to equalize ranks among operands and result");
2165-
21662162
// TOSA does not mandate that elementwise op tensors need to be ranked.
21672163
if (!isa<Torch::NoneType>(biasTy) && !isa<TensorType>(biasTy))
21682164
return rewriter.notifyMatchFailure(
@@ -2210,7 +2206,13 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
22102206
.value();
22112207

22122208
if (!isa<Torch::NoneType>(biasTy)) {
2213-
// Bias addition broadcasts to the matmul output shape.
2209+
// Broadcast bias to the matmul output shape for addition
2210+
if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), matmulPlusBias,
2211+
bias)
2212+
.failed())
2213+
return rewriter.notifyMatchFailure(
2214+
op, "Failed to equalize ranks among operands and result");
2215+
22142216
matmulPlusBias =
22152217
rewriter
22162218
.create<tosa::AddOp>(op->getLoc(), matmulPlusBias.getType(),

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4251,3 +4251,32 @@ func.func @torch.aten.matmul$broadcast(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
42514251
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,3,4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[10,3],f32>
42524252
return %0 : !torch.vtensor<[10,3],f32>
42534253
}
4254+
4255+
// -----
4256+
// CHECK-LABEL: func.func @torch.aten.linear$f16(
4257+
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[2,4],f16>,
4258+
// CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[3,4],f16>,
4259+
// CHECK-SAME: %[[BIAS:.*]]: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4260+
// CHECK: %[[BIAS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[BIAS]] : !torch.vtensor<[3],f16> -> tensor<3xf16>
4261+
// CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[3,4],f16> -> tensor<3x4xf16>
4262+
// CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[2,4],f16> -> tensor<2x4xf16>
4263+
// CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_TENSOR]] {perms = array<i32: 1, 0>} : (tensor<3x4xf16>) -> tensor<4x3xf16>
4264+
// CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4265+
// CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16>
4266+
// CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
4267+
// CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16>
4268+
// CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4269+
// CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4270+
// CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32>
4271+
// CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4272+
// CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
4273+
// CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPE]] : (tensor<2x3xf32>) -> tensor<2x3xf16>
4274+
// CHECK: %[[BIAS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4275+
// CHECK: %[[BIAS_RESHAPE:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16>
4276+
// CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPE]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16>
4277+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x3xf16> -> !torch.vtensor<[2,3],f16>
4278+
// CHECK: return %[[RES]]
4279+
func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch.vtensor<[3,4],f16>, %arg2: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4280+
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
4281+
return %0 : !torch.vtensor<[2,3],f16>
4282+
}

0 commit comments

Comments
 (0)