Skip to content

Commit 6177a78

Browse files
Support export program in intermediate numeric discrepancy detector (#12944)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12581 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/27/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/27/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/26/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/27/orig @diff-train-skip-merge --------- Co-authored-by: gasoonjia <gasoonjia@icloud.com>
1 parent c5dd931 commit 6177a78

File tree

3 files changed

+98
-19
lines changed

3 files changed

+98
-19
lines changed

devtools/inspector/_inspector.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
map_runtime_aot_intermediate_outputs,
6363
merge_runtime_overlapping_debug_handles,
6464
ProgramOutput,
65+
propagate_back_debug_handle,
6566
RESERVED_FRAMEWORK_EVENT_NAMES,
6667
TimeScale,
6768
verify_debug_data_equivalence,
@@ -1166,7 +1167,18 @@ def _get_aot_intermediate_outputs_and_op_names(
11661167
"""
11671168
if self._etrecord._representative_inputs is None:
11681169
return {}, {}
1169-
export_program = self._etrecord.edge_dialect_program
1170+
1171+
export_program = None
1172+
1173+
# Will use the exported program to extract intermediate output if and only if exported_program has been provided, and it is the greatest ancestor of the edge_dialect_program
1174+
if self._etrecord.exported_program and propagate_back_debug_handle(
1175+
self._etrecord.exported_program,
1176+
self._etrecord.export_graph_id,
1177+
self._etrecord.edge_dialect_program,
1178+
):
1179+
export_program = self._etrecord.exported_program
1180+
else:
1181+
export_program = self._etrecord.edge_dialect_program
11701182
graph_module = export_program.module()
11711183
aot_debug_handle_to_op_name = get_aot_debug_handle_to_op_name_mapping(
11721184
graph_module

devtools/inspector/tests/inspector_test.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from executorch.devtools import generate_etrecord, parse_etrecord
2626
from executorch.devtools.debug_format.et_schema import OperatorNode
2727
from executorch.devtools.etdump.schema_flatcc import ProfileEvent
28-
from executorch.devtools.etrecord._etrecord import ETRecord
2928
from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
3029

3130
from executorch.devtools.inspector import (
@@ -480,7 +479,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self):
480479
events=events,
481480
)
482481

483-
def test_etrecord_populates_correct_aot_intermediate_outputs(self):
482+
def test_etrecord_populates_correct_edge_dialect_aot_intermediate_outputs(self):
484483
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
485484
etrecord_path = tmp_file.name
486485
mod = model_registry["ConvLinearModel"]()
@@ -513,15 +512,11 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
513512
etdump_path=ETDUMP_PATH,
514513
etrecord=etrecord_path,
515514
)
516-
etrecord = ETRecord(
517-
edge_dialect_program=inspector_instance._etrecord.edge_dialect_program,
518-
graph_map=inspector_instance._etrecord.graph_map,
519-
_debug_handle_map=inspector_instance._etrecord._debug_handle_map,
520-
_delegate_map=inspector_instance._etrecord._delegate_map,
521-
_reference_outputs=inspector_instance._etrecord._reference_outputs,
522-
_representative_inputs=aten_model.example_inputs[0],
515+
516+
inspector_instance._etrecord._representative_inputs = (
517+
aten_model.example_inputs[0]
523518
)
524-
inspector_instance._etrecord = etrecord
519+
525520
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
526521
inspector_instance._get_aot_intermediate_outputs_and_op_names()
527522
)
@@ -534,7 +529,61 @@ def test_etrecord_populates_correct_aot_intermediate_outputs(self):
534529

535530
self.assertTrue(
536531
check_if_debug_handle_to_op_names_match(
537-
"ConvLinearModel", aot_debug_handle_to_op_names
532+
aot_debug_handle_to_op_names,
533+
mod.get_edge_dialect_expected_debug_handle_to_op_names(),
534+
)
535+
)
536+
537+
def test_etrecord_populates_correct_export_program_aot_intermediate_outputs(self):
538+
with tempfile.NamedTemporaryFile(suffix=".bin") as tmp_file:
539+
etrecord_path = tmp_file.name
540+
mod = model_registry["ConvLinearModel"]()
541+
input_tensor = mod.get_input()
542+
aten_model: ExportedProgram = export(mod, (input_tensor,), strict=True)
543+
edge_program_manager: EdgeProgramManager = to_edge(aten_model)
544+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
545+
et_program_manager: ExecutorchProgramManager = (
546+
edge_program_manager.to_executorch()
547+
)
548+
# Generate ETRecord with the exported program
549+
generate_etrecord(
550+
etrecord_path,
551+
edge_program_manager_copy,
552+
et_program_manager,
553+
exported_program=aten_model,
554+
)
555+
with patch.object(
556+
Inspector, "_consume_etrecord", return_value=None
557+
), patch.object(
558+
_inspector, "gen_etdump_object", return_value=None
559+
), patch.object(
560+
EventBlock, "_gen_from_etdump"
561+
), patch.object(
562+
_inspector, "gen_graphs_from_etrecord"
563+
):
564+
# Call the constructor of Inspector
565+
inspector_instance = Inspector(
566+
etdump_path=ETDUMP_PATH,
567+
etrecord=etrecord_path,
568+
)
569+
570+
inspector_instance._etrecord._representative_inputs = (
571+
aten_model.example_inputs[0]
572+
)
573+
574+
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
575+
inspector_instance._get_aot_intermediate_outputs_and_op_names()
576+
)
577+
self.assertTrue(
578+
check_if_intermediate_outputs_match(
579+
aot_intermediate_outputs,
580+
mod.get_exported_program_expected_intermediate_outputs(),
581+
)
582+
)
583+
self.assertTrue(
584+
check_if_debug_handle_to_op_names_match(
585+
aot_debug_handle_to_op_names,
586+
mod.get_exported_program_expected_debug_handle_to_op_names(),
538587
)
539588
)
540589

devtools/inspector/tests/inspector_test_utils.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_edge_dialect_expected_intermediate_outputs():
7979
}
8080

8181
@staticmethod
82-
def get_expected_debug_handle_to_op_names():
82+
def get_edge_dialect_expected_debug_handle_to_op_names():
8383
"""
8484
Returns the expected debug handle and op names mapping for this model for the given input.
8585
"""
@@ -100,7 +100,7 @@ def get_expected_debug_handle_to_op_names():
100100
@staticmethod
101101
def get_exported_program_expected_intermediate_outputs():
102102
"""
103-
Returns the expected outputs of the debug handles and intermediate output mapping for edge dialect graph of this model for the given input.
103+
Returns the expected outputs of the debug handles and intermediate output mapping for export graph of this model for the given input.
104104
"""
105105
return {
106106
(UNSET_DEBUG_HANDLE,): torch.tensor([[5.4000, 13.5200]]),
@@ -117,6 +117,26 @@ def get_exported_program_expected_intermediate_outputs():
117117
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
118118
}
119119

120+
@staticmethod
121+
def get_exported_program_expected_debug_handle_to_op_names():
122+
"""
123+
Returns the expected debug handle and op name mapping for this model for the given input.
124+
"""
125+
return {
126+
(UNSET_DEBUG_HANDLE,): ["_assert_tensor_metadata_default", "to"],
127+
(1,): ["conv2d"],
128+
(2,): ["view"],
129+
(3,): ["linear"],
130+
(4,): ["add"],
131+
(5,): ["sub"],
132+
(6,): ["mul"],
133+
(7,): ["add_1"],
134+
(8,): ["div"],
135+
(9,): ["relu"],
136+
(10,): ["sigmoid"],
137+
(11,): ["split"],
138+
}
139+
120140

121141
# Global model registry
122142
model_registry = {
@@ -153,15 +173,13 @@ def check_if_intermediate_outputs_match(
153173
return True
154174

155175

156-
def check_if_debug_handle_to_op_names_match(model_name, actual_debug_handle_to_op_name):
176+
def check_if_debug_handle_to_op_names_match(
177+
actual_debug_handle_to_op_name, expected_debug_handle_to_op_name
178+
):
157179
"""
158180
Checks if the actual op names match the expected op names for the specified model.
159181
Returns True if all match, otherwise returns False.
160182
"""
161-
model_instance = model_registry[model_name]
162-
expected_debug_handle_to_op_name = (
163-
model_instance.get_expected_debug_handle_to_op_names()
164-
)
165183
if len(actual_debug_handle_to_op_name) != len(expected_debug_handle_to_op_name):
166184
return False
167185
for debug_handle, expected_op_name in expected_debug_handle_to_op_name.items():

0 commit comments

Comments
 (0)