Skip to content

Commit 39482bb

Browse files
peterbell10Jokeren
authored andcommitted
[Backend] Fix function attributes after tensor descriptor fallback (#8406)
The tensor descriptor rewriting pass uses 1-to-n type conversion, which will expand a tensordesc argument into multiple new pointer and int arguments. However, the upstream function conversion in mlir doesn't remap the attributes to the corresponding new argument index, which means we may make incorrect assumptions about arguments and miscompile.
1 parent 201c4d8 commit 39482bb

File tree

2 files changed

+92
-4
lines changed

2 files changed

+92
-4
lines changed

lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,91 @@ struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
7373
}
7474
};
7575

76+
//===----------------------------------------------------------------------===//
77+
// FunctionOpInterfaceSignatureConversion
78+
//===----------------------------------------------------------------------===//
79+
// NOTE: Forked from mlir to support remapping argument attributes correctly in
80+
// a one-to-many type conversion.
81+
82+
SmallVector<Attribute>
83+
convertFuncOpAttrs(FunctionOpInterface funcOp,
84+
TypeConverter::SignatureConversion &sigConv,
85+
FunctionType newType) {
86+
if (newType.getNumInputs() == funcOp.getNumArguments()) {
87+
return {};
88+
}
89+
ArrayAttr allArgAttrs = funcOp.getAllArgAttrs();
90+
if (!allArgAttrs)
91+
return {};
92+
93+
SmallVector<Attribute> newAttrs(newType.getNumInputs());
94+
for (auto i : llvm::seq(allArgAttrs.size())) {
95+
auto mapping = sigConv.getInputMapping(i);
96+
assert(mapping.has_value());
97+
auto outIdx = mapping->inputNo;
98+
newAttrs[outIdx] = allArgAttrs[i];
99+
}
100+
return newAttrs;
101+
}
102+
103+
LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
104+
const TypeConverter &typeConverter,
105+
ConversionPatternRewriter &rewriter) {
106+
FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType());
107+
if (!type)
108+
return failure();
109+
110+
// Convert the original function types.
111+
TypeConverter::SignatureConversion result(type.getNumInputs());
112+
SmallVector<Type, 1> newResults;
113+
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
114+
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
115+
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
116+
typeConverter, &result)))
117+
return failure();
118+
119+
// Update the function signature in-place.
120+
auto newType = FunctionType::get(rewriter.getContext(),
121+
result.getConvertedTypes(), newResults);
122+
123+
auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType);
124+
125+
rewriter.modifyOpInPlace(funcOp, [&] {
126+
funcOp.setType(newType);
127+
if (!newArgAttrs.empty()) {
128+
funcOp.setAllArgAttrs(newArgAttrs);
129+
}
130+
});
131+
132+
return success();
133+
}
134+
135+
/// Create a default conversion pattern that rewrites the type signature of a
136+
/// FunctionOpInterface op. This only supports ops which use FunctionType to
137+
/// represent their type.
138+
struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
139+
FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
140+
MLIRContext *ctx,
141+
const TypeConverter &converter,
142+
PatternBenefit benefit = 1)
143+
: ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
144+
145+
LogicalResult
146+
matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
147+
ConversionPatternRewriter &rewriter) const override {
148+
FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
149+
return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
150+
}
151+
};
152+
76153
} // namespace
77154

78155
void populateFunctionTypeConversions(const TypeConverter &converter,
79156
RewritePatternSet &patterns) {
80-
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
81-
patterns, converter);
82-
patterns.add<CallOpConversion, ReturnOpConversion>(converter,
83-
patterns.getContext());
157+
auto context = patterns.getContext();
158+
patterns.add<FunctionOpInterfaceSignatureConversion>(
159+
triton::FuncOp::getOperationName(), context, converter);
160+
patterns.add<CallOpConversion, ReturnOpConversion>(converter, context);
84161
}
85162

86163
} // namespace mlir::triton

test/Triton/rewrite-tensor-descriptor-to-pointer.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,14 @@ module {
138138
// CHECK-DAG: %[[c256:.*]] = arith.constant 256 : i64
139139
// CHECK: %{{.*}}:6 = tt.call @callee(%[[PTR]], %[[c256]], %[[c256]], %[[c256]], %[[c1]], %false)
140140
// CHECK-SAME -> (!tt.ptr<f32>, i64, i64, i64, i64, i1)
141+
142+
// -----
143+
144+
module {
145+
tt.func public @arg_attr(%arg0: !tt.tensordesc<tensor<128x128xf32>>, %arg1: i32 {tt.divisibility = 16 : i32}) {
146+
tt.return
147+
}
148+
}
149+
150+
// CHECK-LABEL: @arg_attr
151+
// CHECK-SAME: %arg6: i32 {tt.divisibility = 16 : i32}) {

0 commit comments

Comments
 (0)