Skip to content

Commit e657840

Browse files
authored
[pass] Preserve value metadata in LiftConstantsToInitializersPass (#204)
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent ec6f7be commit e657840

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/onnx_ir/passes/common/constant_manipulation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
6969
shape=tensor.shape, # type: ignore[arg-type]
7070
type=ir.TensorType(tensor.dtype),
7171
const_value=tensor,
72+
# Preserve metadata from Constant value into the onnx model
73+
metadata_props=node.outputs[0].metadata_props.copy(),
7274
)
75+
# Preserve value meta from the Constant output for intermediate analysis
76+
initializer.meta.update(node.outputs[0].meta)
77+
7378
assert node.graph is not None
7479
node.graph.register_initializer(initializer)
7580
# Replace the constant node with the initializer

src/onnx_ir/passes/common/constant_manipulation_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
3636
const_node = ir.node(
3737
"Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1
3838
)
39+
const_node.outputs[0].meta["meta_key"] = "meta_val"
40+
const_node.outputs[0].metadata_props["metadata_key"] = "metadata_val"
3941
add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]])
4042
mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]])
4143

@@ -72,6 +74,15 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
7274
self.assertEqual(
7375
len([node for node in result.model.graph if node.op_type == "Constant"]), 0
7476
)
77+
# Metadata is preserved
78+
self.assertEqual(
79+
result.model.graph.initializers["val_0"].meta,
80+
{"meta_key": "meta_val"},
81+
)
82+
self.assertEqual(
83+
result.model.graph.initializers["val_0"].metadata_props,
84+
{"metadata_key": "metadata_val"},
85+
)
7586

7687
@parameterized.parameterized.expand(
7788
[

0 commit comments

Comments
 (0)