Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d345c26
update
chilo-ms Oct 6, 2025
cb52305
Modify to make it work correctly
chilo-ms Oct 6, 2025
fb5a424
Save scalar value after data propagation
chilo-ms Oct 7, 2025
bd14235
correctly save the inferred shape values in NodeArg for other ops' da…
chilo-ms Oct 9, 2025
e2c532f
refactor the code and add comments
chilo-ms Oct 9, 2025
f79f65b
add tests
chilo-ms Oct 9, 2025
dbe36a8
update error message
chilo-ms Oct 9, 2025
d310a26
fix warnings from pipelines
chilo-ms Oct 10, 2025
3fa22a0
handle for Unsqueeze 11 and eariler that axes is node attribute
chilo-ms Oct 11, 2025
92c71be
fix issue for 'indices' input to Gather has negative value
chilo-ms Oct 13, 2025
f652eec
address issue from pipeline
chilo-ms Oct 14, 2025
78ed993
fix pipeline warning
chilo-ms Oct 14, 2025
035a66e
fix warning in pipeline
chilo-ms Oct 14, 2025
e3e26a0
refactor the code
chilo-ms Oct 14, 2025
6d478cb
Merge branch 'main' into chi/fix_shape_inference
chilo-ms Oct 14, 2025
9dc51aa
address reviewer's comments
chilo-ms Oct 14, 2025
1ec4bf9
fix bug after using std::optional to store value
chilo-ms Oct 14, 2025
1391a15
address reviewer's comments
chilo-ms Oct 14, 2025
21d5e33
add check for get initializer as in-memory external
chilo-ms Oct 14, 2025
e59216b
fix type issue
chilo-ms Oct 15, 2025
eb8be39
Merge branch 'main' into chi/fix_shape_inference
chilo-ms Oct 16, 2025
2616c6f
refactor code and address corner case
chilo-ms Oct 20, 2025
73e38ad
Add clean up for inferred shape values and fix bugs
chilo-ms Oct 21, 2025
8bbebf4
add more tests
chilo-ms Oct 21, 2025
de2ad68
revert
chilo-ms Oct 21, 2025
24ed992
Add test model
chilo-ms Oct 21, 2025
40c8d09
lintrunner -a
chilo-ms Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
std::vector<const ONNX_NAMESPACE::TypeProto*>& output_types,
const Graph::ResolveOptions& options);

// If the dimension values are inferred after executing ONNX operator's PartialDataPropagationFunction(),
// save them to the output NodeArg as a TensorShapeProto or a scalar value so that downstream (consumer) nodes
// can use them later for their PartialDataPropagationFunction() and PartialDataPropagationFunction().
common::Status SaveValuesFromDataPropagation(Node& node, NodeArg& output_def,
const ONNX_NAMESPACE::TypeProto& propagated_value_as_type_proto) const;

// Apply type-inference and type-checking to all inputs and initializers:
common::Status TypeCheckInputsAndInitializers();

Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/graph/node_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@
/** Gets this NodeArg as a NodeArgInfo, AKA ValueInfoProto. */
const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; }

/** Gets the inferred shape as a TensorShapeProto. */
const ONNX_NAMESPACE::TensorShapeProto& GetValuesAfterDataPropagation() const noexcept { return values_after_data_propagation_; }

/** Gets a flag indicating whether this NodeArg exists or not.
Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */
bool Exists() const noexcept;
Expand All @@ -128,6 +131,22 @@
// Node arg name, type and shape.
NodeArgInfo node_arg_info_;

// This variable stores the inferred shape data as a TensorShapeProto after executing
// the ONNX operator's PartialDataPropagationFunction().
//
// Calling an operator's TypeAndShapeInferenceFunction() alone is sometimes insufficient
// for complete shape inference. For example, the Shape operator only provides the
// output's rank (1-dimensional) but not its actual dimension values.
// The PartialDataPropagationFunction(), defined in the ONNX operator schema, must also
// be executed to obtain the concrete output shape values, allowing accurate propagation
// of shape information throughout the graph.
ONNX_NAMESPACE::TensorShapeProto values_after_data_propagation_;

// This variable stores the inferred scalar output.
// It is also used for shape inference and data propagation to ensure consistent shape and
// value information throughout the graph.
int64_t scalar_value_after_data_propagation_ = std::numeric_limits<int64_t>::min();

Check warning on line 148 in include/onnxruntime/core/graph/node_arg.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/graph/node_arg.h:148: Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]

// Flag indicates whether <*this> node arg exists or not.
bool exists_;
};
Expand Down
Loading
Loading