@@ -4251,3 +4251,32 @@ func.func @torch.aten.matmul$broadcast(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
4251
4251
%0 = torch.aten.matmul %arg0 , %arg1 : !torch.vtensor <[10 ,3 ,4 ],f32 >, !torch.vtensor <[4 ],f32 > -> !torch.vtensor <[10 ,3 ],f32 >
4252
4252
return %0 : !torch.vtensor <[10 ,3 ],f32 >
4253
4253
}
4254
+
4255
+ // -----
4256
+ // CHECK-LABEL: func.func @torch.aten.linear$f16(
4257
+ // CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[2,4],f16>,
4258
+ // CHECK-SAME: %[[WTS:.*]]: !torch.vtensor<[3,4],f16>,
4259
+ // CHECK-SAME: %[[BIAS:.*]]: !torch.vtensor<[3],f16>) -> !torch.vtensor<[2,3],f16> {
4260
+ // CHECK: %[[BIAS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[BIAS]] : !torch.vtensor<[3],f16> -> tensor<3xf16>
4261
+ // CHECK: %[[WTS_TENSOR:.*]] = torch_c.to_builtin_tensor %[[WTS]] : !torch.vtensor<[3,4],f16> -> tensor<3x4xf16>
4262
+ // CHECK: %[[INP_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INP]] : !torch.vtensor<[2,4],f16> -> tensor<2x4xf16>
4263
+ // CHECK: %[[WTS_TRANSPOSE:.*]] = tosa.transpose %[[WTS_TENSOR]] {perms = array<i32: 1, 0>} : (tensor<3x4xf16>) -> tensor<4x3xf16>
4264
+ // CHECK: %[[INP_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 2, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
4265
+ // CHECK: %[[INP_RESHAPE:.*]] = tosa.reshape %[[INP_TENSOR]], %[[INP_SHAPE]] : (tensor<2x4xf16>, !tosa.shape<3>) -> tensor<1x2x4xf16>
4266
+ // CHECK: %[[WTS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3>
4267
+ // CHECK: %[[WTS_RESHAPE:.*]] = tosa.reshape %[[WTS_TRANSPOSE]], %[[WTS_SHAPE]] : (tensor<4x3xf16>, !tosa.shape<3>) -> tensor<1x4x3xf16>
4268
+ // CHECK: %[[INP_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4269
+ // CHECK: %[[WTS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf16>}> : () -> tensor<1xf16>
4270
+ // CHECK: %[[MATMUL:.*]] = tosa.matmul %[[INP_RESHAPE]], %[[WTS_RESHAPE]], %[[INP_ZP]], %[[WTS_ZP]] : (tensor<1x2x4xf16>, tensor<1x4x3xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x2x3xf32>
4271
+ // CHECK: %[[RES_SHAPE:.*]] = tosa.const_shape {values = dense<[2, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4272
+ // CHECK: %[[RES_RESHAPE:.*]] = tosa.reshape %[[MATMUL]], %[[RES_SHAPE]] : (tensor<1x2x3xf32>, !tosa.shape<2>) -> tensor<2x3xf32>
4273
+ // CHECK: %[[CAST:.*]] = tosa.cast %[[RES_RESHAPE]] : (tensor<2x3xf32>) -> tensor<2x3xf16>
4274
+ // CHECK: %[[BIAS_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
4275
+ // CHECK: %[[BIAS_RESHAPE:.*]] = tosa.reshape %[[BIAS_TENSOR]], %[[BIAS_SHAPE]] : (tensor<3xf16>, !tosa.shape<2>) -> tensor<1x3xf16>
4276
+ // CHECK: %[[ADD:.*]] = tosa.add %[[CAST]], %[[BIAS_RESHAPE]] : (tensor<2x3xf16>, tensor<1x3xf16>) -> tensor<2x3xf16>
4277
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[ADD]] : tensor<2x3xf16> -> !torch.vtensor<[2,3],f16>
4278
+ // CHECK: return %[[RES]]
4279
+ func.func @torch.aten.linear$f16 (%arg0: !torch.vtensor <[2 ,4 ],f16 >, %arg1: !torch.vtensor <[3 ,4 ],f16 >, %arg2: !torch.vtensor <[3 ],f16 >) -> !torch.vtensor <[2 ,3 ],f16 > {
4280
+ %0 = torch.aten.linear %arg0 , %arg1 , %arg2 : !torch.vtensor <[2 ,4 ],f16 >, !torch.vtensor <[3 ,4 ],f16 >, !torch.vtensor <[3 ],f16 > -> !torch.vtensor <[2 ,3 ],f16 >
4281
+ return %0 : !torch.vtensor <[2 ,3 ],f16 >
4282
+ }
0 commit comments