Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
const Graph::ResolveOptions& options);

// If the shape values are inferred after executing ONNX operator's PartialDataPropagationFunction(),
// save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes
// can use them later for their TypeAndShapeInferenceFunction() and PartialDataPropagationFunction().
common::Status SaveShapeValuesFromDataPropagation(Node& node, NodeArg& output_def,
const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const;

// Apply type-inference and type-checking to all inputs and initializers:
common::Status TypeCheckInputsAndInitializers();

Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/graph/node_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ class NodeArg {
/** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }

/** Gets the inferred shape values as a TensorShapeProto. */
const std::optional<ONNX_NAMESPACE::TensorShapeProto>& GetInferredShapeValues() const noexcept { return inferred_shape_values_; }

/** Gets a flag indicating whether this NodeArg exists or not.
Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
bool Exists() const noexcept;
Expand All @@ -128,6 +131,22 @@ class NodeArg {
// Node arg name, type and shape.
NodeArgInfo node_arg_info_;

// This variable stores the inferred shape data as a TensorShapeProto after executing
// the ONNX operator's PartialDataPropagationFunction().
//
// Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient
// for complete shape inference. For example, the Shape operator only provides the
// output's rank (1-dimensional) but not its actual dimension values.
// The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also
// be executed to obtain the concrete output shape values, allowing accurate propagation
// of shape information throughout the graph.
std::optional<ONNX_NAMESPACE::TensorShapeProto> inferred_shape_values_;

// This variable stores the inferred scalar output.
// It is also used for shape inference and data propagation to ensure consistent shape and
// value information throughout the graph.
std::optional<int64_t> inferred_scalar_value_;

// Flag indicates whether <*this> node arg exists or not.
bool exists_;
};
Expand Down
Loading
Loading