From d4314152e3dfe79d17babcef369327e1820dfcdb Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Thu, 9 Oct 2025 01:31:32 +0100 Subject: [PATCH] [Backend] Fix function attributes after tensor descriptor fallback The tensor descriptor rewriting pass uses 1-to-n type conversion, which will expand a tensordesc argument into multiple new pointer and int arguments. However, the upstream function conversion in mlir doesn't remap the attributes to the corresponding new argument index, which means we may make incorrect assumptions about arguments and miscompile. --- .../Transforms/FunctionTypeConversion.cpp | 85 ++++++++++++++++++- .../rewrite-tensor-descriptor-to-pointer.mlir | 11 +++ 2 files changed, 92 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp b/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp index 0170463cef84..d9f822cbc419 100644 --- a/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp +++ b/lib/Dialect/Triton/Transforms/FunctionTypeConversion.cpp @@ -73,14 +73,91 @@ struct ReturnOpConversion : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// FunctionOpInterfaceSignatureConversion +//===----------------------------------------------------------------------===// +// NOTE: Forked from mlir to support remapping argument attributes correctly in +// a one-to-many type conversion. + +SmallVector +convertFuncOpAttrs(FunctionOpInterface funcOp, + TypeConverter::SignatureConversion &sigConv, + FunctionType newType) { + if (newType.getNumInputs() == funcOp.getNumArguments()) { + return {}; + } + ArrayAttr allArgAttrs = funcOp.getAllArgAttrs(); + if (!allArgAttrs) + return {}; + + SmallVector newAttrs(newType.getNumInputs()); + for (auto i : llvm::seq(allArgAttrs.size())) { + auto mapping = sigConv.getInputMapping(i); + assert(mapping.has_value()); + auto outIdx = mapping->inputNo; + newAttrs[outIdx] = allArgAttrs[i]; + } + return newAttrs; +} + +LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, + const TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + FunctionType type = dyn_cast(funcOp.getFunctionType()); + if (!type) + return failure(); + + // Convert the original function types. + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || + failed(typeConverter.convertTypes(type.getResults(), newResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + typeConverter, &result))) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + + auto newArgAttrs = convertFuncOpAttrs(funcOp, result, newType); + + rewriter.modifyOpInPlace(funcOp, [&] { + funcOp.setType(newType); + if (!newArgAttrs.empty()) { + funcOp.setAllArgAttrs(newArgAttrs); + } + }); + + return success(); +} + +/// Create a default conversion pattern that rewrites the type signature of a +/// FunctionOpInterface op. This only supports ops which use FunctionType to +/// represent their type. +struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { + FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, + MLIRContext *ctx, + const TypeConverter &converter, + PatternBenefit benefit = 1) + : ConversionPattern(converter, functionLikeOpName, benefit, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + FunctionOpInterface funcOp = cast(op); + return convertFuncOpTypes(funcOp, *typeConverter, rewriter); + } +}; + } // namespace void populateFunctionTypeConversions(const TypeConverter &converter, RewritePatternSet &patterns) { - mlir::populateFunctionOpInterfaceTypeConversionPattern( - patterns, converter); - patterns.add(converter, - patterns.getContext()); + auto context = patterns.getContext(); + patterns.add( + triton::FuncOp::getOperationName(), context, converter); + patterns.add(converter, context); } } // namespace mlir::triton diff --git a/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir b/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir index 9290daf8d248..5176d545412c 100644 --- a/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir +++ b/test/Triton/rewrite-tensor-descriptor-to-pointer.mlir @@ -138,3 +138,14 @@ module { // CHECK-DAG: %[[c256:.*]] = arith.constant 256 : i64 // CHECK: %{{.*}}:6 = tt.call @callee(%[[PTR]], %[[c256]], %[[c256]], %[[c256]], %[[c1]], %false) // CHECK-SAME -> (!tt.ptr, i64, i64, i64, i64, i1) + +// ----- + +module { + tt.func public @arg_attr(%arg0: !tt.tensordesc>, %arg1: i32 {tt.divisibility = 16 : i32}) { + tt.return + } +} + +// CHECK-LABEL: @arg_attr +// CHECK-SAME: %arg6: i32 {tt.divisibility = 16 : i32}) {