diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 9a0708d72b4f8..f08bff9876494 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1760,6 +1760,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::vector& output_types, const Graph::ResolveOptions& options); + // If the shape values are inferred after executing ONNX operator's PartialDataPropagationFunction(), + // save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes + // can use them later for their TypeAndShapeInferenceFunction() and PartialDataPropagationFunction(). + common::Status SaveShapeValuesFromDataPropagation(Node& node, NodeArg& output_def, + const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const; + // Apply type-inference and type-checking to all inputs and initializers: common::Status TypeCheckInputsAndInitializers(); diff --git a/include/onnxruntime/core/graph/node_arg.h b/include/onnxruntime/core/graph/node_arg.h index 0ddf1a2b9d3de..325e9f50e4ea2 100644 --- a/include/onnxruntime/core/graph/node_arg.h +++ b/include/onnxruntime/core/graph/node_arg.h @@ -107,6 +107,9 @@ class NodeArg { /** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */ const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } + /** Gets the inferred shape values as a TensorShapeProto. */ + const std::optional& GetInferredShapeValues() const noexcept { return inferred_shape_values_; } + /** Gets a flag indicating whether this NodeArg exists or not. Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ bool Exists() const noexcept; @@ -128,6 +131,22 @@ class NodeArg { // Node arg name, type and shape. NodeArgInfo node_arg_info_; + // This variable stores the inferred shape data as a TensorShapeProto after executing + // the ONNX operator's PartialDataPropagationFunction(). + // + // Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient + // for complete shape inference. For example, the Shape operator only provides the + // output's rank (1-dimensional) but not its actual dimension values. + // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also + // be executed to obtain the concrete output shape values, allowing accurate propagation + // of shape information throughout the graph. + std::optional inferred_shape_values_; + + // This variable stores the inferred scalar output. + // It is also used for shape inference and data propagation to ensure consistent shape and + // value information throughout the graph. + std::optional inferred_scalar_value_; + // Flag indicates whether <*this> node arg exists or not. bool exists_; }; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 8b599dc86d997..17fcfa0a01e6c 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -44,6 +44,7 @@ #include "core/graph/schema_registry.h" #include "onnx/checker.h" #include "onnx/defs/parser.h" +#include "onnx/defs/tensor_proto_util.h" using namespace ONNX_NAMESPACE::checker; #endif @@ -2675,8 +2676,8 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { if (!def) return nullptr; - // only return data if it's for a constant initializer. checks for outer scope initializers - // if this is a subgraph and the name isn't found locally. + // Returns if it's a constant initializer. + // Checks for outer scope initializers if this is a subgraph and the name isn't found locally. const TensorProto* initializer = graph_.GetConstantInitializer(def->Name(), true); if (initializer != nullptr) { // Check if this is in-memory external data (data stored in OrtValue) @@ -2698,8 +2699,43 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { " has in-memory external data but cannot get OrtValue during shape inference"); } } + + return initializer; + } + + // The following code handles cases where a node stores the previously inferred output shape values in its NodeArg. + // + // For example, the Reshape operator, its shape input may come from a producer node such as a Shape operator, + // and the inferred output shape value is already stored as a TensorShapeProto in corresponding NodeArg. + // + // In such cases, the Reshape operator should convert this TensorShapeProto into a TensorProto. + // The resulting TensorProto will then be treated as an initializer during ONNX shape inference, + // allowing the real dimension values to be correctly used. + const auto& tensor_shape_proto = def->GetInferredShapeValues(); + + // Make sure the returning shape tensor as a TensorProto has rank > 0 and all the dimensions + // have values (not symbolic) + if (tensor_shape_proto.has_value() && tensor_shape_proto->dim_size() > 0) { + TensorProto tensor_proto; + tensor_proto.set_data_type(TensorProto_DataType_INT64); + tensor_proto.add_dims(tensor_shape_proto->dim_size()); + bool all_values = true; + for (const auto& dim : tensor_shape_proto->dim()) { + if (dim.has_dim_value()) { + tensor_proto.add_int64_data(dim.dim_value()); + } else { + all_values = false; + break; + } + } + + if (all_values) { + temp_tensor_protos_.push_back(std::make_unique(std::move(tensor_proto))); + return temp_tensor_protos_.back().get(); + } } - return initializer; + + return nullptr; } // ORT does not implement partial data propagation yet so just return nullptr. @@ -2742,9 +2778,463 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext { // These need to outlive the shape inference call, so we store them here // Inference is per node and the instance of this context is on the stack, // so this is safe. + // It can also be used to temporarily save the inferred shape values as a TensorProto. mutable InlinedVector> temp_tensor_protos_; }; +// An implementation of the DataPropagationContext interface optional by operator-specific +// shape inference for onnxruntime graphs. +// Please see the description and usage of ONNX's data propagation here: +// https://github.com/onnx/onnx/blob/main/onnx/defs/shape_inference.h#L117-L127 +class DataPropagationContextImpl : public ONNX_NAMESPACE::DataPropagationContext { + public: + DataPropagationContextImpl(Node& node) noexcept : node_(node) { + node_output_types_.resize(node.OutputDefs().size()); + } + + const AttributeProto* getAttribute(const std::string& name) const override { + auto& attribute_value_map = node_.GetAttributes(); + auto iter = attribute_value_map.find(name); + if (iter == attribute_value_map.end()) { + return nullptr; + } + return &iter->second; + } + + size_t getNumInputs() const noexcept override { + return node_.InputDefs().size(); + } + + const TypeProto* getInputType(size_t index) const override { + if (index >= getNumInputs()) { + return nullptr; + } + + const TypeProto* type = nullptr; + auto p_node_arg = node_.InputDefs().at(index); + if ((nullptr != p_node_arg) && p_node_arg->Exists()) { + type = p_node_arg->TypeAsProto(); + } + + return type; + } + + size_t getNumOutputs() const noexcept override { + return node_output_types_.size(); + } + + const TypeProto* getOutputType(size_t index) const override { + if (index >= getNumOutputs()) { + return nullptr; + } + + return &node_output_types_[index]; + } + + const TensorShapeProto* getInputData(size_t index) override { + if (index >= getNumInputs()) { + return nullptr; + } + + auto def = node_.InputDefs()[index]; + if (!def) + return nullptr; + + auto has_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> bool { + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return false; + } + } + return true; + } + + return false; + }; + + // Try to get the previously inferred shape values that stored in NodeArg's inferred_shape_values_ + auto& tensor_shape_proto = def->GetInferredShapeValues(); + if (tensor_shape_proto.has_value() && has_shape_values(*tensor_shape_proto)) { + return &*tensor_shape_proto; + } + + return nullptr; + } + + void addOutputData(size_t index, TensorShapeProto&& tsp) override { + if (index >= node_output_types_.size()) return; + + TypeProto& type_proto = node_output_types_[index]; + *type_proto.mutable_tensor_type()->mutable_shape() = std::move(tsp); + } + + void RunInferencing() { + auto schema = node_.Op(); + if (nullptr != schema) { + schema->GetDataPropagationFunction()(*this); + } + } + + std::vector InferredOutputTypes() const { return node_output_types_; } + + private: + Node& node_; + std::vector node_output_types_; +}; + +Status Graph::SaveShapeValuesFromDataPropagation(Node& node, + NodeArg& output_def, + const TypeProto& onnx_inferred_type_after_data_propagation) const { + auto dim_size = onnx_inferred_type_after_data_propagation.tensor_type().shape().dim_size(); + + // Helper function to get the input value if it's a initializer. + auto get_initialized_input_values = [&](const std::string& input_name, std::vector& input_values) -> void { + const TensorProto* initializer = this->GetConstantInitializer(input_name, true); + + if (initializer) { + // Get shape from TensorProto and calculate element counts. + // If shape has dimension size equals zero, it means it's a scalar and has only one value. + int64_t element_cnt = 1; + for (auto& dim : initializer->dims()) { + element_cnt *= dim; + } + + // Check if this is in-memory external data (data stored in OrtValue) + if (utils::HasExternalDataInMemory(*initializer)) { + // Try to get the OrtValue for this initializer + OrtValue ort_value; + if (this->GetOrtValueInitializer(input_name, ort_value, true)) { + const Tensor& tensor = ort_value.Get(); + if (initializer->data_type() == TensorProto_DataType_INT32) { + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = static_cast(tensor.Data()[i]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = tensor.Data()[i]; + } + } + } else { + // If we can't get the OrtValue, it is a bug + ORT_THROW("Initializer ", input_name, + " has in-memory external data but cannot get OrtValue during shape inference"); + } + } else if (utils::HasRawData(*initializer)) { + const std::string& raw = initializer->raw_data(); + const int64_t* data = reinterpret_cast(raw.data()); + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = data[i]; + } + } else { + if (initializer->data_type() == TensorProto_DataType_INT32) { + std::vector values; + values.assign(initializer->int32_data().begin(), initializer->int32_data().end()); + input_values.resize(element_cnt); + for (int64_t i = 0; i < element_cnt; ++i) { + input_values[i] = static_cast(values[0]); + } + } else if (initializer->data_type() == TensorProto_DataType_INT64) { + input_values.resize(element_cnt); + input_values.assign(initializer->int64_data().begin(), initializer->int64_data().end()); + } + } + } + }; + + // Following operators, e.g. Size, Squeeze and Unsqueeze, calling their PartialDataPropagationFunction() alone + // to get the inferred shape values is still not enough to get the proper values. + // So, ignore the inferred values from PartialDataPropagationFunction() and infer the output values here. + if (node.OpType() == "Size") { + // Size operator generates a scalar output + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->inferred_shape_values_; + auto get_num_elements = [&](const TensorShapeProto& tensor_shape_proto) -> void { + int64_t num_elements = 1; + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return; + } + num_elements *= dim.dim_value(); + } + output_def.inferred_scalar_value_ = num_elements; + } + }; + + if (tensor_shape_proto.has_value()) { + get_num_elements(*tensor_shape_proto); + } + + return Status::OK(); + + } else if (node.OpType() == "Squeeze") { + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->inferred_shape_values_; + + auto update_output_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> void { + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return; + } + } + } + + if (tensor_shape_proto.dim_size() == 1) { + output_def.inferred_scalar_value_ = tensor_shape_proto.dim(0).dim_value(); + } else if (tensor_shape_proto.dim_size() > 1) { + // Get axes value + std::vector axes; + std::unordered_set axes_set; + + // Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node.GetDefinitions().input_defs.size() > 1) { + const auto* input_1 = node.GetDefinitions().input_defs[1]; + get_initialized_input_values(input_1->Name(), axes); + } else { + const auto& attrs = node.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (const auto& i : axes_attr.ints()) { + axes.push_back(i); + } + } + } + + for (size_t i = 0; i < axes.size(); ++i) { + // Negative value means counting dimensions from the back. + if (axes[i] < 0) { + axes_set.insert(axes[i] + tensor_shape_proto.dim_size()); + } else { + axes_set.insert(axes[i]); + } + } + + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); + + for (const auto& dim : tensor_shape_proto.dim()) { + auto value = dim.dim_value(); + if (axes_set.find(value) == axes_set.end() && value != 1) { + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + } + } + } + }; + + if (tensor_shape_proto.has_value()) { + update_output_shape_values(*tensor_shape_proto); + } + + return Status::OK(); + + } else if (node.OpType() == "Unsqueeze") { + const auto* input_0 = node.GetDefinitions().input_defs[0]; + auto& tensor_shape_proto = input_0->inferred_shape_values_; + + auto update_output_shape_values = [&](const TensorShapeProto& tensor_shape_proto) -> void { + // The TensorShapeProto (inferred shape values) should have rank > 0 and + // all the dimensions have values (not symbolic) + if (tensor_shape_proto.dim_size() > 0) { + for (const auto& dim : tensor_shape_proto.dim()) { + if (!dim.has_dim_value()) { + return; + } + } + } + + if (tensor_shape_proto.dim_size() > 0) { + // Get axes value + std::vector axes; + std::unordered_set axes_set; + + // Note: Starting from opset 13, "axes" is provided as a second input to the Squeeze operator. + // In opset 11 and earlier, "axes" is defined as a node attribute instead. + if (node.GetDefinitions().input_defs.size() > 1) { + const auto* input_1 = node.GetDefinitions().input_defs[1]; + get_initialized_input_values(input_1->Name(), axes); + } else { + const auto& attrs = node.GetAttributes(); + auto it = attrs.find("axes"); + if (it != attrs.end()) { + const auto& axes_attr = it->second; + for (const auto& i : axes_attr.ints()) { + axes.push_back(i); + } + } + } + + // axes is required, if not provided just do nothing and return. + if (axes.empty()) { + return; + } + + for (size_t i = 0; i < axes.size(); ++i) { + // Negative value means counting dimensions from the back. + if (axes[i] < 0) { + axes_set.insert(axes[i] + tensor_shape_proto.dim_size()); + } else { + axes_set.insert(axes[i]); + } + } + + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); + + int64_t axis = 0; + for (const auto& dim : tensor_shape_proto.dim()) { + if (axes_set.find(axis) != axes_set.end()) { + output_def.inferred_shape_values_->add_dim()->set_dim_value(1); + } + + auto value = dim.dim_value(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + + axis += 1; + } + } + }; + + if (dim_size == 0 && input_0->inferred_scalar_value_.has_value()) { + // Following code expands a scalr to one dimension array, e.g. shape data is 64 -> it becomes [64] + // In this case, the axis should be 0 + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + output_def.inferred_shape_values_->clear_dim(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(*input_0->inferred_scalar_value_); + + } else if (tensor_shape_proto.has_value()) { + update_output_shape_values(*tensor_shape_proto); + } + + return Status::OK(); + } + + // If dimension size is 0, it could indicate one of the following cases: + // 1. The inferred output is a scalar. + // 2. The node's input is a scalar, and the operator's PartialDataPropagationFunction() cannot handle it. + // + // In other words, some operators' PartialDataPropagationFunction() implementations do not support + // scalar inputs or outputs. In such cases, attempt data propagation manually and store the inferred + // scalar value in the NodeArg if any. + if (dim_size == 0) { + if (node.OpType() == "Gather") { + // Following code extracts an element from a 1D array if all conditions are met. + // e.g. + // shape data is [1, 3, 64, 64] -> gets 64 if the index is 2. + // shape data is [1, 3, 64, 64] -> gets 3 if the index is 1. + + // Try to get the "data" input + // Note: The "data" input should be a one dimension array in this case. + const auto* input_0 = node.GetDefinitions().input_defs[0]; + + // Try to get the "indices" input + // Note: The "indices" input should be a scalar, otherwise, if it's a tensor with dimension size > 0, + // the operator's type and shape inference should have inferred the output shape value. + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + // The "indices" should be an initializer because it's a scalar in this case. + std::vector indices; + get_initialized_input_values(input_1->Name(), indices); + + // Get the previously inferred dimension values + auto& tensor_shape_proto = input_0->inferred_shape_values_; + + // Save the dimension value in the NodeArg. + // Index value is expected to be within bounds [-s, s-1] along axis of size s + if (tensor_shape_proto.has_value()) { + if (indices.size() == 1 && + indices[0] < tensor_shape_proto->dim_size() && indices[0] >= -tensor_shape_proto->dim_size()) { + if (indices[0] < 0) { + indices[0] = tensor_shape_proto->dim_size() + indices[0]; + } + + auto& dim = tensor_shape_proto->dim(static_cast(indices[0])); + if (dim.has_dim_value()) { + output_def.inferred_scalar_value_ = dim.dim_value(); + } + } + } + + } else if (node.OpType() == "Add") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ + *input_1->inferred_scalar_value_; + } + } else if (node.OpType() == "Sub") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ - *input_1->inferred_scalar_value_; + } + } else if (node.OpType() == "Mul") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ * (*input_1->inferred_scalar_value_); + } + } else if (node.OpType() == "Div") { + // Try to get the "A" input + const auto* input_0 = node.GetDefinitions().input_defs[0]; + // Try to get the "B" input + const auto* input_1 = node.GetDefinitions().input_defs[1]; + + if (input_0->inferred_scalar_value_.has_value() && input_1->inferred_scalar_value_.has_value()) { + output_def.inferred_scalar_value_ = *input_0->inferred_scalar_value_ / *input_1->inferred_scalar_value_; + } + } + } + // If the dimension size is greater than 0. + else if (dim_size > 0) { + // Only handle that the inferred shape values from data propagation if the data has rank > 0 + // and all dimensions have concrete (non-symbolic) values. + for (int i = 0; i < dim_size; ++i) { + if (!onnx_inferred_type_after_data_propagation.tensor_type().shape().dim(i).has_dim_value()) { + return Status::OK(); + } + } + + if (!output_def.inferred_shape_values_.has_value()) { + output_def.inferred_shape_values_.emplace(); + } + + output_def.inferred_shape_values_->clear_dim(); + for (int i = 0; i < dim_size; ++i) { + auto value = onnx_inferred_type_after_data_propagation.tensor_type().shape().dim(i).dim_value(); + output_def.inferred_shape_values_->add_dim()->set_dim_value(value); + } + } + + return Status::OK(); +} + Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, const std::vector& input_types, std::vector& output_types, @@ -2927,6 +3417,14 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso } } + if (node.OpType() == "Shape") { + std::cout << "Shape" << std::endl; + } + + if (node.OpType() == "Size") { + std::cout << "Size" << std::endl; + } + // Apply ONNX's type/shape inference to this node. // This will call InferAndVerifySubgraphTypes if the ONNX level type/shape inferencing for the Node attempts // to do subgraph type/shape inferencing (Scan/If/Loop nodes). @@ -2936,11 +3434,20 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso // returned here. SubgraphInferencingFunc func(Graph::InferAndVerifySubgraphTypes); InferenceContextImpl context(node, func, *this, options); + DataPropagationContextImpl data_propagation_context(node); { auto status = Status::OK(); ORT_TRY { context.RunInferencing(); + + // Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient + // for complete shape inference. For example, the Shape operator only provides the + // output's rank (1-dimensional) but not its actual dimension values. + // The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also + // be executed to obtain the concrete output shape values, allowing accurate propagation + // of shape information throughout the graph. + data_propagation_context.RunInferencing(); } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { @@ -2952,6 +3459,8 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso const auto& onnx_inferred_types(context.InferredOutputTypes()); + const auto& onnx_inferred_types_after_data_propagation(data_propagation_context.InferredOutputTypes()); + // Infer and verify node output arg type information. int i = -1; for (auto& output_def : node.MutableDefinitions().output_defs) { @@ -2968,6 +3477,12 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso auto op_formal_parameter = op.outputs().at(operand_index); const TypeProto& onnx_inferred_type = onnx_inferred_types[i]; + const TypeProto& onnx_inferred_type_after_data_propagation = onnx_inferred_types_after_data_propagation[i]; + + ORT_RETURN_IF_ERROR(SaveShapeValuesFromDataPropagation(node, + *output_def, + onnx_inferred_type_after_data_propagation)); + DataType existing_type = output_def->Type(); DataType inferred_type = nullptr; diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index 2d5c3a43ee8ed..a7910e28b1d6d 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -16,6 +16,8 @@ using namespace ONNX_NAMESPACE; +extern std::unique_ptr ort_env; + namespace onnxruntime { namespace test { @@ -76,6 +78,65 @@ TEST_F(ShapeInferenceTest, BasicTest) { CheckShapeEquality(InputShape(node), OutputShape(node)); } +TEST(ShapeInferenceV2Test, PartialDataPropagationTest) { + { + // This model contains "Shape" and "Reshape" operators. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "batch", 1) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "width", 64) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "height", 64) == nullptr); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 4) << "The output shape should have 4 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 64) << "The second dimension should have 64 as value"; + EXPECT_TRUE(output_shape[3] == 64) << "The second dimension should have 64 as value"; + } + + { + // This model contains "Shape", "Reshape", "Gather" and "Unsqueeze" operators. + auto model_path = ORT_TSTR("testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx"); + + Ort::SessionOptions session_options{}; + session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL); + + const OrtApi* g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "batch", 1) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "width", 64) == nullptr); + ORT_ENFORCE(g_ort->AddFreeDimensionOverrideByName(session_options, "height", 64) == nullptr); + + // Even though all graph optimizations are disabled, the free dimension override is still enabled by default. + // The shape of graph's output should be correctly inferred by shape inference and data propagation. + Ort::Session session(*ort_env, model_path, session_options); + + // This graph only has one output + ORT_ENFORCE(session.GetOutputCount() == 1); + + Ort::TypeInfo type_info = session.GetOutputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + std::vector output_shape = tensor_info.GetShape(); + EXPECT_TRUE(output_shape.size() == 3) << "The output shape should have 3 dimensions"; + EXPECT_TRUE(output_shape[0] == 1) << "The first dimension should have 1 as value"; + EXPECT_TRUE(output_shape[1] == 3) << "The second dimension should have 3 as value"; + EXPECT_TRUE(output_shape[2] == 4096) << "The second dimension should have 4096 as value"; + } +} + namespace { struct MyCustomKernelWithOptionalInput { MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx new file mode 100644 index 0000000000000..e18aa31e414e5 Binary files /dev/null and b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.onnx differ diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py new file mode 100644 index 0000000000000..6537a3cd357c3 --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes.py @@ -0,0 +1,53 @@ +import onnx +from onnx import TensorProto, helper + +# 1. Define graph input with symbolic shape ['batch', 3, 'width', 'height'] +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) + +# 2. Define intermediate and output tensors +shape_out = helper.make_tensor_value_info("shape_out", TensorProto.INT64, [4]) # Shape output +reshape_a_out = helper.make_tensor_value_info("reshape_a_out", TensorProto.FLOAT, None) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, None) + +# 3. Create the initializer for Reshape A's 'shape' input: [0, 32, -1] +shape_initializer = helper.make_tensor( + name="reshape_a_shape", + data_type=TensorProto.INT64, + dims=[3], + vals=[0, 32, -1], +) + +# 4. Create nodes: +# Shape node +shape_node = helper.make_node("Shape", inputs=["input"], outputs=["shape_out"], name="ShapeNode") + +# Reshape A node: takes input + constant shape +reshape_a_node = helper.make_node( + "Reshape", inputs=["input", "reshape_a_shape"], outputs=["reshape_a_out"], name="ReshapeA" +) + +# Reshape B node: takes Shape + ReshapeA outputs, outputs final output +reshape_b_node = helper.make_node("Reshape", inputs=["reshape_a_out", "shape_out"], outputs=["output"], name="ReshapeB") + +# 5. Assemble the graph +graph = helper.make_graph( + nodes=[shape_node, reshape_a_node, reshape_b_node], + name="Shape_Reshape_Model", + inputs=[input_tensor], + outputs=[output_tensor], + initializer=[shape_initializer], + value_info=[shape_out, reshape_a_out], +) + +# 6. Define the model (set IR and opset) +model = helper.make_model( + graph, + opset_imports=[helper.make_operatorsetid("", 18)], + producer_name="onnx-example-generator", +) +model.ir_version = onnx.IR_VERSION + +# 7. Save the model +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes.onnx") + +print("Model saved to test_shape_data_propagation_with_shape_related_nodes.onnx") diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx new file mode 100644 index 0000000000000..ff41075ff64cc Binary files /dev/null and b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.onnx differ diff --git a/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py new file mode 100644 index 0000000000000..7cfbcca8d4d03 --- /dev/null +++ b/onnxruntime/test/testdata/test_shape_data_propagation_with_shape_related_nodes_v2.py @@ -0,0 +1,59 @@ +import onnx +from onnx import TensorProto, helper + +# === Graph input/output === +input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", 3, "width", "height"]) +output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", 3, "width*height"]) + +# === Initializers === +B = helper.make_tensor("B", TensorProto.FLOAT, [], [1.0]) + +# Gather indices +g0_idx = helper.make_tensor("g0_idx", TensorProto.INT64, [], [0]) +g1_idx = helper.make_tensor("g1_idx", TensorProto.INT64, [], [1]) +g2_idx = helper.make_tensor("g2_idx", TensorProto.INT64, [], [2]) +g3_idx = helper.make_tensor("g3_idx", TensorProto.INT64, [], [3]) + +# Unsqueeze axes tensors +axes_unsq0 = helper.make_tensor("axes_unsq0", TensorProto.INT64, [1], [0]) +axes_unsq1 = helper.make_tensor("axes_unsq1", TensorProto.INT64, [1], [0]) +axes_unsq2 = helper.make_tensor("axes_unsq2", TensorProto.INT64, [1], [0]) + +# === Nodes === +div = helper.make_node("Div", ["input", "B"], ["div_out"]) + +# Two Shape nodes from Div +shape_left = helper.make_node("Shape", ["div_out"], ["shape_left_out"]) +shape_right = helper.make_node("Shape", ["div_out"], ["shape_right_out"]) + +# Left Shape path +gather0 = helper.make_node("Gather", ["shape_left_out", "g0_idx"], ["g0_out"]) +gather1 = helper.make_node("Gather", ["shape_left_out", "g1_idx"], ["g1_out"]) +unsq0 = helper.make_node("Unsqueeze", ["g0_out", "axes_unsq0"], ["u0_out"]) +unsq1 = helper.make_node("Unsqueeze", ["g1_out", "axes_unsq1"], ["u1_out"]) + +# Right Shape path +gather2 = helper.make_node("Gather", ["shape_right_out", "g2_idx"], ["g2_out"]) +gather3 = helper.make_node("Gather", ["shape_right_out", "g3_idx"], ["g3_out"]) +mul = helper.make_node("Mul", ["g2_out", "g3_out"], ["mul_out"]) +unsq2 = helper.make_node("Unsqueeze", ["mul_out", "axes_unsq2"], ["u2_out"]) + +# Combine +concat = helper.make_node("Concat", ["u0_out", "u1_out", "u2_out"], ["concat_out"], axis=0) +reshape = helper.make_node("Reshape", ["div_out", "concat_out"], ["output"]) + +# === Graph === +graph = helper.make_graph( + [div, shape_left, shape_right, gather0, gather1, gather2, gather3, mul, unsq0, unsq1, unsq2, concat, reshape], + "Div_Shape_Gather_Concat_Reshape", + [input_tensor], + [output_tensor], + initializer=[B, g0_idx, g1_idx, g2_idx, g3_idx, axes_unsq0, axes_unsq1, axes_unsq2], +) + +# === Model === +model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)], producer_name="onnx-example") +onnx.checker.check_model(model) +onnx.save(model, "test_shape_data_propagation_with_shape_related_nodes_v2.onnx") + +print("✅ Model saved as test_shape_data_propagation_with_shape_related_nodes_v2.onnx")