diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index 6a92557749..f2d903b733 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -48,6 +48,7 @@ bool disableQuantZeroPoint; // common for both bool enableKrnlBufferReuse; // common for both bool enableSafeCodeGen; // common for both bool disableMemRefPrefetch; // common for both +bool enableForceF32Cast; // common for both uint64_t compilationNumThreads; // common for both std::vector decomposeOpsInONNX; // common for both EmissionTargetType emissionTarget; // onnx-mlir only @@ -267,6 +268,18 @@ static llvm::cl::opt disableMemRefPrefetchOpt( llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions)); +static llvm::cl::opt enableForceF32CastOpt("enable-force-f32-cast", + llvm::cl::desc( + "Enable the transformace of cast from F16 to F32 (default=false).\n" + "Set to 'true' if you want to enable the transformation.\n" + "This transformation is a temporary solution for the error for CastOp " + "when a float16 onnx model is converted to float32 with " + "utils/convertF16ToF32.py. Some of the ONNXCastOp in the model " + "are not converted. The transformation blindly changes the to() " + "TypeAttr from F16 to F32"), + llvm::cl::location(enableForceF32Cast), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::list> decomposeOpsInONNXOpt("decompose-op-in-onnx", llvm::cl::desc("Specify ONNX operations to decompose.\n" diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 294839eb62..be43f8c1e9 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -94,6 +94,7 @@ extern bool enableKrnlBufferReuse; // common for both extern bool enableSafeCodeGen; // common for both extern bool disableMemRefPrefetch; // common for both extern uint64_t compilationNumThreads; // common for both +extern bool enableForceF32Cast; // common for both extern std::vector decomposeOpsInONNX; // common for both extern EmissionTargetType emissionTarget; // onnx-mlir only extern bool invokeOnnxVersionConverter; // onnx-mlir only diff --git a/src/Dialect/ONNX/Transforms/Decompose.cpp b/src/Dialect/ONNX/Transforms/Decompose.cpp index 658689b059..d6acba6f7b 100644 --- a/src/Dialect/ONNX/Transforms/Decompose.cpp +++ b/src/Dialect/ONNX/Transforms/Decompose.cpp @@ -1052,6 +1052,32 @@ struct SumToAddPattern : public OpRewritePattern { } }; +// This pattern is a temporary solution to fix the bug in onnx model conversion. +// An onnx model with dtype=float16 can be converted to float32 with a python +// script. However, the CastOp inside the model may not be converted. +// This transformation blindly change CastOp(input, to=f16) to +// CastOp(input, to=f32) +struct CastF16ToF32Pattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + ONNXCastOp castOp, PatternRewriter &rewriter) const final { + auto loc = castOp.getLoc(); + onnx_mlir::OnnxBuilder create(rewriter, loc); + + if (castOp.getTo().isF16()) { + Type f32Type = rewriter.getF32Type(); + castOp.setTo(f32Type); + // UnrankedTensor is used because shape will be propagated by shape infer + Type newOutputType = UnrankedTensorType::get(f32Type); + Value output = castOp.getOutput(); + output.setType(newOutputType); + } + + return success(); + } +}; + // ============================================================================= // Pattern for replacing CastLikeOp by CastOp. // ============================================================================= @@ -1242,6 +1268,11 @@ void DecomposeONNXToONNXPass::runOnOperation() { return !onnx_mlir::canSequenceAtBeReplaced(op.getResult()); }); + if (onnx_mlir::enableForceF32Cast) { + target.addDynamicallyLegalOp( + [](ONNXCastOp op) { return !op.getTo().isF16(); }); + } + // Rewrite ONNXConstantOp with scalar values into the one using ElementAttrs. target.addDynamicallyLegalOp([](ONNXConstantOp op) { return !(op.getValueFloatAttr() || op.getValueFloatsAttr() || @@ -1299,6 +1330,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns( patterns.insert(context); patterns.insert(context); patterns.insert(context); + if (enableForceF32Cast) + patterns.insert(context); if (!onnx_mlir::decomposeOpsInONNX.empty()) { for (const auto &op : onnx_mlir::decomposeOpsInONNX) { diff --git a/utils/convertF16ToF32.py b/utils/convertF16ToF32.py new file mode 100644 index 0000000000..bdbb0cde06 --- /dev/null +++ b/utils/convertF16ToF32.py @@ -0,0 +1,78 @@ +import onnx +import numpy as np +from onnx import numpy_helper, external_data_helper + + +def convert_fp16_to_fp32_with_external_data( + input_model_path: str, output_model_path: str +): + # Load the model (with external weights) + model = onnx.load_model( + input_model_path, + load_external_data=True, # Necessary to read the external tensors + ) + + # Convert all initializers (weights) + for tensor in model.graph.initializer: + if tensor.data_type == onnx.TensorProto.FLOAT16: + array_fp16 = numpy_helper.to_array(tensor) + array_fp32 = array_fp16.astype(np.float32) + tensor_fp32 = numpy_helper.from_array(array_fp32, tensor.name) + tensor.CopyFrom(tensor_fp32) + + # Update type in inputs/outputs/value_info + def update_elem_type_to_fp32(value_info): + if value_info.type.HasField("tensor_type"): + tt = value_info.type.tensor_type + if tt.elem_type == onnx.TensorProto.FLOAT16: + tt.elem_type = onnx.TensorProto.FLOAT + + for vi in model.graph.input: + update_elem_type_to_fp32(vi) + for vi in model.graph.output: + update_elem_type_to_fp32(vi) + for vi in model.graph.value_info: + update_elem_type_to_fp32(vi) + + # Update node attributes with FP16 tensors + for node in model.graph.node: + for attr in node.attribute: + if ( + attr.type == onnx.AttributeProto.TENSOR + and attr.t.data_type == onnx.TensorProto.FLOAT16 + ): + arr_fp16 = numpy_helper.to_array(attr.t) + arr_fp32 = arr_fp16.astype(np.float32) + attr_fp32 = numpy_helper.from_array(arr_fp32) + attr.t.CopyFrom(attr_fp32) + + # Save model with external data + onnx.save_model( + model, + output_model_path, + save_as_external_data=True, # Keep external data format + all_tensors_to_one_file=True, # Or False if you prefer individual files + location="model_fp32.data", # Change this if needed + size_threshold=1024, # Keep small weights inline + convert_attribute=True, + ) + + +# Example usage: + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--input", type=str, required=True, help="Path of the input onnx model(float16)" + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Path of the output onnx model(converted to float32", + ) + args = parser.parse_args() + print(args.input, args.output) + convert_fp16_to_fp32_with_external_data(args.input, args.output)