@@ -73,14 +73,91 @@ struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
73
73
}
74
74
};
75
75
76
+ // ===----------------------------------------------------------------------===//
77
+ // FunctionOpInterfaceSignatureConversion
78
+ // ===----------------------------------------------------------------------===//
79
+ // NOTE: Forked from mlir to support remapping argument attributes correctly in
80
+ // a one-to-many type conversion.
81
+
82
+ SmallVector<Attribute>
83
+ convertFuncOpAttrs (FunctionOpInterface funcOp,
84
+ TypeConverter::SignatureConversion &sigConv,
85
+ FunctionType newType) {
86
+ if (newType.getNumInputs () == funcOp.getNumArguments ()) {
87
+ return {};
88
+ }
89
+ ArrayAttr allArgAttrs = funcOp.getAllArgAttrs ();
90
+ if (!allArgAttrs)
91
+ return {};
92
+
93
+ SmallVector<Attribute> newAttrs (newType.getNumInputs ());
94
+ for (auto i : llvm::seq (allArgAttrs.size ())) {
95
+ auto mapping = sigConv.getInputMapping (i);
96
+ assert (mapping.has_value ());
97
+ auto outIdx = mapping->inputNo ;
98
+ newAttrs[outIdx] = allArgAttrs[i];
99
+ }
100
+ return newAttrs;
101
+ }
102
+
103
+ LogicalResult convertFuncOpTypes (FunctionOpInterface funcOp,
104
+ const TypeConverter &typeConverter,
105
+ ConversionPatternRewriter &rewriter) {
106
+ FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType ());
107
+ if (!type)
108
+ return failure ();
109
+
110
+ // Convert the original function types.
111
+ TypeConverter::SignatureConversion result (type.getNumInputs ());
112
+ SmallVector<Type, 1 > newResults;
113
+ if (failed (typeConverter.convertSignatureArgs (type.getInputs (), result)) ||
114
+ failed (typeConverter.convertTypes (type.getResults (), newResults)) ||
115
+ failed (rewriter.convertRegionTypes (&funcOp.getFunctionBody (),
116
+ typeConverter, &result)))
117
+ return failure ();
118
+
119
+ // Update the function signature in-place.
120
+ auto newType = FunctionType::get (rewriter.getContext (),
121
+ result.getConvertedTypes (), newResults);
122
+
123
+ auto newArgAttrs = convertFuncOpAttrs (funcOp, result, newType);
124
+
125
+ rewriter.modifyOpInPlace (funcOp, [&] {
126
+ funcOp.setType (newType);
127
+ if (!newArgAttrs.empty ()) {
128
+ funcOp.setAllArgAttrs (newArgAttrs);
129
+ }
130
+ });
131
+
132
+ return success ();
133
+ }
134
+
135
+ // / Create a default conversion pattern that rewrites the type signature of a
136
+ // / FunctionOpInterface op. This only supports ops which use FunctionType to
137
+ // / represent their type.
138
+ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
139
+ FunctionOpInterfaceSignatureConversion (StringRef functionLikeOpName,
140
+ MLIRContext *ctx,
141
+ const TypeConverter &converter,
142
+ PatternBenefit benefit = 1 )
143
+ : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
144
+
145
+ LogicalResult
146
+ matchAndRewrite (Operation *op, ArrayRef<Value> /* operands*/ ,
147
+ ConversionPatternRewriter &rewriter) const override {
148
+ FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
149
+ return convertFuncOpTypes (funcOp, *typeConverter, rewriter);
150
+ }
151
+ };
152
+
76
153
} // namespace
77
154
78
155
void populateFunctionTypeConversions (const TypeConverter &converter,
79
156
RewritePatternSet &patterns) {
80
- mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
81
- patterns, converter);
82
- patterns. add <CallOpConversion, ReturnOpConversion>(converter,
83
- patterns.getContext () );
157
+ auto context = patterns. getContext ();
158
+ patterns. add <FunctionOpInterfaceSignatureConversion>(
159
+ triton::FuncOp::getOperationName (), context, converter);
160
+ patterns.add <CallOpConversion, ReturnOpConversion>(converter, context );
84
161
}
85
162
86
163
} // namespace mlir::triton
0 commit comments