diff --git a/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp b/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp index 0170463cef84..d9f822cbc419 100644 --- a/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp +++ b/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp @@ -73,14 +73,91 @@ struct ReturnOpConversion : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// FunctionOpInterfaceSignatureConversion +//===----------------------------------------------------------------------===// +// NOTE: Forked from mlir to support remapping argument attributes correctly in +// a one-to-many type conversion. + +SmallVector +convertFuncOpAttrs(FunctionOpInterface funcOp, + TypeConverter::SignatureConversion &sigConv, + FunctionType newType) { + if (newType.getNumInputs() == funcOp.getNumArguments()) { + return {}; + } + ArrayAttr allArgAttrs = funcOp.getAllArgAttrs(); + if (!allArgAttrs) + return {}; + + SmallVector newAttrs(newType.getNumInputs()); + for (auto i : llvm::seq(allArgAttrs.size())) { + auto mapping = sigConv.getInputMapping(i); + assert(mapping.has_value()); + auto outIdx = mapping->inputNo; + newAttrs[outIdx] = allArgAttrs[i]; + } + return newAttrs; +} + +LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, + const TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + FunctionType type = dyn_cast(funcOp.getFunctionType()); + if (!type) + return failure(); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType); + + rewriter.modifyOpInPlace(funcOp, [&] { + funcOp.setType(newType); + if (!newArgAttrs.empty()) { + funcOp.setAllArgAttrs(newArgAttrs); + } + }); + + return success(); +} + +/// Create a default conversion pattern that rewrites the type signature of a +/// FunctionOpInterface op. This only supports ops which use FunctionType to +/// represent their type. +struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { + FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, + MLIRContext *ctx, + const TypeConverter &converter, + PatternBenefit benefit = 1) + : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + FunctionOpInterface funcOp = cast(op); + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; + } // namespace void populateFunctionTypeConversions(const TypeConverter &converter, RewritePatternSet &patterns) { - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); - patterns.add(converter, - patterns.getContext()); + auto context = patterns.getContext(); + patterns.add( + triton::FuncOp::getOperationName(), context, converter); + patterns.add(converter, context); } } // namespace mlir::triton diff --git a/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir b/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir index 9290daf8d248..5176d545412c 100644 --- a/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir +++ b/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir @@ -138,3 +138,14 @@ module { // CHECK-DAG: %[[c256:.*]] = arith.constant 256 : i64 // CHECK: %{{.*}}:6 = tt.call @callee(%[[PTR]], %[[c256]], %[[c256]], %[[c256]], %[[c1]], %false) // CHECK-SAME -> (!tt.ptr, i64, i64, i64, i64, i1) + +// ----- + +module { + tt.func public @arg_attr(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}) { + tt.return + } +} + +// CHECK-LABEL: @arg_attr +// CHECK-SAME: %arg6: i32 {tt.divisibility = 16 : i32}) {