@@ -73,14 +73,91 @@ struct ReturnOpConversion : public OpConversionPattern<ReturnOp> {
7373 }
7474};
7575
76+ // ===----------------------------------------------------------------------===//
77+ // FunctionOpInterfaceSignatureConversion
78+ // ===----------------------------------------------------------------------===//
79+ // NOTE: Forked from mlir to support remapping argument attributes correctly in
80+ // a one-to-many type conversion.
81+
82+ SmallVector<Attribute>
83+ convertFuncOpAttrs (FunctionOpInterface funcOp,
84+ TypeConverter::SignatureConversion &sigConv,
85+ FunctionType newType) {
86+ if (newType.getNumInputs () == funcOp.getNumArguments ()) {
87+ return {};
88+ }
89+ ArrayAttr allArgAttrs = funcOp.getAllArgAttrs ();
90+ if (!allArgAttrs)
91+ return {};
92+
93+ SmallVector<Attribute> newAttrs (newType.getNumInputs ());
94+ for (auto i : llvm::seq (allArgAttrs.size ())) {
95+ auto mapping = sigConv.getInputMapping (i);
96+ assert (mapping.has_value ());
97+ auto outIdx = mapping->inputNo ;
98+ newAttrs[outIdx] = allArgAttrs[i];
99+ }
100+ return newAttrs;
101+ }
102+
103+ LogicalResult convertFuncOpTypes (FunctionOpInterface funcOp,
104+ const TypeConverter &typeConverter,
105+ ConversionPatternRewriter &rewriter) {
106+ FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType ());
107+ if (!type)
108+ return failure ();
109+
110+ // Convert the original function types.
111+ TypeConverter::SignatureConversion result (type.getNumInputs ());
112+ SmallVector<Type, 1 > newResults;
113+ if (failed (typeConverter.convertSignatureArgs (type.getInputs (), result)) ||
114+ failed (typeConverter.convertTypes (type.getResults (), newResults)) ||
115+ failed (rewriter.convertRegionTypes (&funcOp.getFunctionBody (),
116+ typeConverter, &result)))
117+ return failure ();
118+
119+ // Update the function signature in-place.
120+ auto newType = FunctionType::get (rewriter.getContext (),
121+ result.getConvertedTypes (), newResults);
122+
123+ auto newArgAttrs = convertFuncOpAttrs (funcOp, result, newType);
124+
125+ rewriter.modifyOpInPlace (funcOp, [&] {
126+ funcOp.setType (newType);
127+ if (!newArgAttrs.empty ()) {
128+ funcOp.setAllArgAttrs (newArgAttrs);
129+ }
130+ });
131+
132+ return success ();
133+ }
134+
135+ // / Create a default conversion pattern that rewrites the type signature of a
136+ // / FunctionOpInterface op. This only supports ops which use FunctionType to
137+ // / represent their type.
138+ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
139+ FunctionOpInterfaceSignatureConversion (StringRef functionLikeOpName,
140+ MLIRContext *ctx,
141+ const TypeConverter &converter,
142+ PatternBenefit benefit = 1 )
143+ : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {}
144+
145+ LogicalResult
146+ matchAndRewrite (Operation *op, ArrayRef<Value> /* operands*/ ,
147+ ConversionPatternRewriter &rewriter) const override {
148+ FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
149+ return convertFuncOpTypes (funcOp, *typeConverter, rewriter);
150+ }
151+ };
152+
76153} // namespace
77154
78155void populateFunctionTypeConversions (const TypeConverter &converter,
79156 RewritePatternSet &patterns) {
80- mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::triton::FuncOp>(
81- patterns, converter);
82- patterns. add <CallOpConversion, ReturnOpConversion>(converter,
83- patterns.getContext () );
157+ auto context = patterns. getContext ();
158+ patterns. add <FunctionOpInterfaceSignatureConversion>(
159+ triton::FuncOp::getOperationName (), context, converter);
160+ patterns.add <CallOpConversion, ReturnOpConversion>(converter, context );
84161}
85162
86163} // namespace mlir::triton
0 commit comments