Skip to content

Commit 493bb33

Browse files
authored
[TorchToLinalg] simplify non-broadcast unit dim indexing maps in elementise generics (#4107)
This change is made to reduce the pattern-matching load for fusing elementwise generic ops with non-broadcasting unit dims. For example, adding tensors with shapes `[6,1]` and `[1]`, the output shape will be `[6,1]`. Before this change, the indexing maps were inconsistent between the inputs and outputs for the unit-dim (constant 0 for inputs, and a dim expression for the output). --------- Signed-off-by: zjgarvey <zjgarvey@gmail.com>
1 parent 692b2c0 commit 493bb33

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

lib/Conversion/TorchToLinalg/Utils.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,26 +255,47 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
255255
// all sizes along that result dimension are statically 1.
256256
auto c1 = b.create<arith::ConstantIndexOp>(loc, /*value=*/1);
257257
SmallVector<Value> resultShape(resultRank, c1);
258+
259+
// Record whether or not all corresponding input dims are statically 1.
260+
// We don't want to use a constant 0 expression for the input indexing maps in
261+
// this case, since there is no broadcasting. Using the constant 0 expressions
262+
// for the inputs, when they actually do correspond to an output dim, makes
263+
// subsequent optimizations (e.g. fusions) more difficult.
264+
DenseSet<int64_t> nonStaticOneResultDims;
265+
for (int64_t i = 0; i < resultRank; i++) {
266+
for (Value tensorOperand : tensorOperands) {
267+
auto type = cast<RankedTensorType>(tensorOperand.getType());
268+
auto index = i - (resultRank - type.getRank());
269+
if (index < 0)
270+
continue;
271+
int64_t dimSize = makeShapeTorchCompatible(type.getShape())[index];
272+
if (dimSize != 1) {
273+
nonStaticOneResultDims.insert(i);
274+
break;
275+
}
276+
}
277+
}
278+
258279
SmallVector<AffineMap> indexingMaps;
259280
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
281+
260282
for (Value tensorOperand : tensorOperands) {
261283
SmallVector<AffineExpr> exprs;
262284
auto type = cast<RankedTensorType>(tensorOperand.getType());
263285
for (auto size :
264286
llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) {
265-
// If the size is statically known to be 1, we don't want any
266-
// error guards to be spuriously emitted, since we are specifically
267-
// allowing size-1 broadcasts in this case, as they correspond to a
268-
// constant-0 indexing map.
269-
if (size.value() == 1) {
270-
exprs.push_back(b.getAffineConstantExpr(0));
271-
continue;
272-
}
273287

274288
// The rank of this operand might be smaller than the overall rank of
275289
// the broadcast. Add an offset to correlate it to the correct
276290
// dimension of the result.
277-
auto resultDim = size.index() + (resultRank - type.getRank());
291+
int64_t resultDim = size.index() + (resultRank - type.getRank());
292+
293+
// If the size is statically 1 and we don't know that the result dim is
294+
// statically 1, use an affine constant expression to broadcast.
295+
if (size.value() == 1 && nonStaticOneResultDims.contains(resultDim)) {
296+
exprs.push_back(b.getAffineConstantExpr(0));
297+
continue;
298+
}
278299

279300
// The generated linalg op will now be iterating along the full size
280301
// of this dimension. Record that fact.

test/Conversion/TorchToLinalg/elementwise.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,16 @@ func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>
118118
%0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
119119
return %0 : !torch.vtensor<[1,?,32,128],f16>
120120
}
121+
122+
// -----
123+
124+
// CHECK-LABEL: func.func @elementwise_add_non_broadcast_unit_dims(
125+
// CHECK: linalg.generic {indexing_maps = [
126+
// CHECK-SAME: affine_map<(d0, d1) -> (d0, d1)>,
127+
// CHECK-SAME: affine_map<(d0, d1) -> (d1)>,
128+
// CHECK-SAME: affine_map<(d0, d1) -> (d0, d1)>]
129+
func.func @elementwise_add_non_broadcast_unit_dims(%arg0: !torch.vtensor<[6,1],bf16>, %arg1 : !torch.vtensor<[1],bf16>) -> !torch.vtensor<[6,1],bf16> {
130+
%int1_13 = torch.constant.int 1
131+
%11 = torch.aten.add.Tensor %arg0, %arg1, %int1_13 : !torch.vtensor<[6,1],bf16>, !torch.vtensor<[1],bf16>, !torch.int -> !torch.vtensor<[6,1],bf16>
132+
return %11 : !torch.vtensor<[6,1],bf16>
133+
}

0 commit comments

Comments
 (0)