From d345c260c885e04daa59f746b58ca60afb407df0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Sun, 5 Oct 2025 22:31:01 -0700 Subject: [PATCH 01/23] update --- onnxruntime/core/graph/graph.cc | 205 +++++++++++++++++++++++++++++++- 1 file changed, 201 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3f6443aa73d4c..7befdfbcf05b8 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2628,6 +2628,51 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { graph_(graph), options_(options) { node_output_types_.resize(node.OutputDefs().size()); + + // If it's a Reshape operator, its shape input could be from another op's output and that output + // could be shape-inferred via ONNX Operators' type and shape inference. + // In following two cases, it returns TensorProto from shape input's TensorShapeProto + // - The shape of the shape input is [1] + // - The shape of the shape input has rank > 1 and the all the dimensions have value, e.g. [1, 3, 64, 64] + if (node_.OpType() == "Reshape") { + ORT_ENFORCE(node.InputDefs().size() == 2); + auto def = node_.InputDefs()[1]; // "shape" input + if (def && def->Shape() && def->Shape()->dim_size() > 0) { + auto tensor_shape_proto = def->Shape(); + + // The shape of the "shape" input is [1] + if (tensor_shape_proto->dim_size() == 1 && + tensor_shape_proto->dim()[0].has_dim_value() && + tensor_shape_proto->dim()[0].dim_value() == 1) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(1); + tensor_proto.add_int64_data(tensor_shape_proto->dim()[0].dim_value()); + tensor_proto_for_shape_.push_back(std::move(tensor_proto)); + } + + // The shape of the shape input has rank > 1 and the all the dimensions have value (not symbolic) + if (tensor_shape_proto->dim_size() > 1) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(tensor_shape_proto->dim_size()); + bool all_values = true; + for (const auto& dim : tensor_shape_proto->dim()) { + if (dim.has_dim_value()) { + tensor_proto.add_int64_data(dim.dim_value()); + } else { + all_values = false; + break; + } + } + + if (all_values) { + tensor_proto_for_shape_.push_back(std::move(tensor_proto)); + } + } + } + } + } void RunInferencing() { @@ -2675,10 +2720,25 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { if (!def) return nullptr; - // only return data if it's for a constant initializer. checks for outer scope initializers - // if this is a subgraph and the name isn't found locally. + // Returns if it's a constant initializer. + // Checks for outer scope initializers if this is a subgraph and the name isn't found locally. const TensorProto* initializer = graph_.GetConstantInitializer(def->Name(), true); - return initializer; + if (initializer) { + return initializer; + } + + // If it's a Reshape operator, its shape input could be from another op's output and that output + // could be shape-inferred via ONNX Operators' type and shape inference. + // In following two cases, it returns TensorProto from shape input's TensorShapeProto + // - The shape of the shape input is [1] + // - The shape of the shape input has rank > 1, e.g. [1, 3, 64, 64] + if (node_.OpType() == "Reshape") { + if (tensor_proto_for_shape_.size() == 1) { + return &tensor_proto_for_shape_[0]; + } + } + + return nullptr; } // ORT does not implement partial data propagation yet so just return nullptr. @@ -2713,12 +2773,98 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { Node& node_; // node_output_types_ will be populated by the operator-specific shape inference. std::vector node_output_types_; + std::vector tensor_proto_for_shape_; SubgraphInferencingFunc subgraph_inferencing_func_; std::vector> graph_inferencers_; const Graph& graph_; const Graph::ResolveOptions& options_; }; +class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext { + public: + DataPropagationContextImpl(Node& node, const Graph& graph) noexcept : node_(node), graph_(graph) { + node_output_types_.resize(node.OutputDefs().size()); + } + + const AttributeProto* getAttribute(const std::string& name) const override { + auto& attribute_value_map = node_.GetAttributes(); + auto iter = attribute_value_map.find(name); + if (iter == attribute_value_map.end()) { + return nullptr; + } + return &iter->second; + } + + size_t getNumInputs() const noexcept override { + return node_.InputDefs().size(); + } + + const TypeProto* getInputType(size_t index) const override { + if (index >= getNumInputs()) { + return nullptr; + } + + const TypeProto* type = nullptr; + auto p_node_arg = node_.InputDefs().at(index); + if ((nullptr != p_node_arg) && p_node_arg->Exists()) { + type = p_node_arg->TypeAsProto(); + } + + return type; + } + + size_t getNumOutputs() const noexcept override { + return node_output_types_.size(); + } + + const TypeProto* getOutputType(size_t index) const override { + if (index >= getNumOutputs()) { + return nullptr; + } + + return &node_output_types_[index]; + } + + const TensorShapeProto* getInputData(size_t index) override { + if (index >= getNumInputs()) { + return nullptr; + } + + auto def = node_.InputDefs()[index]; + if (!def) + return nullptr; + + auto type_proto = def->TypeAsProto(); + + if (!type_proto->has_tensor_type() || !type_proto->tensor_type().has_shape()) { + return nullptr; + } + + return &type_proto->tensor_type().shape(); + } + + void addOutputData(size_t index, TensorShapeProto&& tsp) { + if (index >= node_output_types_.size()) return; + + TypeProto& type_proto = node_output_types_[index]; + *type_proto.mutable_tensor_type()->mutable_shape() = std::move(tsp); + } + + void RunInferencing() { + auto schema = node_.Op(); + if (nullptr != schema) { + schema->GetDataPropagationFunction()(*this); + } + } + + std::vector InferredOutputTypes() const { return node_output_types_; } + + private: + Node& node_; + const Graph& graph_; + std::vector node_output_types_; +}; + Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const std::vector& input_types, std::vector& output_types, @@ -2924,7 +3070,58 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso ORT_RETURN_IF_ERROR(status); } - const auto& onnx_inferred_types(context.InferredOutputTypes()); + auto& onnx_inferred_types(context.InferredOutputTypes()); + + // If it's a "Shape" operator, after running "Shape" op's shape inference, only the rank of the node's output is + // resolved but not the real shape/dimension values. + // Need to run data propagation to get the shape values. + if (node.OpType() == "Shape") { + // This function should be called after executing operator's TypeAndShapeInferenceFunction() defined in ONNX OpSchema. + DataPropagationContextImpl data_propagation_context(node, *this); + + { + auto status = Status::OK(); + ORT_TRY { + data_propagation_context.RunInferencing(); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node (", node.Name(), ") Op (", node.OpType(), ") ", ex.what()); + }); + } + ORT_RETURN_IF_ERROR(status); + } + + const auto& onnx_inferred_data(data_propagation_context.InferredOutputTypes()); + + if (!onnx_inferred_data.empty()) { + ORT_ENFORCE(onnx_inferred_data.size() == onnx_inferred_types.size(), ""); + + for (size_t idx = 0; idx < onnx_inferred_types.size(); ++idx) { + auto& type_proto = onnx_inferred_types[idx]; + auto& type_proto_with_shape_data_inferred = onnx_inferred_data[idx]; + + auto dim_size = type_proto_with_shape_data_inferred.tensor_type().shape().dim_size(); + auto are_all_dimensions_have_value = (dim_size > 0) ? true : false; + + for (int i = 0; i < dim_size; ++i) { + if (!type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).has_dim_value()) { + are_all_dimensions_have_value = false; + break; + } + } + + if (are_all_dimensions_have_value) { + type_proto.mutable_tensor_type()->clear_shape(); + + for (int i = 0; i < dim_size; ++i) { + auto value = type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).dim_value(); + type_proto.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(value); + } + } + } + } + } // Infer and verify node output arg type information. int i = -1; From cb52305412fac4b4077f8f73d8555de3cdab5b0c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 6 Oct 2025 12:25:36 -0700 Subject: [PATCH 02/23] Modify to make it work correctly --- include/onnxruntime/core/graph/node_arg.h | 10 ++ onnxruntime/core/graph/graph.cc | 153 +++++++++++----------- 2 files changed, 89 insertions(+), 74 deletions(-) diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 0ddf1a2b9d3de..5cf8f8acd370a 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -107,6 +107,9 @@ class NodeArg { /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } + /** Gets the inferred shape as a TensorShapeProto. */ + const ONNX_NAMESPACE::TensorShapeProto& GetInferredTensorShapeProto() const noexcept { return tensor_shape_proto_after_shape_inferred_; } + /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ bool Exists() const noexcept; @@ -128,6 +131,13 @@ class NodeArg { // Node arg name, type and shape. NodeArgInfo node_arg_info_; + // If this node has a Tensor type, this is the TensorShapeProto of the node's TypeProto_Tensor. + // Calling Op's TypeAndShapeInferenceFunction() is not enough for some operators during type and shape inference, + // e.g. Shape op, as it only gets the inferred rank/dimension size value. + // The Op's PartialDataPropagationFunction() defined in ONNX Op Schema should also be called to get the shape dimension values. + // The variable is used for storing that shape dimension values so that inferred shape can be correctly propagated through out the graph. + ONNX_NAMESPACE::TensorShapeProto tensor_shape_proto_after_shape_inferred_; + // Flag indicates whether <*this> node arg exists or not. bool exists_; }; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 7befdfbcf05b8..64978bf6bd95d 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2629,35 +2629,50 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { options_(options) { node_output_types_.resize(node.OutputDefs().size()); - // If it's a Reshape operator, its shape input could be from another op's output and that output - // could be shape-inferred via ONNX Operators' type and shape inference. - // In following two cases, it returns TensorProto from shape input's TensorShapeProto - // - The shape of the shape input is [1] - // - The shape of the shape input has rank > 1 and the all the dimensions have value, e.g. [1, 3, 64, 64] + // The following code handles cases where a node needs to retrieve a ShapeTensorProto + // that stores the actual or inferred shape information in its NodeArg. + // + // Focus on Reshape operator for now. + // + // For the Reshape operator, its input shape may come from a producer node such as a Shape operator, + // which can store the inferred shape as a ShapeTensorProto within the NodeArg. + // In such cases, the Reshape operator should convert this ShapeTensorProto into a TensorProto. + // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, + // allowing the real dimension values to be correctly used. + // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. + // + // Note: Only certain ShapeTensorProto cases will be converted to TensorProto, specifically when: + // - The shape input has shape [1], meaning the shape only has one dimension value, or + // - The shape input has rank > 1 and all dimensions have known (concrete) values, + // e.g., [1, 3, 64, 64]. + // + // Note: The following code is placed here instead of in getInputData() because + // getInputData() is a const function that returns a pointer. As a result, + // we cannot create a new ShapeTensorProto and store it in this class within a const function. if (node_.OpType() == "Reshape") { - ORT_ENFORCE(node.InputDefs().size() == 2); - auto def = node_.InputDefs()[1]; // "shape" input + // Get the "shape" input. + auto def = node_.InputDefs()[1]; // It's safe to access, the Reshape operator is guaranteed to have two inputs if (def && def->Shape() && def->Shape()->dim_size() > 0) { - auto tensor_shape_proto = def->Shape(); + auto tensor_shape_proto = def->GetInferredTensorShapeProto(); // The shape of the "shape" input is [1] - if (tensor_shape_proto->dim_size() == 1 && - tensor_shape_proto->dim()[0].has_dim_value() && - tensor_shape_proto->dim()[0].dim_value() == 1) { + if (tensor_shape_proto.dim_size() == 1 && + tensor_shape_proto.dim()[0].has_dim_value() && + tensor_shape_proto.dim()[0].dim_value() == 1) { TensorProto tensor_proto; tensor_proto.set_data_type(TensorProto_DataType_INT64); tensor_proto.add_dims(1); - tensor_proto.add_int64_data(tensor_shape_proto->dim()[0].dim_value()); + tensor_proto.add_int64_data(tensor_shape_proto.dim()[0].dim_value()); tensor_proto_for_shape_.push_back(std::move(tensor_proto)); } - // The shape of the shape input has rank > 1 and the all the dimensions have value (not symbolic) - if (tensor_shape_proto->dim_size() > 1) { + // The shape of the "shape" input has rank > 1 and the all the dimensions have value (not symbolic) + if (tensor_shape_proto.dim_size() > 1) { TensorProto tensor_proto; tensor_proto.set_data_type(TensorProto_DataType_INT64); - tensor_proto.add_dims(tensor_shape_proto->dim_size()); + tensor_proto.add_dims(tensor_shape_proto.dim_size()); bool all_values = true; - for (const auto& dim : tensor_shape_proto->dim()) { + for (const auto& dim : tensor_shape_proto.dim()) { if (dim.has_dim_value()) { tensor_proto.add_int64_data(dim.dim_value()); } else { @@ -2672,7 +2687,6 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } } } - } void RunInferencing() { @@ -2727,11 +2741,8 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { return initializer; } - // If it's a Reshape operator, its shape input could be from another op's output and that output - // could be shape-inferred via ONNX Operators' type and shape inference. - // In following two cases, it returns TensorProto from shape input's TensorShapeProto - // - The shape of the shape input is [1] - // - The shape of the shape input has rank > 1, e.g. [1, 3, 64, 64] + // If it's a Reshape operator, simply return the ShapeTensorProto + // that was previously created in the InferenceContextImpl's constructor. if (node_.OpType() == "Reshape") { if (tensor_proto_for_shape_.size() == 1) { return &tensor_proto_for_shape_[0]; @@ -2780,6 +2791,8 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { const Graph::ResolveOptions& options_; }; +// An implementation of the DataPropagationContext interface optional by operator-specific +// shape inference for onnxruntime graphs. class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext { public: DataPropagationContextImpl(Node& node, const Graph& graph) noexcept : node_(node), graph_(graph) { @@ -3056,11 +3069,26 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // returned here. SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes); InferenceContextImpl context(node, func, *this, options); + DataPropagationContextImpl data_propagation_context(node, *this); { auto status = Status::OK(); ORT_TRY { context.RunInferencing(); + + // The following code handles cases where calling the operator-specific shape inference alone + // (i.e., the TypeAndShapeInferenceFunction() defined in the ONNX operator schema) + // is insufficient to fully infer the output shape. + // + // Focus on Shape operator for now. + // + // For a Shape operator, running its shape inference only determines the rank (number of dimensions) + // of the output tensor, not the actual dimension values. To infer those values and propagate them + // throughout the graph, the PartialDataPropagationFunction() defined in the ONNX operator schema + // must also be executed. + if (node.OpType() == "Shape") { + data_propagation_context.RunInferencing(); + } } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -3070,58 +3098,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso ORT_RETURN_IF_ERROR(status); } - auto& onnx_inferred_types(context.InferredOutputTypes()); - - // If it's a "Shape" operator, after running "Shape" op's shape inference, only the rank of the node's output is - // resolved but not the real shape/dimension values. - // Need to run data propagation to get the shape values. - if (node.OpType() == "Shape") { - // This function should be called after executing operator's TypeAndShapeInferenceFunction() defined in ONNX OpSchema. - DataPropagationContextImpl data_propagation_context(node, *this); - - { - auto status = Status::OK(); - ORT_TRY { - data_propagation_context.RunInferencing(); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node (", node.Name(), ") Op (", node.OpType(), ") ", ex.what()); - }); - } - ORT_RETURN_IF_ERROR(status); - } - - const auto& onnx_inferred_data(data_propagation_context.InferredOutputTypes()); - - if (!onnx_inferred_data.empty()) { - ORT_ENFORCE(onnx_inferred_data.size() == onnx_inferred_types.size(), ""); - - for (size_t idx = 0; idx < onnx_inferred_types.size(); ++idx) { - auto& type_proto = onnx_inferred_types[idx]; - auto& type_proto_with_shape_data_inferred = onnx_inferred_data[idx]; - - auto dim_size = type_proto_with_shape_data_inferred.tensor_type().shape().dim_size(); - auto are_all_dimensions_have_value = (dim_size > 0) ? true : false; - - for (int i = 0; i < dim_size; ++i) { - if (!type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).has_dim_value()) { - are_all_dimensions_have_value = false; - break; - } - } - - if (are_all_dimensions_have_value) { - type_proto.mutable_tensor_type()->clear_shape(); - - for (int i = 0; i < dim_size; ++i) { - auto value = type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).dim_value(); - type_proto.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(value); - } - } - } - } - } + const auto& onnx_inferred_types(context.InferredOutputTypes()); + const auto& onnx_inferred_data(data_propagation_context.InferredOutputTypes()); // Infer and verify node output arg type information. int i = -1; @@ -3139,6 +3117,33 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso auto op_formal_parameter = op.outputs().at(operand_index); const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; + + const TypeProto& type_proto_with_shape_data_inferred = onnx_inferred_data[i]; + + // If the actual dimension values are inferred after executing the operator's + // PartialDataPropagationFunction(), save them to the output NodeArg as a + // TensorShapeProto so that downstream (consumer) nodes can use them later. + auto save_tensor_shape_proto_with_inferred_data = [&]() -> void { + auto dim_size = type_proto_with_shape_data_inferred.tensor_type().shape().dim_size(); + if (dim_size <= 0) return; + + // Each dimension should have the value + for (int i = 0; i < dim_size; ++i) { + if (!type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).has_dim_value()) { + return; + } + } + + output_def->tensor_shape_proto_after_shape_inferred_.clear_dim(); + + for (int i = 0; i < dim_size; ++i) { + auto value = type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).dim_value(); + output_def->tensor_shape_proto_after_shape_inferred_.add_dim()->set_dim_value(value); + } + }; + + save_tensor_shape_proto_with_inferred_data(); + DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; From fb5a424f315db28a481bfd1bae01e7005ee091d4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 7 Oct 2025 13:47:37 -0700 Subject: [PATCH 03/23] Save scalar value after data propagation --- include/onnxruntime/core/graph/graph.h | 3 + include/onnxruntime/core/graph/node_arg.h | 10 +- onnxruntime/core/graph/graph.cc | 239 +++++++++++++++------- 3 files changed, 175 insertions(+), 77 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 9a0708d72b4f8..5eea9514c2a77 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1760,6 +1760,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::vector& output_types, const Graph::ResolveOptions& options); + common::Status SaveValuesFromDataPropagation(Node&, NodeArg& output_def, + const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto); + // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 5cf8f8acd370a..7420e619b6649 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -108,7 +108,7 @@ class NodeArg { const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } /** Gets the inferred shape as a TensorShapeProto. */ - const ONNX_NAMESPACE::TensorShapeProto& GetInferredTensorShapeProto() const noexcept { return tensor_shape_proto_after_shape_inferred_; } + const ONNX_NAMESPACE::TensorShapeProto& GetInferredTensorShapeProto() const noexcept { return tensor_shape_proto_after_data_propagation_; } /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ @@ -131,12 +131,14 @@ class NodeArg { // Node arg name, type and shape. NodeArgInfo node_arg_info_; - // If this node has a Tensor type, this is the TensorShapeProto of the node's TypeProto_Tensor. + // This TensorShapeProto is for saving the inferred shape data as a TensorShapeProto from Op's PartialDataPropagationFunction(). + // // Calling Op's TypeAndShapeInferenceFunction() is not enough for some operators during type and shape inference, // e.g. Shape op, as it only gets the inferred rank/dimension size value. // The Op's PartialDataPropagationFunction() defined in ONNX Op Schema should also be called to get the shape dimension values. - // The variable is used for storing that shape dimension values so that inferred shape can be correctly propagated through out the graph. - ONNX_NAMESPACE::TensorShapeProto tensor_shape_proto_after_shape_inferred_; + // The variable is used for storing that shape dimension values so that inferred shape values can be correctly propagated through out the graph. + ONNX_NAMESPACE::TensorShapeProto tensor_shape_proto_after_data_propagation_; + int64_t scalar_value_after_data_propagation_; // Flag indicates whether <*this> node arg exists or not. bool exists_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 64978bf6bd95d..a9c1c0765d097 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -44,6 +44,7 @@ #include "core/graph/schema_registry.h" #include "onnx/checker.h" #include "onnx/defs/parser.h" +#include "onnx/defs/tensor_proto_util.h" using namespace ONNX_NAMESPACE::checker; #endif @@ -2630,17 +2631,17 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { node_output_types_.resize(node.OutputDefs().size()); // The following code handles cases where a node needs to retrieve a ShapeTensorProto - // that stores the actual or inferred shape information in its NodeArg. - // + // that stores the inferred shape values in its NodeArg. + // // Focus on Reshape operator for now. - // + // // For the Reshape operator, its input shape may come from a producer node such as a Shape operator, // which can store the inferred shape as a ShapeTensorProto within the NodeArg. // In such cases, the Reshape operator should convert this ShapeTensorProto into a TensorProto. // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, // allowing the real dimension values to be correctly used. // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. - // + // // Note: Only certain ShapeTensorProto cases will be converted to TensorProto, specifically when: // - The shape input has shape [1], meaning the shape only has one dimension value, or // - The shape input has rank > 1 and all dimensions have known (concrete) values, @@ -2649,42 +2650,43 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // Note: The following code is placed here instead of in getInputData() because // getInputData() is a const function that returns a pointer. As a result, // we cannot create a new ShapeTensorProto and store it in this class within a const function. - if (node_.OpType() == "Reshape") { - // Get the "shape" input. - auto def = node_.InputDefs()[1]; // It's safe to access, the Reshape operator is guaranteed to have two inputs - if (def && def->Shape() && def->Shape()->dim_size() > 0) { - auto tensor_shape_proto = def->GetInferredTensorShapeProto(); - - // The shape of the "shape" input is [1] - if (tensor_shape_proto.dim_size() == 1 && - tensor_shape_proto.dim()[0].has_dim_value() && - tensor_shape_proto.dim()[0].dim_value() == 1) { - TensorProto tensor_proto; - tensor_proto.set_data_type(TensorProto_DataType_INT64); - tensor_proto.add_dims(1); - tensor_proto.add_int64_data(tensor_shape_proto.dim()[0].dim_value()); - tensor_proto_for_shape_.push_back(std::move(tensor_proto)); - } + for (const NodeArg* def : node_.InputDefs()) { + // Skip initializer as it will be handled in getInputData() + if (graph_.GetConstantInitializer(def->Name(), true)) { + continue; + } - // The shape of the "shape" input has rank > 1 and the all the dimensions have value (not symbolic) - if (tensor_shape_proto.dim_size() > 1) { - TensorProto tensor_proto; - tensor_proto.set_data_type(TensorProto_DataType_INT64); - tensor_proto.add_dims(tensor_shape_proto.dim_size()); - bool all_values = true; - for (const auto& dim : tensor_shape_proto.dim()) { - if (dim.has_dim_value()) { - tensor_proto.add_int64_data(dim.dim_value()); - } else { - all_values = false; - break; - } - } + const auto& tensor_shape_proto = def->GetInferredTensorShapeProto(); + + // The shape of the "shape" input is [1] + if (tensor_shape_proto.dim_size() == 1 && + tensor_shape_proto.dim()[0].has_dim_value() && + tensor_shape_proto.dim()[0].dim_value() == 1) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(1); + tensor_proto.add_int64_data(tensor_shape_proto.dim()[0].dim_value()); + tensor_proto_for_shape_[def->Name()] = std::move(tensor_proto); + } - if (all_values) { - tensor_proto_for_shape_.push_back(std::move(tensor_proto)); + // The shape of the "shape" input has rank > 1 and the all the dimensions have value (not symbolic) + if (tensor_shape_proto.dim_size() > 1) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(tensor_shape_proto.dim_size()); + bool all_values = true; + for (const auto& dim : tensor_shape_proto.dim()) { + if (dim.has_dim_value()) { + tensor_proto.add_int64_data(dim.dim_value()); + } else { + all_values = false; + break; } } + + if (all_values) { + tensor_proto_for_shape_[def->Name()] = std::move(tensor_proto); + } } } } @@ -2741,12 +2743,11 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { return initializer; } - // If it's a Reshape operator, simply return the ShapeTensorProto - // that was previously created in the InferenceContextImpl's constructor. - if (node_.OpType() == "Reshape") { - if (tensor_proto_for_shape_.size() == 1) { - return &tensor_proto_for_shape_[0]; - } + // Return the ShapeTensorProto that was previously created in + // the InferenceContextImpl's constructor if there is any. + auto tensor_proto_for_shape = tensor_proto_for_shape_.find(def->Name()); + if (tensor_proto_for_shape != tensor_proto_for_shape_.end()) { + return &tensor_proto_for_shape->second; } return nullptr; @@ -2784,7 +2785,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { Node& node_; // node_output_types_ will be populated by the operator-specific shape inference. std::vector node_output_types_; - std::vector tensor_proto_for_shape_; + std::unordered_map tensor_proto_for_shape_; SubgraphInferencingFunc subgraph_inferencing_func_; std::vector> graph_inferencers_; const Graph& graph_; @@ -2847,13 +2848,47 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext if (!def) return nullptr; - auto type_proto = def->TypeAsProto(); + auto check_shape_has_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { + // The shape of the "shape" input is 1D (rank = 0) with only one element. + // e.g. the "shape" input is [2] or [12] .... + if (tensor_shape_proto.dim_size() == 1 && + tensor_shape_proto.dim()[0].has_dim_value() && + tensor_shape_proto.dim()[0].dim_value() == 1) { + return true; + } - if (!type_proto->has_tensor_type() || !type_proto->tensor_type().has_shape()) { - return nullptr; + // The shape of the "shape" input has rank > 1 and the all the dimensions have value (not symbolic) + if (tensor_shape_proto.dim_size() > 1) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return false; + } + } + return true; + } + + return false; + }; + + const TensorShapeProto* tensor_shape_proto = nullptr; + + // then check the TensorShapeProto from NodeArg's TypeProto_Tensor + if (def->TypeAsProto() && + def->TypeAsProto()->has_tensor_type() && + def->TypeAsProto()->tensor_type().has_shape()) { + tensor_shape_proto = &def->TypeAsProto()->tensor_type().shape(); + if (check_shape_has_values(*tensor_shape_proto)) { + return tensor_shape_proto; + } + } + + // then check the tensor_shape_proto_after_shape_inferred_ stores in NodeArg. + tensor_shape_proto = &def->GetInferredTensorShapeProto(); + if (check_shape_has_values(*tensor_shape_proto)) { + return tensor_shape_proto; } - return &type_proto->tensor_type().shape(); + return nullptr; } void addOutputData(size_t index, TensorShapeProto&& tsp) { @@ -2878,6 +2913,82 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext std::vector node_output_types_; }; +// If the dimension values are inferred after executing the operator's +// PartialDataPropagationFunction(), save them to the output NodeArg as a +// TensorShapeProto so that downstream (consumer) nodes can use them later. +Status Graph::SaveValuesFromDataPropagation(Node& node, + NodeArg& output_def, + const TypeProto& propagated_value_as_type_proto) { + + auto dim_size = propagated_value_as_type_proto.tensor_type().shape().dim_size(); + if (dim_size < 0) return Status::OK(); + + // If dim size is 0 it means the inferred output is a scalar integer. + if (dim_size == 0) { + // Gather operator's PartialDataPropagationFunction() doesn't handle the case where + // the output is a scalar. + // Try to do data propagation and save the inferred scalar value. + if (node.OpType() == "Gather") { + // Try to get the "data" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->tensor_shape_proto_after_data_propagation_; + + // Try to get the "indices" input as a initializer + // Note: The "indices" input should be a scalar value, otherwise, if it's a tensor with dimension size > 0, + // the operator's type and shape inference should have inferred the shape value. + const auto* input_1 = node.GetDefinitions().input_defs[1]; + const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); + auto data_type = initializer->data_type(); + int64_t dim_index = -1; + + if (utils::HasRawData(*initializer)) { + const std::string& raw = initializer->raw_data(); + const int64_t* data = reinterpret_cast(raw.data()); + dim_index = *data; + } else { + if (initializer->data_type() == TensorProto_DataType_INT32) { + std::vector values; + values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); + if (values.size() == 1) { + dim_index = static_cast(values[0]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + std::vector values; + values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); + if (values.size() == 1) { + dim_index = values[0]; + } + + } + } + + if (dim_index >= 0 && (dim_index < tensor_shape_proto.dim_size())) { + auto& dim = tensor_shape_proto.dim(static_cast(dim_index)); + if (dim.has_dim_value()) { + output_def.scalar_value_after_data_propagation_ = dim.dim_value(); + } + } + + return Status::OK(); + } + } + + // If dim size > 0, this function only saves the shape if each dimension has value + for (int i = 0; i < dim_size; ++i) { + if (!propagated_value_as_type_proto.tensor_type().shape().dim(i).has_dim_value()) { + return Status::OK(); + } + } + + output_def.tensor_shape_proto_after_data_propagation_.clear_dim(); + for (int i = 0; i < dim_size; ++i) { + auto value = propagated_value_as_type_proto.tensor_type().shape().dim(i).dim_value(); + output_def.tensor_shape_proto_after_data_propagation_.add_dim()->set_dim_value(value); + } + + return Status::OK(); +} + Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const std::vector& input_types, std::vector& output_types, @@ -3079,14 +3190,18 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // The following code handles cases where calling the operator-specific shape inference alone // (i.e., the TypeAndShapeInferenceFunction() defined in the ONNX operator schema) // is insufficient to fully infer the output shape. - // + // // Focus on Shape operator for now. - // + // // For a Shape operator, running its shape inference only determines the rank (number of dimensions) // of the output tensor, not the actual dimension values. To infer those values and propagate them // throughout the graph, the PartialDataPropagationFunction() defined in the ONNX operator schema // must also be executed. - if (node.OpType() == "Shape") { + if (node.OpType() == "Shape" || + node.OpType() == "Unsqueeze" || + node.OpType() == "Squeeze" || + node.OpType() == "Gather" || + node.OpType() == "Slice") { data_propagation_context.RunInferencing(); } } @@ -3120,29 +3235,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const TypeProto& type_proto_with_shape_data_inferred = onnx_inferred_data[i]; - // If the actual dimension values are inferred after executing the operator's - // PartialDataPropagationFunction(), save them to the output NodeArg as a - // TensorShapeProto so that downstream (consumer) nodes can use them later. - auto save_tensor_shape_proto_with_inferred_data = [&]() -> void { - auto dim_size = type_proto_with_shape_data_inferred.tensor_type().shape().dim_size(); - if (dim_size <= 0) return; - - // Each dimension should have the value - for (int i = 0; i < dim_size; ++i) { - if (!type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).has_dim_value()) { - return; - } - } - - output_def->tensor_shape_proto_after_shape_inferred_.clear_dim(); - - for (int i = 0; i < dim_size; ++i) { - auto value = type_proto_with_shape_data_inferred.tensor_type().shape().dim(i).dim_value(); - output_def->tensor_shape_proto_after_shape_inferred_.add_dim()->set_dim_value(value); - } - }; - - save_tensor_shape_proto_with_inferred_data(); + SaveValuesFromDataPropagation(node, *output_def, type_proto_with_shape_data_inferred); DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; From bd1423549a97c29719ca836c5dcf16123be5163e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 9 Oct 2025 10:55:16 -0700 Subject: [PATCH 04/23] correctly save the inferred shape values in NodeArg for other ops' data propagation to work correctly --- include/onnxruntime/core/graph/node_arg.h | 2 +- onnxruntime/core/graph/graph.cc | 136 ++++++++++++++-------- 2 files changed, 90 insertions(+), 48 deletions(-) diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 7420e619b6649..f6e2369caeddd 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -138,7 +138,7 @@ class NodeArg { // The Op's PartialDataPropagationFunction() defined in ONNX Op Schema should also be called to get the shape dimension values. // The variable is used for storing that shape dimension values so that inferred shape values can be correctly propagated through out the graph. ONNX_NAMESPACE::TensorShapeProto tensor_shape_proto_after_data_propagation_; - int64_t scalar_value_after_data_propagation_; + int64_t scalar_value_after_data_propagation_ = std::numeric_limits::min(); // Flag indicates whether <*this> node arg exists or not. bool exists_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index a9c1c0765d097..4cc94eec8d178 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2755,6 +2755,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // ORT does not implement partial data propagation yet so just return nullptr. const TensorShapeProto* getSymbolicInput(size_t) const override { + //TODO:// Add symbolic input return nullptr; } @@ -2848,17 +2849,13 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext if (!def) return nullptr; - auto check_shape_has_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { - // The shape of the "shape" input is 1D (rank = 0) with only one element. - // e.g. the "shape" input is [2] or [12] .... - if (tensor_shape_proto.dim_size() == 1 && - tensor_shape_proto.dim()[0].has_dim_value() && - tensor_shape_proto.dim()[0].dim_value() == 1) { - return true; - } + // Try to get the previously inferred shape values that stored in NodeArg's tensor_shape_proto_after_shape_inferred_ - // The shape of the "shape" input has rank > 1 and the all the dimensions have value (not symbolic) - if (tensor_shape_proto.dim_size() > 1) { + const TensorShapeProto* tensor_shape_proto = nullptr; + + auto has_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { + // The TensorShapeProto (inferred shape values) should have rank > 0 and the all the dimensions have value (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { for (const auto& dim : tensor_shape_proto.dim()) { if (!dim.has_dim_value()) { return false; @@ -2870,21 +2867,9 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext return false; }; - const TensorShapeProto* tensor_shape_proto = nullptr; - - // then check the TensorShapeProto from NodeArg's TypeProto_Tensor - if (def->TypeAsProto() && - def->TypeAsProto()->has_tensor_type() && - def->TypeAsProto()->tensor_type().has_shape()) { - tensor_shape_proto = &def->TypeAsProto()->tensor_type().shape(); - if (check_shape_has_values(*tensor_shape_proto)) { - return tensor_shape_proto; - } - } - - // then check the tensor_shape_proto_after_shape_inferred_ stores in NodeArg. + // Get NodeArg's tensor_shape_proto_after_shape_inferred_ if any tensor_shape_proto = &def->GetInferredTensorShapeProto(); - if (check_shape_has_values(*tensor_shape_proto)) { + if (has_shape_values(*tensor_shape_proto)) { return tensor_shape_proto; } @@ -2923,22 +2908,23 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, auto dim_size = propagated_value_as_type_proto.tensor_type().shape().dim_size(); if (dim_size < 0) return Status::OK(); - // If dim size is 0 it means the inferred output is a scalar integer. + // If dim size is 0, it means one of the cases: + // 1. Inferred output is a scalar. + // 2. Node's input is a scalar and Op's PartialDataPropagationFunction() can't handle. + // + // In other words, some operators' PartialDataPropagationFunction() don't handle the case where the input or output is a scalar. + // Try to do data propagation and save the inferred scalar value if needed. if (dim_size == 0) { - // Gather operator's PartialDataPropagationFunction() doesn't handle the case where - // the output is a scalar. - // Try to do data propagation and save the inferred scalar value. if (node.OpType() == "Gather") { // Try to get the "data" input const auto* input_0 = node.GetDefinitions().input_defs[0]; auto& tensor_shape_proto = input_0->tensor_shape_proto_after_data_propagation_; - // Try to get the "indices" input as a initializer - // Note: The "indices" input should be a scalar value, otherwise, if it's a tensor with dimension size > 0, - // the operator's type and shape inference should have inferred the shape value. + // Try to get the "indices" input as an initializer + // Note: The "indices" input should be a scalar, otherwise, if it's a tensor with dimension size > 0, + // the operator's type and shape inference should have inferred the output shape value. const auto* input_1 = node.GetDefinitions().input_defs[1]; const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); - auto data_type = initializer->data_type(); int64_t dim_index = -1; if (utils::HasRawData(*initializer)) { @@ -2968,22 +2954,70 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, output_def.scalar_value_after_data_propagation_ = dim.dim_value(); } } + } else if (node.OpType() == "Unsqueeze") { + // Try to get the "data" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Only handle the "data" input is a scalar + if (input_0->scalar_value_after_data_propagation_ > 0) { + // Try to get the "Axes" input as an initializer + const auto* input_1 = node.GetDefinitions().input_defs[1]; + const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); + int64_t axis = -1; + + if (utils::HasRawData(*initializer)) { + const std::string& raw = initializer->raw_data(); + const int64_t* data = reinterpret_cast(raw.data()); + axis = *data; + } else { + if (initializer->data_type() == TensorProto_DataType_INT32) { + std::vector values; + values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); + if (values.size() == 1) { + axis = static_cast(values[0]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + std::vector values; + values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); + if (values.size() == 1) { + axis = values[0]; + } + } + } + + if (axis == 0) { + output_def.tensor_shape_proto_after_data_propagation_.clear_dim(); + output_def.tensor_shape_proto_after_data_propagation_.add_dim()->set_dim_value(input_0->scalar_value_after_data_propagation_); + } + } + } else if (node.OpType() == "Mul") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->scalar_value_after_data_propagation_ > 0 && + input_1->scalar_value_after_data_propagation_ > 0) { + output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ * input_1->scalar_value_after_data_propagation_; + } + } + } + // If dim size > 0, only saves inferred output if one of the conditions is true: + // 1. The shape of the inferred output is 1D (rank = 0) with only one element, e.g. inferred output is [2], [3], .... + // 2. The shape of the inferred output has rank > 1 and the all the dimensions have value (not symbolic) + else if (dim_size > 0) { + + for (int i = 0; i < dim_size; ++i) { + if (!propagated_value_as_type_proto.tensor_type().shape().dim(i).has_dim_value()) { return Status::OK(); } - } - - // If dim size > 0, this function only saves the shape if each dimension has value - for (int i = 0; i < dim_size; ++i) { - if (!propagated_value_as_type_proto.tensor_type().shape().dim(i).has_dim_value()) { - return Status::OK(); } - } - output_def.tensor_shape_proto_after_data_propagation_.clear_dim(); - for (int i = 0; i < dim_size; ++i) { - auto value = propagated_value_as_type_proto.tensor_type().shape().dim(i).dim_value(); - output_def.tensor_shape_proto_after_data_propagation_.add_dim()->set_dim_value(value); + output_def.tensor_shape_proto_after_data_propagation_.clear_dim(); + for (int i = 0; i < dim_size; ++i) { + auto value = propagated_value_as_type_proto.tensor_type().shape().dim(i).dim_value(); + output_def.tensor_shape_proto_after_data_propagation_.add_dim()->set_dim_value(value); + } } return Status::OK(); @@ -3182,6 +3216,10 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso InferenceContextImpl context(node, func, *this, options); DataPropagationContextImpl data_propagation_context(node, *this); + if (node.OpType() == "Reshape") { + std::cout << "Reshape" << std::endl; + } + { auto status = Status::OK(); ORT_TRY { @@ -3191,19 +3229,23 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // (i.e., the TypeAndShapeInferenceFunction() defined in the ONNX operator schema) // is insufficient to fully infer the output shape. // - // Focus on Shape operator for now. + // Focus on tensor replated operators. // - // For a Shape operator, running its shape inference only determines the rank (number of dimensions) + // For exmaple, for a Shape operator, running its shape inference only determines the rank (number of dimensions) // of the output tensor, not the actual dimension values. To infer those values and propagate them // throughout the graph, the PartialDataPropagationFunction() defined in the ONNX operator schema // must also be executed. + /* if (node.OpType() == "Shape" || node.OpType() == "Unsqueeze" || node.OpType() == "Squeeze" || node.OpType() == "Gather" || - node.OpType() == "Slice") { + node.OpType() == "Slice" || + node.OpType() == "Concat") { data_propagation_context.RunInferencing(); } + */ + data_propagation_context.RunInferencing(); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -3235,7 +3277,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const TypeProto& type_proto_with_shape_data_inferred = onnx_inferred_data[i]; - SaveValuesFromDataPropagation(node, *output_def, type_proto_with_shape_data_inferred); + ORT_RETURN_IF_ERROR(SaveValuesFromDataPropagation(node, *output_def, type_proto_with_shape_data_inferred)); DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; From e2c532f759419b278e3ff339357b26c7b9a79cf5 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 9 Oct 2025 12:22:47 -0700 Subject: [PATCH 05/23] refactor the code and add comments --- include/onnxruntime/core/graph/graph.h | 7 +- include/onnxruntime/core/graph/node_arg.h | 25 ++- onnxruntime/core/graph/graph.cc | 223 ++++++++++------------ 3 files changed, 122 insertions(+), 133 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 5eea9514c2a77..02ce84e0b1f88 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1760,8 +1760,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::vector& output_types, const Graph::ResolveOptions& options); - common::Status SaveValuesFromDataPropagation(Node&, NodeArg& output_def, - const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto); + // If the dimension values are inferred after executing ONNX operator's PartialDataPropagationFunction(), + // save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes + // can use them later for their PartialDataPropagationFunction() and PartialDataPropagationFunction(). + common::Status SaveValuesFromDataPropagation(Node& node, NodeArg& output_def, + const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const; // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index f6e2369caeddd..d033a0bf73397 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -108,7 +108,7 @@ class NodeArg { const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } /** Gets the inferred shape as a TensorShapeProto. */ - const ONNX_NAMESPACE::TensorShapeProto& GetInferredTensorShapeProto() const noexcept { return tensor_shape_proto_after_data_propagation_; } + const ONNX_NAMESPACE::TensorShapeProto& GetValuesAfterDataPropagation() const noexcept { return values_after_data_propagation_; } /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ @@ -131,14 +131,21 @@ class NodeArg { // Node arg name, type and shape. NodeArgInfo node_arg_info_; - // This TensorShapeProto is for saving the inferred shape data as a TensorShapeProto from Op's PartialDataPropagationFunction(). - // - // Calling Op's TypeAndShapeInferenceFunction() is not enough for some operators during type and shape inference, - // e.g. Shape op, as it only gets the inferred rank/dimension size value. - // The Op's PartialDataPropagationFunction() defined in ONNX Op Schema should also be called to get the shape dimension values. - // The variable is used for storing that shape dimension values so that inferred shape values can be correctly propagated through out the graph. - ONNX_NAMESPACE::TensorShapeProto tensor_shape_proto_after_data_propagation_; - int64_t scalar_value_after_data_propagation_ = std::numeric_limits::min(); + // This variable stores the inferred shape data as a TensorShapeProto after executing + // the ONNX operator's PartialDataPropagationFunction(). + // + // Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient + // for complete shape inference. For example, the Shape operator only provides the + // output's rank (1-dimensional) but not its actual dimension values. + // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also + // be executed to obtain the concrete output shape values, allowing accurate propagation + // of shape information throughout the graph. + ONNX_NAMESPACE::TensorShapeProto values_after_data_propagation_; + + // This variable stores the inferred scalar output. + // It is also used for shape inference and data propagation to ensure consistent shape and + // value information throughout the graph. + int64_t scalar_value_after_data_propagation_ = std::numeric_limits::min(); // Flag indicates whether <*this> node arg exists or not. bool exists_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 4cc94eec8d178..99ae3aae2ce23 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2630,46 +2630,28 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { options_(options) { node_output_types_.resize(node.OutputDefs().size()); - // The following code handles cases where a node needs to retrieve a ShapeTensorProto - // that stores the inferred shape values in its NodeArg. + // The following code handles cases where a node stores the previously inferred shape values in its NodeArg. // - // Focus on Reshape operator for now. + // For example, the Reshape operator, its input shape may come from a producer node such as a Shape operator, + // and the inferred shape value is already stored as a ShapeTensorProto in corresponding NodeArg. // - // For the Reshape operator, its input shape may come from a producer node such as a Shape operator, - // which can store the inferred shape as a ShapeTensorProto within the NodeArg. // In such cases, the Reshape operator should convert this ShapeTensorProto into a TensorProto. // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, // allowing the real dimension values to be correctly used. // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. // - // Note: Only certain ShapeTensorProto cases will be converted to TensorProto, specifically when: - // - The shape input has shape [1], meaning the shape only has one dimension value, or - // - The shape input has rank > 1 and all dimensions have known (concrete) values, - // e.g., [1, 3, 64, 64]. - // // Note: The following code is placed here instead of in getInputData() because - // getInputData() is a const function that returns a pointer. As a result, - // we cannot create a new ShapeTensorProto and store it in this class within a const function. + // getInputData() is a const function. As a result, we cannot create a new ShapeTensorProto + // and store it in this class within a const function. for (const NodeArg* def : node_.InputDefs()) { // Skip initializer as it will be handled in getInputData() if (graph_.GetConstantInitializer(def->Name(), true)) { continue; } - const auto& tensor_shape_proto = def->GetInferredTensorShapeProto(); - - // The shape of the "shape" input is [1] - if (tensor_shape_proto.dim_size() == 1 && - tensor_shape_proto.dim()[0].has_dim_value() && - tensor_shape_proto.dim()[0].dim_value() == 1) { - TensorProto tensor_proto; - tensor_proto.set_data_type(TensorProto_DataType_INT64); - tensor_proto.add_dims(1); - tensor_proto.add_int64_data(tensor_shape_proto.dim()[0].dim_value()); - tensor_proto_for_shape_[def->Name()] = std::move(tensor_proto); - } + const auto& tensor_shape_proto = def->GetValuesAfterDataPropagation(); - // The shape of the "shape" input has rank > 1 and the all the dimensions have value (not symbolic) + // The inferred data has rank > 1 and the all the dimensions have values (not symbolic) if (tensor_shape_proto.dim_size() > 1) { TensorProto tensor_proto; tensor_proto.set_data_type(TensorProto_DataType_INT64); @@ -2755,7 +2737,6 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // ORT does not implement partial data propagation yet so just return nullptr. const TensorShapeProto* getSymbolicInput(size_t) const override { - //TODO:// Add symbolic input return nullptr; } @@ -2849,12 +2830,12 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext if (!def) return nullptr; - // Try to get the previously inferred shape values that stored in NodeArg's tensor_shape_proto_after_shape_inferred_ + // Try to get the previously inferred shape values that stored in NodeArg's values_after_data_propagation_ const TensorShapeProto* tensor_shape_proto = nullptr; auto has_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { - // The TensorShapeProto (inferred shape values) should have rank > 0 and the all the dimensions have value (not symbolic) + // The TensorShapeProto (inferred shape values) should have rank > 0 and the all the dimensions have values (not symbolic) if (tensor_shape_proto.dim_size() > 0) { for (const auto& dim : tensor_shape_proto.dim()) { if (!dim.has_dim_value()) { @@ -2867,8 +2848,8 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext return false; }; - // Get NodeArg's tensor_shape_proto_after_shape_inferred_ if any - tensor_shape_proto = &def->GetInferredTensorShapeProto(); + // Get NodeArg's values_after_data_propagation_ if applicable + tensor_shape_proto = &def->GetValuesAfterDataPropagation(); if (has_shape_values(*tensor_shape_proto)) { return tensor_shape_proto; } @@ -2898,73 +2879,91 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext std::vector node_output_types_; }; -// If the dimension values are inferred after executing the operator's -// PartialDataPropagationFunction(), save them to the output NodeArg as a -// TensorShapeProto so that downstream (consumer) nodes can use them later. Status Graph::SaveValuesFromDataPropagation(Node& node, NodeArg& output_def, - const TypeProto& propagated_value_as_type_proto) { + const TypeProto& onnx_inferred_types_after_data_propagation) const { + auto dim_size = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim_size(); - auto dim_size = propagated_value_as_type_proto.tensor_type().shape().dim_size(); - if (dim_size < 0) return Status::OK(); + if (dim_size < 0) { + return Status::OK(); + } - // If dim size is 0, it means one of the cases: - // 1. Inferred output is a scalar. - // 2. Node's input is a scalar and Op's PartialDataPropagationFunction() can't handle. - // - // In other words, some operators' PartialDataPropagationFunction() don't handle the case where the input or output is a scalar. - // Try to do data propagation and save the inferred scalar value if needed. + // If the dimension size is 0, it indicates one of the following cases: + // 1. The inferred output is a scalar. + // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() cannot handle it. + // + // In other words, some operators' PartialDataPropagationFunction() implementations do not support + // scalar inputs or outputs. In such cases, attempt data propagation manually and store the inferred + // scalar value in the NodeArg if applicable. if (dim_size == 0) { - if (node.OpType() == "Gather") { - // Try to get the "data" input - const auto* input_0 = node.GetDefinitions().input_defs[0]; - auto& tensor_shape_proto = input_0->tensor_shape_proto_after_data_propagation_; - - // Try to get the "indices" input as an initializer - // Note: The "indices" input should be a scalar, otherwise, if it's a tensor with dimension size > 0, - // the operator's type and shape inference should have inferred the output shape value. - const auto* input_1 = node.GetDefinitions().input_defs[1]; - const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); - int64_t dim_index = -1; - + if (node.OpType() == "Gather") { + // Following code extracts an element from a 1D array. + // e.g. + // shape data is [1, 3, 64, 64] -> gets 64 if the index is 2. + // shape data is [1, 3, 64, 64] -> gets 3 if the index is 1. + + // Try to get the "data" input + // Note: The "data" input should be a one dimension array in this case. + const auto* input_0 = node.GetDefinitions().input_defs[0]; + + // Try to get the "indices" input + // Note: The "indices" input should be a scalar, otherwise, if it's a tensor with dimension size > 0, + // the operator's type and shape inference should have inferred the output shape value. + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + // The "indices" should be an initializer because it's a scalar in this case. + const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); + int64_t index = std::numeric_limits::min(); + + if (initializer) { if (utils::HasRawData(*initializer)) { const std::string& raw = initializer->raw_data(); const int64_t* data = reinterpret_cast(raw.data()); - dim_index = *data; + index = *data; } else { if (initializer->data_type() == TensorProto_DataType_INT32) { std::vector values; values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); if (values.size() == 1) { - dim_index = static_cast(values[0]); + index = static_cast(values[0]); } } else if (initializer->data_type() == TensorProto_DataType_INT64) { std::vector values; values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); if (values.size() == 1) { - dim_index = values[0]; + index = values[0]; } - } } + } - if (dim_index >= 0 && (dim_index < tensor_shape_proto.dim_size())) { - auto& dim = tensor_shape_proto.dim(static_cast(dim_index)); - if (dim.has_dim_value()) { - output_def.scalar_value_after_data_propagation_ = dim.dim_value(); - } + // Get the previously inferred dimension values + auto& tensor_shape_proto = input_0->values_after_data_propagation_; + + // Save the dimension value in the NodeArg + if (index != std::numeric_limits::min() && (index < tensor_shape_proto.dim_size())) { + auto& dim = tensor_shape_proto.dim(static_cast(index)); + if (dim.has_dim_value()) { + output_def.scalar_value_after_data_propagation_ = dim.dim_value(); } - } else if (node.OpType() == "Unsqueeze") { - // Try to get the "data" input - const auto* input_0 = node.GetDefinitions().input_defs[0]; + } + + } else if (node.OpType() == "Unsqueeze") { + // Following code expends the dimension of a scalr to one dimension. + // e.g. + // shape data is 64 -> gets [64] - // Only handle the "data" input is a scalar - if (input_0->scalar_value_after_data_propagation_ > 0) { - // Try to get the "Axes" input as an initializer - const auto* input_1 = node.GetDefinitions().input_defs[1]; - const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); - int64_t axis = -1; + // Try to get the "data" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + + // Only handle the "data" input is previously inferred from data propagation and is a scalar + if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { + // Try to get the "Axes" input as an initializer + const auto* input_1 = node.GetDefinitions().input_defs[1]; + const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); + int64_t axis = -1; + if (initializer) { if (utils::HasRawData(*initializer)) { const std::string& raw = initializer->raw_data(); const int64_t* data = reinterpret_cast(raw.data()); @@ -2984,39 +2983,38 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } } } - - if (axis == 0) { - output_def.tensor_shape_proto_after_data_propagation_.clear_dim(); - output_def.tensor_shape_proto_after_data_propagation_.add_dim()->set_dim_value(input_0->scalar_value_after_data_propagation_); - } } - } else if (node.OpType() == "Mul") { - // Try to get the "A" input - const auto* input_0 = node.GetDefinitions().input_defs[0]; - // Try to get the "B" input - const auto* input_1 = node.GetDefinitions().input_defs[1]; - if (input_0->scalar_value_after_data_propagation_ > 0 && - input_1->scalar_value_after_data_propagation_ > 0) { - output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ * input_1->scalar_value_after_data_propagation_; + if (axis == 0) { + output_def.values_after_data_propagation_.clear_dim(); + output_def.values_after_data_propagation_.add_dim()->set_dim_value(input_0->scalar_value_after_data_propagation_); } } - } - // If dim size > 0, only saves inferred output if one of the conditions is true: - // 1. The shape of the inferred output is 1D (rank = 0) with only one element, e.g. inferred output is [2], [3], .... - // 2. The shape of the inferred output has rank > 1 and the all the dimensions have value (not symbolic) + } else if (node.OpType() == "Mul") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && + input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { + output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ * input_1->scalar_value_after_data_propagation_; + } + } + } + // If the dimension size is greater than 0, only save the inferred data from data propagation + // when the data has rank > 1 and all dimensions have concrete (non-symbolic) values. else if (dim_size > 0) { - for (int i = 0; i < dim_size; ++i) { - if (!propagated_value_as_type_proto.tensor_type().shape().dim(i).has_dim_value()) { + if (!onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).has_dim_value()) { return Status::OK(); } } - output_def.tensor_shape_proto_after_data_propagation_.clear_dim(); + output_def.values_after_data_propagation_.clear_dim(); for (int i = 0; i < dim_size; ++i) { - auto value = propagated_value_as_type_proto.tensor_type().shape().dim(i).dim_value(); - output_def.tensor_shape_proto_after_data_propagation_.add_dim()->set_dim_value(value); + auto value = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).dim_value(); + output_def.values_after_data_propagation_.add_dim()->set_dim_value(value); } } @@ -3216,35 +3214,17 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso InferenceContextImpl context(node, func, *this, options); DataPropagationContextImpl data_propagation_context(node, *this); - if (node.OpType() == "Reshape") { - std::cout << "Reshape" << std::endl; - } - { auto status = Status::OK(); ORT_TRY { context.RunInferencing(); - // The following code handles cases where calling the operator-specific shape inference alone - // (i.e., the TypeAndShapeInferenceFunction() defined in the ONNX operator schema) - // is insufficient to fully infer the output shape. - // - // Focus on tensor replated operators. - // - // For exmaple, for a Shape operator, running its shape inference only determines the rank (number of dimensions) - // of the output tensor, not the actual dimension values. To infer those values and propagate them - // throughout the graph, the PartialDataPropagationFunction() defined in the ONNX operator schema - // must also be executed. - /* - if (node.OpType() == "Shape" || - node.OpType() == "Unsqueeze" || - node.OpType() == "Squeeze" || - node.OpType() == "Gather" || - node.OpType() == "Slice" || - node.OpType() == "Concat") { - data_propagation_context.RunInferencing(); - } - */ + // Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient + // for complete shape inference. For example, the Shape operator only provides the + // output's rank (1-dimensional) but not its actual dimension values. + // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also + // be executed to obtain the concrete output shape values, allowing accurate propagation + // of shape information throughout the graph. data_propagation_context.RunInferencing(); } ORT_CATCH(const std::exception& ex) { @@ -3256,7 +3236,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso } const auto& onnx_inferred_types(context.InferredOutputTypes()); - const auto& onnx_inferred_data(data_propagation_context.InferredOutputTypes()); + + const auto& onnx_inferred_types_after_data_propagation(data_propagation_context.InferredOutputTypes()); // Infer and verify node output arg type information. int i = -1; @@ -3275,9 +3256,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; - const TypeProto& type_proto_with_shape_data_inferred = onnx_inferred_data[i]; - - ORT_RETURN_IF_ERROR(SaveValuesFromDataPropagation(node, *output_def, type_proto_with_shape_data_inferred)); + ORT_RETURN_IF_ERROR(SaveValuesFromDataPropagation(node, *output_def, onnx_inferred_types_after_data_propagation[i])); DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; From f79f65b1ff68a534b8c7bf68057c1ff0ba34b680 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 9 Oct 2025 16:11:37 -0700 Subject: [PATCH 06/23] add tests --- onnxruntime/core/graph/graph.cc | 30 ++++++++ .../test/framework/shape_inference_test.cc | 59 +++++++++++++++ ..._propagation_with_shape_related_nodes.onnx | Bin 0 -> 359 bytes ...ta_propagation_with_shape_related_nodes.py | 68 +++++++++++++++++ ...opagation_with_shape_related_nodes_v2.onnx | Bin 0 -> 810 bytes ...propagation_with_shape_related_nodes_v2.py | 69 ++++++++++++++++++ 6 files changed, 226 insertions(+) create mode 100644 onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx create mode 100644 onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py create mode 100644 onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx create mode 100644 onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 99ae3aae2ce23..daae51c4ea82f 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2990,6 +2990,26 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, output_def.values_after_data_propagation_.add_dim()->set_dim_value(input_0->scalar_value_after_data_propagation_); } } + } else if (node.OpType() == "Add") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && + input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { + output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ + input_1->scalar_value_after_data_propagation_; + } + } else if (node.OpType() == "Sub") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && + input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { + output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ - input_1->scalar_value_after_data_propagation_; + } } else if (node.OpType() == "Mul") { // Try to get the "A" input const auto* input_0 = node.GetDefinitions().input_defs[0]; @@ -3000,6 +3020,16 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ * input_1->scalar_value_after_data_propagation_; } + } else if (node.OpType() == "Div") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && + input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { + output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ / input_1->scalar_value_after_data_propagation_; + } } } // If the dimension size is greater than 0, only save the inferred data from data propagation diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index f5258760eb20d..f70ba25301a37 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -16,6 +16,8 @@ using namespace ONNX_NAMESPACE; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -76,6 +78,63 @@ TEST_F(ShapeInferenceTest, BasicTest) { CheckShapeEquality(InputShape(node), OutputShape(node)); } +TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { + { + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "batch", 1) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "width", 64) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "height", 64) == nullptr); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 4) << "The output shape should have 4 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should be 1"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should be 3"; + EXPECT_TRUE(output_shape[2] == 64) << "The second dimension should be 64"; + EXPECT_TRUE(output_shape[3] == 64) << "The second dimension should be 64"; + } + + { + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "batch", 1) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "width", 64) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "height", 64) == nullptr); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 3) << "The output shape should have 3 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should be 1"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should be 3"; + EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should be 4096"; + } +} + namespace { struct MyCustomKernelWithOptionalInput { MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e18aa31e414e56724085f5ac3579db7a4db5f277 GIT binary patch literal 359 zcmZ8cyAFat5EKD9*B4o%CRRj^g^97Up_5(Krbh?Wl&zxQ`a K!M7Bs==cO0oNeC# literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py new file mode 100644 index 0000000000000..ce7d33d2f93b6 --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py @@ -0,0 +1,68 @@ +import onnx +from onnx import helper, TensorProto + +# 1. Define graph input with symbolic shape ['batch', 3, 'width', 'height'] +input_tensor = helper.make_tensor_value_info( + 'input', TensorProto.FLOAT, ['batch', 3, 'width', 'height'] +) + +# 2. Define intermediate and output tensors +shape_out = helper.make_tensor_value_info('shape_out', TensorProto.INT64, [4]) # Shape output +reshape_a_out = helper.make_tensor_value_info('reshape_a_out', TensorProto.FLOAT, None) +output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, None) + +# 3. Create the initializer for Reshape A's 'shape' input: [0, 32, -1] +shape_initializer = helper.make_tensor( + name='reshape_a_shape', + data_type=TensorProto.INT64, + dims=[3], + vals=[0, 32, -1], +) + +# 4. Create nodes: +# Shape node +shape_node = helper.make_node( + 'Shape', + inputs=['input'], + outputs=['shape_out'], + name='ShapeNode' +) + +# Reshape A node: takes input + constant shape +reshape_a_node = helper.make_node( + 'Reshape', + inputs=['input', 'reshape_a_shape'], + outputs=['reshape_a_out'], + name='ReshapeA' +) + +# Reshape B node: takes Shape + ReshapeA outputs, outputs final output +reshape_b_node = helper.make_node( + 'Reshape', + inputs=['reshape_a_out', 'shape_out'], + outputs=['output'], + name='ReshapeB' +) + +# 5. Assemble the graph +graph = helper.make_graph( + nodes=[shape_node, reshape_a_node, reshape_b_node], + name='Shape_Reshape_Model', + inputs=[input_tensor], + outputs=[output_tensor], + initializer=[shape_initializer], + value_info=[shape_out, reshape_a_out], +) + +# 6. Define the model (set IR and opset) +model = helper.make_model( + graph, + opset_imports=[helper.make_operatorsetid('', 18)], + producer_name='onnx-example-generator', +) +model.ir_version = onnx.IR_VERSION + +# 7. Save the model +onnx.save(model, 'test_shape_data_propagation_with_shape_related_nodes.onnx') + +print("Model saved to test_shape_data_propagation_with_shape_related_nodes.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ff41075ff64cca07e1284f8af6f48c281530ff90 GIT binary patch literal 810 zcmaiy&rZTX5XQR{+BzVT@xsMFph0Xxq~#4SZC$ zyNkAj#MA7z` zzaBx3MU4%@TyN`-lsktf%yO68>Iy?#QO%WrS#6MZSy4W39O9V!4Sh%yM|vqFVUimr zWQj>`ijXRk{Nc?*gNW?hJMP-Fc5T`(SMtMR2hI(G?3ZF?^F{jAcw^>Xt9i=Vn!WnJi`>YQwF7hKa6bXssI20 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py new file mode 100644 index 0000000000000..6620f79f0350c --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py @@ -0,0 +1,69 @@ +import onnx +from onnx import helper, TensorProto + +# === Graph input/output === +input_tensor = helper.make_tensor_value_info( + 'input', TensorProto.FLOAT, ['batch', 3, 'width', 'height'] +) +output_tensor = helper.make_tensor_value_info( + 'output', TensorProto.FLOAT, ['batch', 3, 'width*height'] +) + +# === Initializers === +B = helper.make_tensor('B', TensorProto.FLOAT, [], [1.0]) + +# Gather indices +g0_idx = helper.make_tensor('g0_idx', TensorProto.INT64, [], [0]) +g1_idx = helper.make_tensor('g1_idx', TensorProto.INT64, [], [1]) +g2_idx = helper.make_tensor('g2_idx', TensorProto.INT64, [], [2]) +g3_idx = helper.make_tensor('g3_idx', TensorProto.INT64, [], [3]) + +# Unsqueeze axes tensors +axes_unsq0 = helper.make_tensor('axes_unsq0', TensorProto.INT64, [1], [0]) +axes_unsq1 = helper.make_tensor('axes_unsq1', TensorProto.INT64, [1], [0]) +axes_unsq2 = helper.make_tensor('axes_unsq2', TensorProto.INT64, [1], [0]) + +# === Nodes === +div = helper.make_node('Div', ['input', 'B'], ['div_out']) + +# Two Shape nodes from Div +shape_left = helper.make_node('Shape', ['div_out'], ['shape_left_out']) +shape_right = helper.make_node('Shape', ['div_out'], ['shape_right_out']) + +# Left Shape path +gather0 = helper.make_node('Gather', ['shape_left_out', 'g0_idx'], ['g0_out']) +gather1 = helper.make_node('Gather', ['shape_left_out', 'g1_idx'], ['g1_out']) +unsq0 = helper.make_node('Unsqueeze', ['g0_out', 'axes_unsq0'], ['u0_out']) +unsq1 = helper.make_node('Unsqueeze', ['g1_out', 'axes_unsq1'], ['u1_out']) + +# Right Shape path +gather2 = helper.make_node('Gather', ['shape_right_out', 'g2_idx'], ['g2_out']) +gather3 = helper.make_node('Gather', ['shape_right_out', 'g3_idx'], ['g3_out']) +mul = helper.make_node('Mul', ['g2_out', 'g3_out'], ['mul_out']) +unsq2 = helper.make_node('Unsqueeze', ['mul_out', 'axes_unsq2'], ['u2_out']) + +# Combine +concat = helper.make_node('Concat', ['u0_out', 'u1_out', 'u2_out'], ['concat_out'], axis=0) +reshape = helper.make_node('Reshape', ['div_out', 'concat_out'], ['output']) + +# === Graph === +graph = helper.make_graph( + [div, shape_left, shape_right, + gather0, gather1, gather2, gather3, + mul, unsq0, unsq1, unsq2, + concat, reshape], + 'Div_Shape_Gather_Concat_Reshape', + [input_tensor], + [output_tensor], + initializer=[ + B, g0_idx, g1_idx, g2_idx, g3_idx, + axes_unsq0, axes_unsq1, axes_unsq2 + ] +) + +# === Model === +model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)], producer_name='onnx-example') +onnx.checker.check_model(model) +onnx.save(model, 'test_shape_data_propagation_with_shape_related_nodes_v2.onnx') + +print("✅ Model saved as test_shape_data_propagation_with_shape_related_nodes_v2.onnx") From dbe36a8256339573e347aa10b44e57cfd8a3ef6e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 9 Oct 2025 16:14:40 -0700 Subject: [PATCH 07/23] update error message --- onnxruntime/test/framework/shape_inference_test.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index f70ba25301a37..09c1d9f5a97e9 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -101,10 +101,10 @@ TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); std::vector output_shape = tensor_info.GetShape(); EXPECT_TRUE(output_shape.size() == 4) << "The output shape should have 4 dimensions"; - EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should be 1"; - EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should be 3"; - EXPECT_TRUE(output_shape[2] == 64) << "The second dimension should be 64"; - EXPECT_TRUE(output_shape[3] == 64) << "The second dimension should be 64"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 64) << "The second dimension should have 64 as value"; + EXPECT_TRUE(output_shape[3] == 64) << "The second dimension should have 64 as value"; } { @@ -129,9 +129,9 @@ TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); std::vector output_shape = tensor_info.GetShape(); EXPECT_TRUE(output_shape.size() == 3) << "The output shape should have 3 dimensions"; - EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should be 1"; - EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should be 3"; - EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should be 4096"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should have 4096 as value"; } } From d310a26f884ce583df5bcc2fd3227b00468853c0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 10 Oct 2025 15:48:18 -0700 Subject: [PATCH 08/23] fix warnings from pipelines --- onnxruntime/core/graph/graph.cc | 11 ++- ...ta_propagation_with_shape_related_nodes.py | 41 ++++------- ...propagation_with_shape_related_nodes_v2.py | 68 ++++++++----------- 3 files changed, 47 insertions(+), 73 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index daae51c4ea82f..ec383ecc4ad7e 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2651,8 +2651,8 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { const auto& tensor_shape_proto = def->GetValuesAfterDataPropagation(); - // The inferred data has rank > 1 and the all the dimensions have values (not symbolic) - if (tensor_shape_proto.dim_size() > 1) { + // The inferred data has rank > 0 and the all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { TensorProto tensor_proto; tensor_proto.set_data_type(TensorProto_DataType_INT64); tensor_proto.add_dims(tensor_shape_proto.dim_size()); @@ -2778,7 +2778,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // shape inference for onnxruntime graphs. class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext { public: - DataPropagationContextImpl(Node& node, const Graph& graph) noexcept : node_(node), graph_(graph) { + DataPropagationContextImpl(Node& node) noexcept : node_(node) { node_output_types_.resize(node.OutputDefs().size()); } @@ -2857,7 +2857,7 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext return nullptr; } - void addOutputData(size_t index, TensorShapeProto&& tsp) { + void addOutputData(size_t index, TensorShapeProto&& tsp) override { if (index >= node_output_types_.size()) return; TypeProto& type_proto = node_output_types_[index]; @@ -2875,7 +2875,6 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext private: Node& node_; - const Graph& graph_; std::vector node_output_types_; }; @@ -3242,7 +3241,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // returned here. SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes); InferenceContextImpl context(node, func, *this, options); - DataPropagationContextImpl data_propagation_context(node, *this); + DataPropagationContextImpl data_propagation_context(node); { auto status = Status::OK(); diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py index ce7d33d2f93b6..6537a3cd357c3 100644 --- a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py @@ -1,19 +1,17 @@ import onnx -from onnx import helper, TensorProto +from onnx import TensorProto, helper # 1. Define graph input with symbolic shape ['batch', 3, 'width', 'height'] -input_tensor = helper.make_tensor_value_info( - 'input', TensorProto.FLOAT, ['batch', 3, 'width', 'height'] -) +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) # 2. Define intermediate and output tensors -shape_out = helper.make_tensor_value_info('shape_out', TensorProto.INT64, [4]) # Shape output -reshape_a_out = helper.make_tensor_value_info('reshape_a_out', TensorProto.FLOAT, None) -output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, None) +shape_out = helper.make_tensor_value_info("shape_out", TensorProto.INT64, [4]) # Shape output +reshape_a_out = helper.make_tensor_value_info("reshape_a_out", TensorProto.FLOAT, None) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, None) # 3. Create the initializer for Reshape A's 'shape' input: [0, 32, -1] shape_initializer = helper.make_tensor( - name='reshape_a_shape', + name="reshape_a_shape", data_type=TensorProto.INT64, dims=[3], vals=[0, 32, -1], @@ -21,33 +19,20 @@ # 4. Create nodes: # Shape node -shape_node = helper.make_node( - 'Shape', - inputs=['input'], - outputs=['shape_out'], - name='ShapeNode' -) +shape_node = helper.make_node("Shape", inputs=["input"], outputs=["shape_out"], name="ShapeNode") # Reshape A node: takes input + constant shape reshape_a_node = helper.make_node( - 'Reshape', - inputs=['input', 'reshape_a_shape'], - outputs=['reshape_a_out'], - name='ReshapeA' + "Reshape", inputs=["input", "reshape_a_shape"], outputs=["reshape_a_out"], name="ReshapeA" ) # Reshape B node: takes Shape + ReshapeA outputs, outputs final output -reshape_b_node = helper.make_node( - 'Reshape', - inputs=['reshape_a_out', 'shape_out'], - outputs=['output'], - name='ReshapeB' -) +reshape_b_node = helper.make_node("Reshape", inputs=["reshape_a_out", "shape_out"], outputs=["output"], name="ReshapeB") # 5. Assemble the graph graph = helper.make_graph( nodes=[shape_node, reshape_a_node, reshape_b_node], - name='Shape_Reshape_Model', + name="Shape_Reshape_Model", inputs=[input_tensor], outputs=[output_tensor], initializer=[shape_initializer], @@ -57,12 +42,12 @@ # 6. Define the model (set IR and opset) model = helper.make_model( graph, - opset_imports=[helper.make_operatorsetid('', 18)], - producer_name='onnx-example-generator', + opset_imports=[helper.make_operatorsetid("", 18)], + producer_name="onnx-example-generator", ) model.ir_version = onnx.IR_VERSION # 7. Save the model -onnx.save(model, 'test_shape_data_propagation_with_shape_related_nodes.onnx') +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes.onnx") print("Model saved to test_shape_data_propagation_with_shape_related_nodes.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py index 6620f79f0350c..7cfbcca8d4d03 100644 --- a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py @@ -1,69 +1,59 @@ import onnx -from onnx import helper, TensorProto +from onnx import TensorProto, helper # === Graph input/output === -input_tensor = helper.make_tensor_value_info( - 'input', TensorProto.FLOAT, ['batch', 3, 'width', 'height'] -) -output_tensor = helper.make_tensor_value_info( - 'output', TensorProto.FLOAT, ['batch', 3, 'width*height'] -) +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", 3, "width*height"]) # === Initializers === -B = helper.make_tensor('B', TensorProto.FLOAT, [], [1.0]) +B = helper.make_tensor("B", TensorProto.FLOAT, [], [1.0]) # Gather indices -g0_idx = helper.make_tensor('g0_idx', TensorProto.INT64, [], [0]) -g1_idx = helper.make_tensor('g1_idx', TensorProto.INT64, [], [1]) -g2_idx = helper.make_tensor('g2_idx', TensorProto.INT64, [], [2]) -g3_idx = helper.make_tensor('g3_idx', TensorProto.INT64, [], [3]) +g0_idx = helper.make_tensor("g0_idx", TensorProto.INT64, [], [0]) +g1_idx = helper.make_tensor("g1_idx", TensorProto.INT64, [], [1]) +g2_idx = helper.make_tensor("g2_idx", TensorProto.INT64, [], [2]) +g3_idx = helper.make_tensor("g3_idx", TensorProto.INT64, [], [3]) # Unsqueeze axes tensors -axes_unsq0 = helper.make_tensor('axes_unsq0', TensorProto.INT64, [1], [0]) -axes_unsq1 = helper.make_tensor('axes_unsq1', TensorProto.INT64, [1], [0]) -axes_unsq2 = helper.make_tensor('axes_unsq2', TensorProto.INT64, [1], [0]) +axes_unsq0 = helper.make_tensor("axes_unsq0", TensorProto.INT64, [1], [0]) +axes_unsq1 = helper.make_tensor("axes_unsq1", TensorProto.INT64, [1], [0]) +axes_unsq2 = helper.make_tensor("axes_unsq2", TensorProto.INT64, [1], [0]) # === Nodes === -div = helper.make_node('Div', ['input', 'B'], ['div_out']) +div = helper.make_node("Div", ["input", "B"], ["div_out"]) # Two Shape nodes from Div -shape_left = helper.make_node('Shape', ['div_out'], ['shape_left_out']) -shape_right = helper.make_node('Shape', ['div_out'], ['shape_right_out']) +shape_left = helper.make_node("Shape", ["div_out"], ["shape_left_out"]) +shape_right = helper.make_node("Shape", ["div_out"], ["shape_right_out"]) # Left Shape path -gather0 = helper.make_node('Gather', ['shape_left_out', 'g0_idx'], ['g0_out']) -gather1 = helper.make_node('Gather', ['shape_left_out', 'g1_idx'], ['g1_out']) -unsq0 = helper.make_node('Unsqueeze', ['g0_out', 'axes_unsq0'], ['u0_out']) -unsq1 = helper.make_node('Unsqueeze', ['g1_out', 'axes_unsq1'], ['u1_out']) +gather0 = helper.make_node("Gather", ["shape_left_out", "g0_idx"], ["g0_out"]) +gather1 = helper.make_node("Gather", ["shape_left_out", "g1_idx"], ["g1_out"]) +unsq0 = helper.make_node("Unsqueeze", ["g0_out", "axes_unsq0"], ["u0_out"]) +unsq1 = helper.make_node("Unsqueeze", ["g1_out", "axes_unsq1"], ["u1_out"]) # Right Shape path -gather2 = helper.make_node('Gather', ['shape_right_out', 'g2_idx'], ['g2_out']) -gather3 = helper.make_node('Gather', ['shape_right_out', 'g3_idx'], ['g3_out']) -mul = helper.make_node('Mul', ['g2_out', 'g3_out'], ['mul_out']) -unsq2 = helper.make_node('Unsqueeze', ['mul_out', 'axes_unsq2'], ['u2_out']) +gather2 = helper.make_node("Gather", ["shape_right_out", "g2_idx"], ["g2_out"]) +gather3 = helper.make_node("Gather", ["shape_right_out", "g3_idx"], ["g3_out"]) +mul = helper.make_node("Mul", ["g2_out", "g3_out"], ["mul_out"]) +unsq2 = helper.make_node("Unsqueeze", ["mul_out", "axes_unsq2"], ["u2_out"]) # Combine -concat = helper.make_node('Concat', ['u0_out', 'u1_out', 'u2_out'], ['concat_out'], axis=0) -reshape = helper.make_node('Reshape', ['div_out', 'concat_out'], ['output']) +concat = helper.make_node("Concat", ["u0_out", "u1_out", "u2_out"], ["concat_out"], axis=0) +reshape = helper.make_node("Reshape", ["div_out", "concat_out"], ["output"]) # === Graph === graph = helper.make_graph( - [div, shape_left, shape_right, - gather0, gather1, gather2, gather3, - mul, unsq0, unsq1, unsq2, - concat, reshape], - 'Div_Shape_Gather_Concat_Reshape', + [div, shape_left, shape_right, gather0, gather1, gather2, gather3, mul, unsq0, unsq1, unsq2, concat, reshape], + "Div_Shape_Gather_Concat_Reshape", [input_tensor], [output_tensor], - initializer=[ - B, g0_idx, g1_idx, g2_idx, g3_idx, - axes_unsq0, axes_unsq1, axes_unsq2 - ] + initializer=[B, g0_idx, g1_idx, g2_idx, g3_idx, axes_unsq0, axes_unsq1, axes_unsq2], ) # === Model === -model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)], producer_name='onnx-example') +model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)], producer_name="onnx-example") onnx.checker.check_model(model) -onnx.save(model, 'test_shape_data_propagation_with_shape_related_nodes_v2.onnx') +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes_v2.onnx") print("✅ Model saved as test_shape_data_propagation_with_shape_related_nodes_v2.onnx") From 3fa22a0b793184cc939c926928cc96075e152f74 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 10 Oct 2025 23:41:42 -0700 Subject: [PATCH 09/23] handle for Unsqueeze 11 and eariler that axes is node attribute --- onnxruntime/core/graph/graph.cc | 68 ++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ec383ecc4ad7e..00b54a93223d1 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2887,7 +2887,7 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, return Status::OK(); } - // If the dimension size is 0, it indicates one of the following cases: + // If the dimension size is 0, it could indicate one of the following cases: // 1. The inferred output is a scalar. // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() cannot handle it. // @@ -2896,7 +2896,7 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, // scalar value in the NodeArg if applicable. if (dim_size == 0) { if (node.OpType() == "Gather") { - // Following code extracts an element from a 1D array. + // Following code extracts an element from a 1D array if all conditions are met. // e.g. // shape data is [1, 3, 64, 64] -> gets 64 if the index is 2. // shape data is [1, 3, 64, 64] -> gets 3 if the index is 1. @@ -2948,42 +2948,58 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } } else if (node.OpType() == "Unsqueeze") { - // Following code expends the dimension of a scalr to one dimension. - // e.g. - // shape data is 64 -> gets [64] + // Following code expands a scalr to one dimension array if all conditions are met. + // e.g. shape data is 64 -> it becomes [64] // Try to get the "data" input const auto* input_0 = node.GetDefinitions().input_defs[0]; - // Only handle the "data" input is previously inferred from data propagation and is a scalar + // Only handle the "data" input which is previously inferred from data propagation and is a scalar if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { - // Try to get the "Axes" input as an initializer - const auto* input_1 = node.GetDefinitions().input_defs[1]; - const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); + // Try to get the "axes" value int64_t axis = -1; - if (initializer) { - if (utils::HasRawData(*initializer)) { - const std::string& raw = initializer->raw_data(); - const int64_t* data = reinterpret_cast(raw.data()); - axis = *data; - } else { - if (initializer->data_type() == TensorProto_DataType_INT32) { - std::vector values; - values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); - if (values.size() == 1) { - axis = static_cast(values[0]); - } - } else if (initializer->data_type() == TensorProto_DataType_INT64) { - std::vector values; - values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); - if (values.size() == 1) { - axis = values[0]; + // Note: Starting from opset 13, "axes" is provided as a second input to the Unsqueeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node.GetDefinitions().input_defs.size() > 1) { + const auto* input_1 = node.GetDefinitions().input_defs[1]; + // Only check the case that "axes" input is an initializer + const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); + + if (initializer && initializer->dims_size() == 1) { + if (utils::HasRawData(*initializer)) { + const std::string& raw = initializer->raw_data(); + const int64_t* data = reinterpret_cast(raw.data()); + axis = *data; + } else { + if (initializer->data_type() == TensorProto_DataType_INT32) { + std::vector values; + values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); + if (values.size() == 1) { + axis = static_cast(values[0]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + std::vector values; + values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); + if (values.size() == 1) { + axis = values[0]; + } } } } } + const ONNX_NAMESPACE::AttributeProto* axes_attr = node.GetAttributes().count("axes") + ? &node.GetAttributes().at("axes") + : nullptr; + if (axes_attr) { + for (auto v : axes_attr->ints()) { + axis = v; + break; + } + } + + // In this case, the axis should be 0 if (axis == 0) { output_def.values_after_data_propagation_.clear_dim(); output_def.values_after_data_propagation_.add_dim()->set_dim_value(input_0->scalar_value_after_data_propagation_); From 92c71bedc3e4932fe40961ff6d7cbe9316f14ea3 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 13 Oct 2025 09:14:51 -0700 Subject: [PATCH 10/23] fix issue for 'indices' input to Gather has negative value --- onnxruntime/core/graph/graph.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 00b54a93223d1..b1a78be45e323 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2940,7 +2940,13 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, auto& tensor_shape_proto = input_0->values_after_data_propagation_; // Save the dimension value in the NodeArg - if (index != std::numeric_limits::min() && (index < tensor_shape_proto.dim_size())) { + // Index value is expected to be within bounds [-s, s-1] along axis of size s + if (index != std::numeric_limits::min() && + index < tensor_shape_proto.dim_size() && index >= -tensor_shape_proto.dim_size()) { + if (index < 0) { + index = tensor_shape_proto.dim_size() + index; + } + auto& dim = tensor_shape_proto.dim(static_cast(index)); if (dim.has_dim_value()) { output_def.scalar_value_after_data_propagation_ = dim.dim_value(); From f652eec6c621a456883aa55569c4058e1bad3c15 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 13 Oct 2025 18:39:40 -0700 Subject: [PATCH 11/23] address issue from pipeline --- onnxruntime/core/graph/graph.cc | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index b1a78be45e323..f4f12021b131e 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2830,8 +2830,6 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext if (!def) return nullptr; - // Try to get the previously inferred shape values that stored in NodeArg's values_after_data_propagation_ - const TensorShapeProto* tensor_shape_proto = nullptr; auto has_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { @@ -2848,6 +2846,7 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext return false; }; + // Try to get the previously inferred shape values that stored in NodeArg's values_after_data_propagation_ // Get NodeArg's values_after_data_propagation_ if applicable tensor_shape_proto = &def->GetValuesAfterDataPropagation(); if (has_shape_values(*tensor_shape_proto)) { @@ -2887,6 +2886,30 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, return Status::OK(); } + // Size operator generates a scalar output and a scalar has 0 rank. + // But its PartialDataPropagationFunction() has the chance to generate a shape data with rank > 0. + // So, handle it here. + if (node.OpType() == "Size") { + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->values_after_data_propagation_; + auto get_num_elements = [&](const TensorShapeProto& tensor_shape_proto) -> void { + int64_t num_elements = 1; + // The TensorShapeProto (inferred shape values) should have rank > 0 and the all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return; + } + num_elements *= dim.dim_value(); + } + output_def.scalar_value_after_data_propagation_ = num_elements; + } + }; + get_num_elements(tensor_shape_proto); + + return Status::OK(); + } + // If the dimension size is 0, it could indicate one of the following cases: // 1. The inferred output is a scalar. // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() cannot handle it. @@ -3054,7 +3077,7 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } } // If the dimension size is greater than 0, only save the inferred data from data propagation - // when the data has rank > 1 and all dimensions have concrete (non-symbolic) values. + // when the data has rank > 0 and all dimensions have concrete (non-symbolic) values. else if (dim_size > 0) { for (int i = 0; i < dim_size; ++i) { if (!onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).has_dim_value()) { From 78ed993a6d546829121b94b58101c184ac7fd857 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 13 Oct 2025 22:49:11 -0700 Subject: [PATCH 12/23] fix pipeline warning --- onnxruntime/core/graph/graph.cc | 34 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f4f12021b131e..3a513054569e0 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2882,19 +2882,17 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, const TypeProto& onnx_inferred_types_after_data_propagation) const { auto dim_size = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim_size(); - if (dim_size < 0) { - return Status::OK(); - } - - // Size operator generates a scalar output and a scalar has 0 rank. - // But its PartialDataPropagationFunction() has the chance to generate a shape data with rank > 0. + // Size operator generates a scalar output and a scalar has rank equals zero. + // But Size operator's PartialDataPropagationFunction() has the chance to + // generate output data with rank > 0. // So, handle it here. if (node.OpType() == "Size") { const auto* input_0 = node.GetDefinitions().input_defs[0]; auto& tensor_shape_proto = input_0->values_after_data_propagation_; auto get_num_elements = [&](const TensorShapeProto& tensor_shape_proto) -> void { int64_t num_elements = 1; - // The TensorShapeProto (inferred shape values) should have rank > 0 and the all the dimensions have values (not symbolic) + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) if (tensor_shape_proto.dim_size() > 0) { for (const auto& dim : tensor_shape_proto.dim()) { if (!dim.has_dim_value()) { @@ -2910,13 +2908,13 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, return Status::OK(); } - // If the dimension size is 0, it could indicate one of the following cases: + // If dimension size is 0, it could indicate one of the following cases: // 1. The inferred output is a scalar. // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() cannot handle it. // // In other words, some operators' PartialDataPropagationFunction() implementations do not support // scalar inputs or outputs. In such cases, attempt data propagation manually and store the inferred - // scalar value in the NodeArg if applicable. + // scalar value in the NodeArg if any. if (dim_size == 0) { if (node.OpType() == "Gather") { // Following code extracts an element from a 1D array if all conditions are met. @@ -3016,15 +3014,15 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } } } - } - - const ONNX_NAMESPACE::AttributeProto* axes_attr = node.GetAttributes().count("axes") - ? &node.GetAttributes().at("axes") - : nullptr; - if (axes_attr) { - for (auto v : axes_attr->ints()) { - axis = v; - break; + } else { + const auto& attrs = node.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (auto v : axes_attr.ints()) { + axis = v; + break; // get first value + } } } From 035a66e69f179bcd98383dc6e5fac486aaebd4b7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 14 Oct 2025 09:18:09 -0700 Subject: [PATCH 13/23] fix warning in pipeline --- onnxruntime/core/graph/graph.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 3a513054569e0..9c303aff55284 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3019,9 +3019,8 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, auto it = attrs.find("axes"); if (it != attrs.end()) { const auto& axes_attr = it->second; - for (auto v : axes_attr.ints()) { - axis = v; - break; // get first value + if (axes_attr.ints_size()) { + axis = axes_attr.ints()[0]; } } } From e3e26a0b377547aab87ac095b154b1c58f668d58 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 14 Oct 2025 10:48:41 -0700 Subject: [PATCH 14/23] refactor the code --- onnxruntime/core/graph/graph.cc | 81 ++++++++++++++------------------- 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 9c303aff55284..d629e3c1166e4 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2629,48 +2629,6 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { graph_(graph), options_(options) { node_output_types_.resize(node.OutputDefs().size()); - - // The following code handles cases where a node stores the previously inferred shape values in its NodeArg. - // - // For example, the Reshape operator, its input shape may come from a producer node such as a Shape operator, - // and the inferred shape value is already stored as a ShapeTensorProto in corresponding NodeArg. - // - // In such cases, the Reshape operator should convert this ShapeTensorProto into a TensorProto. - // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, - // allowing the real dimension values to be correctly used. - // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. - // - // Note: The following code is placed here instead of in getInputData() because - // getInputData() is a const function. As a result, we cannot create a new ShapeTensorProto - // and store it in this class within a const function. - for (const NodeArg* def : node_.InputDefs()) { - // Skip initializer as it will be handled in getInputData() - if (graph_.GetConstantInitializer(def->Name(), true)) { - continue; - } - - const auto& tensor_shape_proto = def->GetValuesAfterDataPropagation(); - - // The inferred data has rank > 0 and the all the dimensions have values (not symbolic) - if (tensor_shape_proto.dim_size() > 0) { - TensorProto tensor_proto; - tensor_proto.set_data_type(TensorProto_DataType_INT64); - tensor_proto.add_dims(tensor_shape_proto.dim_size()); - bool all_values = true; - for (const auto& dim : tensor_shape_proto.dim()) { - if (dim.has_dim_value()) { - tensor_proto.add_int64_data(dim.dim_value()); - } else { - all_values = false; - break; - } - } - - if (all_values) { - tensor_proto_for_shape_[def->Name()] = std::move(tensor_proto); - } - } - } } void RunInferencing() { @@ -2725,11 +2683,38 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { return initializer; } - // Return the ShapeTensorProto that was previously created in - // the InferenceContextImpl's constructor if there is any. - auto tensor_proto_for_shape = tensor_proto_for_shape_.find(def->Name()); - if (tensor_proto_for_shape != tensor_proto_for_shape_.end()) { - return &tensor_proto_for_shape->second; + // The following code handles cases where a node stores the previously inferred output shape values in its NodeArg. + // + // For example, the Reshape operator, its shape input may come from a producer node such as a Shape operator, + // and the inferred output shape value is already stored as a TensorShapeProto in corresponding NodeArg. + // + // In such cases, the Reshape operator should convert this TensorShapeProto into a TensorProto. + // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, + // allowing the real dimension values to be correctly used. + // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. + + const auto& tensor_shape_proto = def->GetValuesAfterDataPropagation(); + + // Make sure the returning shape tensor as a TensorProto has rank > 0 and all the dimensions + // have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(tensor_shape_proto.dim_size()); + bool all_values = true; + for (const auto& dim : tensor_shape_proto.dim()) { + if (dim.has_dim_value()) { + tensor_proto.add_int64_data(dim.dim_value()); + } else { + all_values = false; + break; + } + } + + if (all_values) { + tensor_proto_for_shape_.push_back(std::make_unique(std::move(tensor_proto))); + return tensor_proto_for_shape_.back().get(); + } } return nullptr; @@ -2767,7 +2752,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { Node& node_; // node_output_types_ will be populated by the operator-specific shape inference. std::vector node_output_types_; - std::unordered_map tensor_proto_for_shape_; + mutable InlinedVector> tensor_proto_for_shape_; SubgraphInferencingFunc subgraph_inferencing_func_; std::vector> graph_inferencers_; const Graph& graph_; From 9dc51aa01c71dc65bf5bc38312d614c241198644 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 14 Oct 2025 13:01:39 -0700 Subject: [PATCH 15/23] address reviewer's comments --- include/onnxruntime/core/graph/graph.h | 8 +- include/onnxruntime/core/graph/node_arg.h | 8 +- onnxruntime/core/graph/graph.cc | 98 +++++++++++------------ 3 files changed, 56 insertions(+), 58 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 02ce84e0b1f88..f08bff9876494 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1760,11 +1760,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::vector& output_types, const Graph::ResolveOptions& options); - // If the dimension values are inferred after executing ONNX operator's PartialDataPropagationFunction(), + // If the shape values are inferred after executing ONNX operator's PartialDataPropagationFunction(), // save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes - // can use them later for their PartialDataPropagationFunction() and PartialDataPropagationFunction(). - common::Status SaveValuesFromDataPropagation(Node& node, NodeArg& output_def, - const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const; + // can use them later for their TypeAndShapeInferenceFunction() and PartialDataPropagationFunction(). + common::Status SaveShapeValuesFromDataPropagation(Node& node, NodeArg& output_def, + const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const; // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index d033a0bf73397..9e474d340d002 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -107,8 +107,8 @@ class NodeArg { /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } - /** Gets the inferred shape as a TensorShapeProto. */ - const ONNX_NAMESPACE::TensorShapeProto& GetValuesAfterDataPropagation() const noexcept { return values_after_data_propagation_; } + /** Gets the inferred shape values as a TensorShapeProto. */ + const std::optional GetInferredShapeValues() const noexcept { return inferred_shape_values_; } /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ @@ -140,12 +140,12 @@ class NodeArg { // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also // be executed to obtain the concrete output shape values, allowing accurate propagation // of shape information throughout the graph. - ONNX_NAMESPACE::TensorShapeProto values_after_data_propagation_; + std::optional inferred_shape_values_; // This variable stores the inferred scalar output. // It is also used for shape inference and data propagation to ensure consistent shape and // value information throughout the graph. - int64_t scalar_value_after_data_propagation_ = std::numeric_limits::min(); + std::optional inferred_scalar_value_; // Flag indicates whether <*this> node arg exists or not. bool exists_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 376138f27f24d..e70abbde5e8f4 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2699,7 +2699,7 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { " has in-memory external data but cannot get OrtValue during shape inference"); } } - + return initializer; } @@ -2712,17 +2712,16 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, // allowing the real dimension values to be correctly used. // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. - - const auto& tensor_shape_proto = def->GetValuesAfterDataPropagation(); + const auto& tensor_shape_proto = def->GetInferredShapeValues(); // Make sure the returning shape tensor as a TensorProto has rank > 0 and all the dimensions // have values (not symbolic) - if (tensor_shape_proto.dim_size() > 0) { + if (tensor_shape_proto.has_value() && tensor_shape_proto->dim_size() > 0) { TensorProto tensor_proto; tensor_proto.set_data_type(TensorProto_DataType_INT64); - tensor_proto.add_dims(tensor_shape_proto.dim_size()); + tensor_proto.add_dims(tensor_shape_proto->dim_size()); bool all_values = true; - for (const auto& dim : tensor_shape_proto.dim()) { + for (const auto& dim : tensor_shape_proto->dim()) { if (dim.has_dim_value()) { tensor_proto.add_int64_data(dim.dim_value()); } else { @@ -2840,10 +2839,9 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext if (!def) return nullptr; - const TensorShapeProto* tensor_shape_proto = nullptr; - auto has_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { - // The TensorShapeProto (inferred shape values) should have rank > 0 and the all the dimensions have values (not symbolic) + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) if (tensor_shape_proto.dim_size() > 0) { for (const auto& dim : tensor_shape_proto.dim()) { if (!dim.has_dim_value()) { @@ -2856,11 +2854,10 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext return false; }; - // Try to get the previously inferred shape values that stored in NodeArg's values_after_data_propagation_ - // Get NodeArg's values_after_data_propagation_ if applicable - tensor_shape_proto = &def->GetValuesAfterDataPropagation(); - if (has_shape_values(*tensor_shape_proto)) { - return tensor_shape_proto; + // Try to get the previously inferred shape values that stored in NodeArg's inferred_shape_values_ + auto& tensor_shape_proto = def->GetInferredShapeValues(); + if (tensor_shape_proto.has_value() && has_shape_values(*tensor_shape_proto)) { + return &*tensor_shape_proto; } return nullptr; @@ -2887,18 +2884,18 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext std::vector node_output_types_; }; -Status Graph::SaveValuesFromDataPropagation(Node& node, - NodeArg& output_def, - const TypeProto& onnx_inferred_types_after_data_propagation) const { +Status Graph::SaveShapeValuesFromDataPropagation(Node& node, + NodeArg& output_def, + const TypeProto& onnx_inferred_types_after_data_propagation) const { auto dim_size = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim_size(); // Size operator generates a scalar output and a scalar has rank equals zero. // But Size operator's PartialDataPropagationFunction() has the chance to // generate output data with rank > 0. - // So, handle it here. + // So, ignore its PartialDataPropagationFunction() and infer the output value here. if (node.OpType() == "Size") { const auto* input_0 = node.GetDefinitions().input_defs[0]; - auto& tensor_shape_proto = input_0->values_after_data_propagation_; + auto& tensor_shape_proto = input_0->inferred_shape_values_; auto get_num_elements = [&](const TensorShapeProto& tensor_shape_proto) -> void { int64_t num_elements = 1; // The TensorShapeProto (inferred shape values) should have rank > 0 and @@ -2910,10 +2907,13 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } num_elements *= dim.dim_value(); } - output_def.scalar_value_after_data_propagation_ = num_elements; + output_def.inferred_scalar_value_ = num_elements; } }; - get_num_elements(tensor_shape_proto); + + if (tensor_shape_proto.has_value()) { + get_num_elements(*tensor_shape_proto); + } return Status::OK(); } @@ -2968,19 +2968,21 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } // Get the previously inferred dimension values - auto& tensor_shape_proto = input_0->values_after_data_propagation_; + auto& tensor_shape_proto = input_0->inferred_shape_values_; - // Save the dimension value in the NodeArg + // Save the dimension value in the NodeArg. // Index value is expected to be within bounds [-s, s-1] along axis of size s - if (index != std::numeric_limits::min() && - index < tensor_shape_proto.dim_size() && index >= -tensor_shape_proto.dim_size()) { - if (index < 0) { - index = tensor_shape_proto.dim_size() + index; - } + if (tensor_shape_proto.has_value()) { + if (index != std::numeric_limits::min() && + index < tensor_shape_proto->dim_size() && index >= -tensor_shape_proto->dim_size()) { + if (index < 0) { + index = tensor_shape_proto->dim_size() + index; + } - auto& dim = tensor_shape_proto.dim(static_cast(index)); - if (dim.has_dim_value()) { - output_def.scalar_value_after_data_propagation_ = dim.dim_value(); + auto& dim = tensor_shape_proto->dim(static_cast(index)); + if (dim.has_dim_value()) { + output_def.inferred_scalar_value_ = dim.dim_value(); + } } } @@ -2992,7 +2994,7 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, const auto* input_0 = node.GetDefinitions().input_defs[0]; // Only handle the "data" input which is previously inferred from data propagation and is a scalar - if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { + if (input_0->inferred_scalar_value_.has_value()) { // Try to get the "axes" value int64_t axis = -1; @@ -3037,8 +3039,8 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, // In this case, the axis should be 0 if (axis == 0) { - output_def.values_after_data_propagation_.clear_dim(); - output_def.values_after_data_propagation_.add_dim()->set_dim_value(input_0->scalar_value_after_data_propagation_); + output_def.inferred_shape_values_->clear_dim(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(*input_0->inferred_scalar_value_); } } } else if (node.OpType() == "Add") { @@ -3047,9 +3049,8 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, // Try to get the "B" input const auto* input_1 = node.GetDefinitions().input_defs[1]; - if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && - input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { - output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ + input_1->scalar_value_after_data_propagation_; + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ + *input_1->inferred_scalar_value_; } } else if (node.OpType() == "Sub") { // Try to get the "A" input @@ -3057,9 +3058,8 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, // Try to get the "B" input const auto* input_1 = node.GetDefinitions().input_defs[1]; - if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && - input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { - output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ - input_1->scalar_value_after_data_propagation_; + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ - *input_1->inferred_scalar_value_; } } else if (node.OpType() == "Mul") { // Try to get the "A" input @@ -3067,9 +3067,8 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, // Try to get the "B" input const auto* input_1 = node.GetDefinitions().input_defs[1]; - if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && - input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { - output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ * input_1->scalar_value_after_data_propagation_; + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ * (*input_1->inferred_scalar_value_); } } else if (node.OpType() == "Div") { // Try to get the "A" input @@ -3077,13 +3076,12 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, // Try to get the "B" input const auto* input_1 = node.GetDefinitions().input_defs[1]; - if (input_0->scalar_value_after_data_propagation_ != std::numeric_limits::min() && - input_1->scalar_value_after_data_propagation_ != std::numeric_limits::min()) { - output_def.scalar_value_after_data_propagation_ = input_0->scalar_value_after_data_propagation_ / input_1->scalar_value_after_data_propagation_; + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ / *input_1->inferred_scalar_value_; } } } - // If the dimension size is greater than 0, only save the inferred data from data propagation + // If the dimension size is greater than 0, only save the inferred shape values from data propagation // when the data has rank > 0 and all dimensions have concrete (non-symbolic) values. else if (dim_size > 0) { for (int i = 0; i < dim_size; ++i) { @@ -3092,10 +3090,10 @@ Status Graph::SaveValuesFromDataPropagation(Node& node, } } - output_def.values_after_data_propagation_.clear_dim(); + output_def.inferred_shape_values_->clear_dim(); for (int i = 0; i < dim_size; ++i) { auto value = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).dim_value(); - output_def.values_after_data_propagation_.add_dim()->set_dim_value(value); + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); } } @@ -3337,7 +3335,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; - ORT_RETURN_IF_ERROR(SaveValuesFromDataPropagation(node, *output_def, onnx_inferred_types_after_data_propagation[i])); + ORT_RETURN_IF_ERROR(SaveShapeValuesFromDataPropagation(node, *output_def, onnx_inferred_types_after_data_propagation[i])); DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; From 1ec4bf98f9d440d99c8200c63a67bb5f039896f4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 14 Oct 2025 14:15:32 -0700 Subject: [PATCH 16/23] fix bug after using std::optional to store value --- include/onnxruntime/core/graph/node_arg.h | 2 +- onnxruntime/core/graph/graph.cc | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 9e474d340d002..325e9f50e4ea2 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -108,7 +108,7 @@ class NodeArg { const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } /** Gets the inferred shape values as a TensorShapeProto. */ - const std::optional GetInferredShapeValues() const noexcept { return inferred_shape_values_; } + const std::optional& GetInferredShapeValues() const noexcept { return inferred_shape_values_; } /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index e70abbde5e8f4..83a71e260c858 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3039,6 +3039,9 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, // In this case, the axis should be 0 if (axis == 0) { + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } output_def.inferred_shape_values_->clear_dim(); output_def.inferred_shape_values_->add_dim()->set_dim_value(*input_0->inferred_scalar_value_); } @@ -3090,6 +3093,10 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, } } + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); for (int i = 0; i < dim_size; ++i) { auto value = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).dim_value(); From 1391a15e53eb702130591f8c1e75ad5fbc3f0f99 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 14 Oct 2025 15:09:22 -0700 Subject: [PATCH 17/23] address reviewer's comments --- onnxruntime/core/graph/graph.cc | 9 +++++---- onnxruntime/test/framework/shape_inference_test.cc | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 83a71e260c858..95c198ce8cfe5 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2711,7 +2711,6 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // In such cases, the Reshape operator should convert this TensorShapeProto into a TensorProto. // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, // allowing the real dimension values to be correctly used. - // See https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.cc#L467 for details. const auto& tensor_shape_proto = def->GetInferredShapeValues(); // Make sure the returning shape tensor as a TensorProto has rank > 0 and all the dimensions @@ -2731,8 +2730,8 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { } if (all_values) { - tensor_proto_for_shape_.push_back(std::make_unique(std::move(tensor_proto))); - return tensor_proto_for_shape_.back().get(); + temp_tensor_protos_.push_back(std::make_unique(std::move(tensor_proto))); + return temp_tensor_protos_.back().get(); } } @@ -2771,7 +2770,6 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { Node& node_; // node_output_types_ will be populated by the operator-specific shape inference. std::vector node_output_types_; - mutable InlinedVector> tensor_proto_for_shape_; SubgraphInferencingFunc subgraph_inferencing_func_; std::vector> graph_inferencers_; const Graph& graph_; @@ -2780,11 +2778,14 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // These need to outlive the shape inference call, so we store them here // Inference is per node and the instance of this context is on the stack, // so this is safe. + // It can also be used to temporarily save the inferred shape values as a TensorProto. mutable InlinedVector> temp_tensor_protos_; }; // An implementation of the DataPropagationContext interface optional by operator-specific // shape inference for onnxruntime graphs. +// Please see the description and usage of ONNX's data propagation here: +// https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.h#L117-L127 class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext { public: DataPropagationContextImpl(Node& node) noexcept : node_(node) { diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index 4b1c62ca5c310..a7910e28b1d6d 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -80,6 +80,7 @@ TEST_F(ShapeInferenceTest, BasicTest) { TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { { + // This model contains "Shape" and "Reshape" operators. auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes.onnx"); Ort::SessionOptions session_options{}; @@ -108,6 +109,7 @@ TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { } { + // This model contains "Shape", "Reshape", "Gather" and "Unsqueeze" operators. auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx"); Ort::SessionOptions session_options{}; From 21d5e335bb9c6d2667f07252320e94b7df6530e7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 14 Oct 2025 15:34:07 -0700 Subject: [PATCH 18/23] add check for get initializer as in-memory external --- onnxruntime/core/graph/graph.cc | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 95c198ce8cfe5..8bf07e0a029a7 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2947,7 +2947,20 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, int64_t index = std::numeric_limits::min(); if (initializer) { - if (utils::HasRawData(*initializer)) { + // Check if this is in-memory external data (data stored in OrtValue) + if (utils::HasExternalDataInMemory(*initializer)) { + // Try to get the OrtValue for this initializer + OrtValue ort_value; + if (this->GetOrtValueInitializer(input_1->Name(), ort_value, true)) { + const Tensor& tensor = ort_value.Get(); + const int64_t* data = tensor.Data(); + index = *data; + } else { + // If we can't get the OrtValue, it is a bug + ORT_THROW("Initializer ", input_1->Name(), + " has in-memory external data but cannot get OrtValue during shape inference"); + } + } else if (utils::HasRawData(*initializer)) { const std::string& raw = initializer->raw_data(); const int64_t* data = reinterpret_cast(raw.data()); index = *data; @@ -3007,7 +3020,20 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); if (initializer && initializer->dims_size() == 1) { - if (utils::HasRawData(*initializer)) { + // Check if this is in-memory external data (data stored in OrtValue) + if (utils::HasExternalDataInMemory(*initializer)) { + // Try to get the OrtValue for this initializer + OrtValue ort_value; + if (this->GetOrtValueInitializer(input_1->Name(), ort_value, true)) { + const Tensor& tensor = ort_value.Get(); + const int64_t* data = tensor.Data(); + axis = *data; + } else { + // If we can't get the OrtValue, it is a bug + ORT_THROW("Initializer ", input_1->Name(), + " has in-memory external data but cannot get OrtValue during shape inference"); + } + } else if (utils::HasRawData(*initializer)) { const std::string& raw = initializer->raw_data(); const int64_t* data = reinterpret_cast(raw.data()); axis = *data; From e59216baf7eda0a03e6f8a7ab3f25b9af049140a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 15 Oct 2025 16:23:22 -0700 Subject: [PATCH 19/23] fix type issue --- onnxruntime/core/graph/graph.cc | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 8bf07e0a029a7..6dcfbc17e32fa 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2953,8 +2953,13 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, OrtValue ort_value; if (this->GetOrtValueInitializer(input_1->Name(), ort_value, true)) { const Tensor& tensor = ort_value.Get(); - const int64_t* data = tensor.Data(); - index = *data; + if (initializer->data_type() == TensorProto_DataType_INT32) { + const int32_t* data = tensor.Data(); + index = static_cast(*data); + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + const int64_t* data = tensor.Data(); + index = *data; + } } else { // If we can't get the OrtValue, it is a bug ORT_THROW("Initializer ", input_1->Name(), @@ -2962,8 +2967,13 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, } } else if (utils::HasRawData(*initializer)) { const std::string& raw = initializer->raw_data(); - const int64_t* data = reinterpret_cast(raw.data()); - index = *data; + if (initializer->data_type() == TensorProto_DataType_INT32) { + const int32_t* data = reinterpret_cast(raw.data()); + index = static_cast(*data); + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + const int64_t* data = reinterpret_cast(raw.data()); + index = *data; + } } else { if (initializer->data_type() == TensorProto_DataType_INT32) { std::vector values; From 2616c6fc3a5a722c5b6ad6f209c4ae1d0318e766 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 20 Oct 2025 15:23:30 -0700 Subject: [PATCH 20/23] refactor code and address corner case --- onnxruntime/core/graph/graph.cc | 374 ++++++++++++++++++++------------ 1 file changed, 238 insertions(+), 136 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 6dcfbc17e32fa..17fcfa0a01e6c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2887,14 +2887,71 @@ class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext Status Graph::SaveShapeValuesFromDataPropagation(Node& node, NodeArg& output_def, - const TypeProto& onnx_inferred_types_after_data_propagation) const { - auto dim_size = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim_size(); + const TypeProto& onnx_inferred_type_after_data_propagation) const { + auto dim_size = onnx_inferred_type_after_data_propagation.tensor_type().shape().dim_size(); - // Size operator generates a scalar output and a scalar has rank equals zero. - // But Size operator's PartialDataPropagationFunction() has the chance to - // generate output data with rank > 0. - // So, ignore its PartialDataPropagationFunction() and infer the output value here. + // Helper function to get the input value if it's a initializer. + auto get_initialized_input_values = [&](const std::string& input_name, std::vector& input_values) -> void { + const TensorProto* initializer = this->GetConstantInitializer(input_name, true); + + if (initializer) { + // Get shape from TensorProto and calculate element counts. + // If shape has dimension size equals zero, it means it's a scalar and has only one value. + int64_t element_cnt = 1; + for (auto& dim : initializer->dims()) { + element_cnt *= dim; + } + + // Check if this is in-memory external data (data stored in OrtValue) + if (utils::HasExternalDataInMemory(*initializer)) { + // Try to get the OrtValue for this initializer + OrtValue ort_value; + if (this->GetOrtValueInitializer(input_name, ort_value, true)) { + const Tensor& tensor = ort_value.Get(); + if (initializer->data_type() == TensorProto_DataType_INT32) { + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = static_cast(tensor.Data()[i]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = tensor.Data()[i]; + } + } + } else { + // If we can't get the OrtValue, it is a bug + ORT_THROW("Initializer ", input_name, + " has in-memory external data but cannot get OrtValue during shape inference"); + } + } else if (utils::HasRawData(*initializer)) { + const std::string& raw = initializer->raw_data(); + const int64_t* data = reinterpret_cast(raw.data()); + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = data[i]; + } + } else { + if (initializer->data_type() == TensorProto_DataType_INT32) { + std::vector values; + values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = static_cast(values[0]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + input_values.resize(element_cnt); + input_values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); + } + } + } + }; + + // Following operators, e.g. Size, Squeeze and Unsqueeze, calling their PartialDataPropagationFunction() alone + // to get the inferred shape values is still not enough to get the proper values. + // So, ignore the inferred values from PartialDataPropagationFunction() and infer the output values here. if (node.OpType() == "Size") { + // Size operator generates a scalar output const auto* input_0 = node.GetDefinitions().input_defs[0]; auto& tensor_shape_proto = input_0->inferred_shape_values_; auto get_num_elements = [&](const TensorShapeProto& tensor_shape_proto) -> void { @@ -2916,6 +2973,157 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, get_num_elements(*tensor_shape_proto); } + return Status::OK(); + + } else if (node.OpType() == "Squeeze") { + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->inferred_shape_values_; + + auto update_output_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> void { + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return; + } + } + } + + if (tensor_shape_proto.dim_size() == 1) { + output_def.inferred_scalar_value_ = tensor_shape_proto.dim(0).dim_value(); + } else if (tensor_shape_proto.dim_size() > 1) { + // Get axes value + std::vector axes; + std::unordered_set axes_set; + + // Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node.GetDefinitions().input_defs.size() > 1) { + const auto* input_1 = node.GetDefinitions().input_defs[1]; + get_initialized_input_values(input_1->Name(), axes); + } else { + const auto& attrs = node.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (const auto& i : axes_attr.ints()) { + axes.push_back(i); + } + } + } + + for (size_t i = 0; i < axes.size(); ++i) { + // Negative value means counting dimensions from the back. + if (axes[i] < 0) { + axes_set.insert(axes[i] + tensor_shape_proto.dim_size()); + } else { + axes_set.insert(axes[i]); + } + } + + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); + + for (const auto& dim : tensor_shape_proto.dim()) { + auto value = dim.dim_value(); + if (axes_set.find(value) == axes_set.end() && value != 1) { + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + } + } + } + }; + + if (tensor_shape_proto.has_value()) { + update_output_shape_values(*tensor_shape_proto); + } + + return Status::OK(); + + } else if (node.OpType() == "Unsqueeze") { + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->inferred_shape_values_; + + auto update_output_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> void { + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return; + } + } + } + + if (tensor_shape_proto.dim_size() > 0) { + // Get axes value + std::vector axes; + std::unordered_set axes_set; + + // Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node.GetDefinitions().input_defs.size() > 1) { + const auto* input_1 = node.GetDefinitions().input_defs[1]; + get_initialized_input_values(input_1->Name(), axes); + } else { + const auto& attrs = node.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (const auto& i : axes_attr.ints()) { + axes.push_back(i); + } + } + } + + // axes is required, if not provided just do nothing and return. + if (axes.empty()) { + return; + } + + for (size_t i = 0; i < axes.size(); ++i) { + // Negative value means counting dimensions from the back. + if (axes[i] < 0) { + axes_set.insert(axes[i] + tensor_shape_proto.dim_size()); + } else { + axes_set.insert(axes[i]); + } + } + + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); + + int64_t axis = 0; + for (const auto& dim : tensor_shape_proto.dim()) { + if (axes_set.find(axis) != axes_set.end()) { + output_def.inferred_shape_values_->add_dim()->set_dim_value(1); + } + + auto value = dim.dim_value(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + + axis += 1; + } + } + }; + + if (dim_size == 0 && input_0->inferred_scalar_value_.has_value()) { + // Following code expands a scalr to one dimension array, e.g. shape data is 64 -> it becomes [64] + // In this case, the axis should be 0 + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(*input_0->inferred_scalar_value_); + + } else if (tensor_shape_proto.has_value()) { + update_output_shape_values(*tensor_shape_proto); + } + return Status::OK(); } @@ -2943,53 +3151,8 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, const auto* input_1 = node.GetDefinitions().input_defs[1]; // The "indices" should be an initializer because it's a scalar in this case. - const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); - int64_t index = std::numeric_limits::min(); - - if (initializer) { - // Check if this is in-memory external data (data stored in OrtValue) - if (utils::HasExternalDataInMemory(*initializer)) { - // Try to get the OrtValue for this initializer - OrtValue ort_value; - if (this->GetOrtValueInitializer(input_1->Name(), ort_value, true)) { - const Tensor& tensor = ort_value.Get(); - if (initializer->data_type() == TensorProto_DataType_INT32) { - const int32_t* data = tensor.Data(); - index = static_cast(*data); - } else if (initializer->data_type() == TensorProto_DataType_INT64) { - const int64_t* data = tensor.Data(); - index = *data; - } - } else { - // If we can't get the OrtValue, it is a bug - ORT_THROW("Initializer ", input_1->Name(), - " has in-memory external data but cannot get OrtValue during shape inference"); - } - } else if (utils::HasRawData(*initializer)) { - const std::string& raw = initializer->raw_data(); - if (initializer->data_type() == TensorProto_DataType_INT32) { - const int32_t* data = reinterpret_cast(raw.data()); - index = static_cast(*data); - } else if (initializer->data_type() == TensorProto_DataType_INT64) { - const int64_t* data = reinterpret_cast(raw.data()); - index = *data; - } - } else { - if (initializer->data_type() == TensorProto_DataType_INT32) { - std::vector values; - values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); - if (values.size() == 1) { - index = static_cast(values[0]); - } - } else if (initializer->data_type() == TensorProto_DataType_INT64) { - std::vector values; - values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); - if (values.size() == 1) { - index = values[0]; - } - } - } - } + std::vector indices; + get_initialized_input_values(input_1->Name(), indices); // Get the previously inferred dimension values auto& tensor_shape_proto = input_0->inferred_shape_values_; @@ -2997,92 +3160,19 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, // Save the dimension value in the NodeArg. // Index value is expected to be within bounds [-s, s-1] along axis of size s if (tensor_shape_proto.has_value()) { - if (index != std::numeric_limits::min() && - index < tensor_shape_proto->dim_size() && index >= -tensor_shape_proto->dim_size()) { - if (index < 0) { - index = tensor_shape_proto->dim_size() + index; + if (indices.size() == 1 && + indices[0] < tensor_shape_proto->dim_size() && indices[0] >= -tensor_shape_proto->dim_size()) { + if (indices[0] < 0) { + indices[0] = tensor_shape_proto->dim_size() + indices[0]; } - auto& dim = tensor_shape_proto->dim(static_cast(index)); + auto& dim = tensor_shape_proto->dim(static_cast(indices[0])); if (dim.has_dim_value()) { output_def.inferred_scalar_value_ = dim.dim_value(); } } } - } else if (node.OpType() == "Unsqueeze") { - // Following code expands a scalr to one dimension array if all conditions are met. - // e.g. shape data is 64 -> it becomes [64] - - // Try to get the "data" input - const auto* input_0 = node.GetDefinitions().input_defs[0]; - - // Only handle the "data" input which is previously inferred from data propagation and is a scalar - if (input_0->inferred_scalar_value_.has_value()) { - // Try to get the "axes" value - int64_t axis = -1; - - // Note: Starting from opset 13, "axes" is provided as a second input to the Unsqueeze operator. - // In opset 11 and earlier, "axes" is defined as a node attribute instead. - if (node.GetDefinitions().input_defs.size() > 1) { - const auto* input_1 = node.GetDefinitions().input_defs[1]; - // Only check the case that "axes" input is an initializer - const TensorProto* initializer = this->GetConstantInitializer(input_1->Name(), true); - - if (initializer && initializer->dims_size() == 1) { - // Check if this is in-memory external data (data stored in OrtValue) - if (utils::HasExternalDataInMemory(*initializer)) { - // Try to get the OrtValue for this initializer - OrtValue ort_value; - if (this->GetOrtValueInitializer(input_1->Name(), ort_value, true)) { - const Tensor& tensor = ort_value.Get(); - const int64_t* data = tensor.Data(); - axis = *data; - } else { - // If we can't get the OrtValue, it is a bug - ORT_THROW("Initializer ", input_1->Name(), - " has in-memory external data but cannot get OrtValue during shape inference"); - } - } else if (utils::HasRawData(*initializer)) { - const std::string& raw = initializer->raw_data(); - const int64_t* data = reinterpret_cast(raw.data()); - axis = *data; - } else { - if (initializer->data_type() == TensorProto_DataType_INT32) { - std::vector values; - values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); - if (values.size() == 1) { - axis = static_cast(values[0]); - } - } else if (initializer->data_type() == TensorProto_DataType_INT64) { - std::vector values; - values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); - if (values.size() == 1) { - axis = values[0]; - } - } - } - } - } else { - const auto& attrs = node.GetAttributes(); - auto it = attrs.find("axes"); - if (it != attrs.end()) { - const auto& axes_attr = it->second; - if (axes_attr.ints_size()) { - axis = axes_attr.ints()[0]; - } - } - } - - // In this case, the axis should be 0 - if (axis == 0) { - if (!output_def.inferred_shape_values_.has_value()) { - output_def.inferred_shape_values_.emplace(); - } - output_def.inferred_shape_values_->clear_dim(); - output_def.inferred_shape_values_->add_dim()->set_dim_value(*input_0->inferred_scalar_value_); - } - } } else if (node.OpType() == "Add") { // Try to get the "A" input const auto* input_0 = node.GetDefinitions().input_defs[0]; @@ -3121,11 +3211,12 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, } } } - // If the dimension size is greater than 0, only save the inferred shape values from data propagation - // when the data has rank > 0 and all dimensions have concrete (non-symbolic) values. + // If the dimension size is greater than 0. else if (dim_size > 0) { + // Only handle that the inferred shape values from data propagation if the data has rank > 0 + // and all dimensions have concrete (non-symbolic) values. for (int i = 0; i < dim_size; ++i) { - if (!onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).has_dim_value()) { + if (!onnx_inferred_type_after_data_propagation.tensor_type().shape().dim(i).has_dim_value()) { return Status::OK(); } } @@ -3136,7 +3227,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, output_def.inferred_shape_values_->clear_dim(); for (int i = 0; i < dim_size; ++i) { - auto value = onnx_inferred_types_after_data_propagation.tensor_type().shape().dim(i).dim_value(); + auto value = onnx_inferred_type_after_data_propagation.tensor_type().shape().dim(i).dim_value(); output_def.inferred_shape_values_->add_dim()->set_dim_value(value); } } @@ -3326,6 +3417,14 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso } } + if (node.OpType() == "Shape") { + std::cout << "Shape" << std::endl; + } + + if (node.OpType() == "Size") { + std::cout << "Size" << std::endl; + } + // Apply ONNX's type/shape inference to this node. // This will call InferAndVerifySubgraphTypes if the ONNX level type/shape inferencing for the Node attempts // to do subgraph type/shape inferencing (Scan/If/Loop nodes). @@ -3378,8 +3477,11 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso auto op_formal_parameter = op.outputs().at(operand_index); const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; + const TypeProto& onnx_inferred_type_after_data_propagation = onnx_inferred_types_after_data_propagation[i]; - ORT_RETURN_IF_ERROR(SaveShapeValuesFromDataPropagation(node, *output_def, onnx_inferred_types_after_data_propagation[i])); + ORT_RETURN_IF_ERROR(SaveShapeValuesFromDataPropagation(node, + *output_def, + onnx_inferred_type_after_data_propagation)); DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; From 73e38adfdc49701438cc4cb2c71fd6c347dfe110 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 21 Oct 2025 14:17:02 -0700 Subject: [PATCH 21/23] Add clean up for inferred shape values and fix bugs --- include/onnxruntime/core/framework/tensor.h | 7 ++- include/onnxruntime/core/graph/graph.h | 3 + onnxruntime/core/graph/graph.cc | 65 ++++++++++++++------- 3 files changed, 52 insertions(+), 23 deletions(-) diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index c7f7f23f70334..11ff427597ed7 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -217,8 +217,11 @@ class Tensor final { template const T* Data() const { // Type check - ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", - "T ", "!=", dtype_); + do { + if (!(utils::IsPrimitiveDataType(dtype_))) { + throw ::onnxruntime::OnnxRuntimeException(::onnxruntime::CodeLocation("C:\\Users\\lochi\\repos\\ort\\include\\onnxruntime\\core\\framework\\tensor.h", 221, static_cast(__FUNCTION__), ::onnxruntime::GetStackTrace()), "utils::IsPrimitiveDataType(dtype_)", ::onnxruntime::MakeString("Tensor type mismatch. ", "T ", "!=", dtype_)); + } + } while (false); return reinterpret_cast(static_cast(p_data_) + byte_offset_); } diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f08bff9876494..15bb4ae802660 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1766,6 +1766,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi common::Status SaveShapeValuesFromDataPropagation(Node& node, NodeArg& output_def, const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const; + // Remove intermediate inferred shape values stored in all NodeArgs to reduce memory usage. + common::Status CleanUpShapeValuesFromDataPropagation(); + // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 17fcfa0a01e6c..86e53736179a1 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2896,8 +2896,8 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, if (initializer) { // Get shape from TensorProto and calculate element counts. - // If shape has dimension size equals zero, it means it's a scalar and has only one value. - int64_t element_cnt = 1; + // If shape has dimension size equals zero, it means it's a scalar and has only one element. + size_t element_cnt = 1; for (auto& dim : initializer->dims()) { element_cnt *= dim; } @@ -2910,12 +2910,12 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, const Tensor& tensor = ort_value.Get(); if (initializer->data_type() == TensorProto_DataType_INT32) { input_values.resize(element_cnt); - for (int64_t i = 0; i < element_cnt; ++i) { + for (size_t i = 0; i < element_cnt; ++i) { input_values[i] = static_cast(tensor.Data()[i]); } } else if (initializer->data_type() == TensorProto_DataType_INT64) { input_values.resize(element_cnt); - for (int64_t i = 0; i < element_cnt; ++i) { + for (size_t i = 0; i < element_cnt; ++i) { input_values[i] = tensor.Data()[i]; } } @@ -2928,7 +2928,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, const std::string& raw = initializer->raw_data(); const int64_t* data = reinterpret_cast(raw.data()); input_values.resize(element_cnt); - for (int64_t i = 0; i < element_cnt; ++i) { + for (size_t i = 0; i < element_cnt; ++i) { input_values[i] = data[i]; } } else { @@ -2936,7 +2936,7 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, std::vector values; values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); input_values.resize(element_cnt); - for (int64_t i = 0; i < element_cnt; ++i) { + for (size_t i = 0; i < element_cnt; ++i) { input_values[i] = static_cast(values[0]); } } else if (initializer->data_type() == TensorProto_DataType_INT64) { @@ -2947,9 +2947,11 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, } }; - // Following operators, e.g. Size, Squeeze and Unsqueeze, calling their PartialDataPropagationFunction() alone - // to get the inferred shape values is still not enough to get the proper values. - // So, ignore the inferred values from PartialDataPropagationFunction() and infer the output values here. + // For certain operators (e.g., Size, Squeeze, Unsqueeze), invoking PartialDataPropagationFunction() + // alone does not yield fully accurate inferred shape values. + // Therefore, we ignore the inferred output shape values produced by PartialDataPropagationFunction() and manually infer + // the output shape values here. + if (node.OpType() == "Size") { // Size operator generates a scalar output const auto* input_0 = node.GetDefinitions().input_defs[0]; @@ -3027,11 +3029,20 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, } output_def.inferred_shape_values_->clear_dim(); + int64_t dim_index = 0; for (const auto& dim : tensor_shape_proto.dim()) { auto value = dim.dim_value(); - if (axes_set.find(value) == axes_set.end() && value != 1) { - output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + if (axes_set.size() > 0) { + if (axes_set.find(dim_index) == axes_set.end()) { + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + } + } else { + if (value != 1) { + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + } } + + dim_index++; } } }; @@ -3127,9 +3138,9 @@ Status Graph::SaveShapeValuesFromDataPropagation(Node& node, return Status::OK(); } - // If dimension size is 0, it could indicate one of the following cases: + // For rest of the operators, if dimension size is 0, it could indicate one of the following cases: // 1. The inferred output is a scalar. - // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() cannot handle it. + // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() doesn't handle it. // // In other words, some operators' PartialDataPropagationFunction() implementations do not support // scalar inputs or outputs. In such cases, attempt data propagation manually and store the inferred @@ -3417,14 +3428,6 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso } } - if (node.OpType() == "Shape") { - std::cout << "Shape" << std::endl; - } - - if (node.OpType() == "Size") { - std::cout << "Size" << std::endl; - } - // Apply ONNX's type/shape inference to this node. // This will call InferAndVerifySubgraphTypes if the ONNX level type/shape inferencing for the Node attempts // to do subgraph type/shape inferencing (Scan/If/Loop nodes). @@ -3795,6 +3798,26 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { } } + ORT_RETURN_IF_ERROR(CleanUpShapeValuesFromDataPropagation()); + + return Status::OK(); +} + +Status Graph::CleanUpShapeValuesFromDataPropagation() { + for (auto node_index : nodes_in_topological_order_) { + auto& node = *GetNode(node_index); + + for (auto node_arg : node.MutableInputDefs()) { + node_arg->inferred_shape_values_.reset(); + node_arg->inferred_scalar_value_.reset(); + } + + for (auto node_arg : node.MutableOutputDefs()) { + node_arg->inferred_shape_values_.reset(); + node_arg->inferred_scalar_value_.reset(); + } + } + return Status::OK(); } From 8bbebf4ce36e7d61faff63642a5c050db88cee09 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 21 Oct 2025 14:18:06 -0700 Subject: [PATCH 22/23] add more tests --- .../test/framework/shape_inference_test.cc | 44 ++++++++ ...opagation_with_shape_related_nodes_v3.onnx | Bin 0 -> 1070 bytes ...propagation_with_shape_related_nodes_v3.py | 98 ++++++++++++++++++ 3 files changed, 142 insertions(+) create mode 100644 onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx create mode 100644 onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.py diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index a7910e28b1d6d..80e3c6dcd6961 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -80,6 +80,7 @@ TEST_F(ShapeInferenceTest, BasicTest) { TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { { + // Model #1 // This model contains "Shape" and "Reshape" operators. auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes.onnx"); @@ -109,6 +110,7 @@ TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { } { + // Model #2 // This model contains "Shape", "Reshape", "Gather" and "Unsqueeze" operators. auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx"); @@ -135,6 +137,48 @@ TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should have 4096 as value"; } + + { + // Model #3 + // This model extends model #2 and appends Unsqueeze -> Unsqueeze -> Squeeze -> Squeeze -> Reshape to the end. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "batch", 1) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "width", 64) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "height", 64) == nullptr); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 3) << "The output shape should have 3 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should have 4096 as value"; + } + + { + // Model #4 + // This model contains Shape, Reshape, Squeeze, Range, ReduceSum. + // It's from SoftmaxGrad_DefaultAxis test. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v4.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + // Make sure it can load the model and run shape inference without errors. + Ort::Session session(*ort_env, model_path, session_options); + } } namespace { diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v3.onnx new file mode 100644 index 0000000000000000000000000000000000000000..2889ec34afd4122eb171429d884baa0b1f052962 GIT binary patch literal 1070 zcmaiz%TC-d6o%uuB;gFDqr(Ov^}_T5L_)=8RW_B5m~ly@Au}Zjjk@Yf zu;)=Zu^q-4M(WD>eZF(C^E({j_9D;g6ICbI%egwg3t>{c$Sa?<~j z-NI{VHSQE?vQ%=eew8{tyl6BWtcd4^xXG^44srx*x!_-&%4}9+vo^i_uSuDz8`wm0 zOY#nR4C~#fsIaK`KYC`-vkv-F7omx0)e#J?tGPz{->Nwr>dBipB(*AJl^4HH(XY%Y zwtkSTl=`bQs0VgHPM|r210F%!!65~a<_r$`5o{k?O=$}uOc(hyF<0(?*5ALnZSMmMP?aLD zkroqVCyUT+lJ!<8VwDm-dW#&?!O=Hqli<73uESSkSE(&2jhUpm32A|7G-i{5=l%N> zvp9M-4$heuTOS(9NNTeg*(0e9Wn`bEHchk}GL!9iVACewv?-#oNoR5UAD-+yVG6#^ z6yoOx`)Kh!ybT%BOU2=3QchFoh5bhS$!28=A^NFmUgYHo*h|AVa6lO^x*0pTturEn J#{=-j_yF Date: Tue, 21 Oct 2025 14:20:38 -0700 Subject: [PATCH 23/23] revert --- include/onnxruntime/core/framework/tensor.h | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 11ff427597ed7..c7f7f23f70334 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -217,11 +217,8 @@ class Tensor final { template const T* Data() const { // Type check - do { - if (!(utils::IsPrimitiveDataType(dtype_))) { - throw ::onnxruntime::OnnxRuntimeException(::onnxruntime::CodeLocation("C:\\Users\\lochi\\repos\\ort\\include\\onnxruntime\\core\\framework\\tensor.h", 221, static_cast(__FUNCTION__), ::onnxruntime::GetStackTrace()), "utils::IsPrimitiveDataType(dtype_)", ::onnxruntime::MakeString("Tensor type mismatch. ", "T ", "!=", dtype_)); - } - } while (false); + ORT_ENFORCE(utils::IsPrimitiveDataType(dtype_), "Tensor type mismatch. ", + "T ", "!=", dtype_); return reinterpret_cast(static_cast(p_data_) + byte_offset_); }