@@ -224,11 +224,13 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
224224
225225 # Initializes the nodes metadata and runs the GenerateMemoryViewConstraints,
226226 # GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
227- def run_memory_planning (self , original , alloc_graph_input = True ) -> GraphModule :
227+ def run_memory_planning (
228+ self , original , opt_level = 2 , alloc_graph_input = True
229+ ) -> GraphModule :
228230 graph_module = SpecPropPass ().call (original ).graph_module
229231 return CadenceMemoryPlanning (
230232 get_default_memory_config (),
231- opt_level = 2 ,
233+ opt_level = opt_level ,
232234 mem_algo = 1 , # greedy_by_size_for_offset_calculation_with_hierarchy
233235 alloc_graph_input = alloc_graph_input ,
234236 )(graph_module ).graph_module
@@ -535,130 +537,239 @@ def test_optimize_cat_with_slice_infeasible(self) -> None:
535537 self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
536538 self .verify_nop_memory_alloc (graph_module )
537539
538- def test_optimize_slice_Tensor (self ) -> None :
539- class SliceTensor (torch .nn .Module ):
540- def forward (self , x , y , z ):
541- x1 = torch .add (x , 2.4 , 3.1 )
542- # This slice should always be optimized, since x1 is not placeholder
543- # and the slice is along the outermost dim
544- t1 = torch .ops .aten .slice (x1 , 0 , 1 , 2 )
545- # This slice should not be optimized when alloc_graph_input=False,
546- # since y is a placeholder node
547- t2 = torch .ops .aten .slice (y , 0 , 0 , 1 )
548- # This slice should be always optimized, since the dims before
549- # sliced dims are 1
550- z1 = torch .add (z , 2.4 , 3.1 )
551- t3 = torch .ops .aten .slice (z1 , 1 , 4 , 5 )
552- return (t1 + t2 ) * t3
553-
554- x = torch .ones (3 , 6 )
555- y = torch .ones (2 , 6 )
556- z = torch .ones (1 , 6 )
557- # Run the memory planning pass and get the graph module
558- graph_module = (
559- compiler .export_to_executorch_gen_etrecord (
560- SliceTensor (),
561- (x , y , z ),
562- opt_level = 2 ,
563- mem_algo = 1 ,
564- alloc_graph_input = False ,
565- )
566- .exported_program ()
567- .graph_module
540+ def test_optimize_slice_outermost (self ) -> None :
541+ builder = GraphBuilder ()
542+ x = builder .placeholder ("x" , torch .ones (3 , 6 , dtype = torch .float32 ))
543+ to_add_to_x = builder .call_operator (
544+ op = exir_ops .edge .aten .full .default ,
545+ args = ([3 , 6 ], 123.0 ),
546+ kwargs = {"dtype" : torch .float32 },
547+ )
548+ add_x = builder .call_operator (
549+ op = exir_ops .edge .aten .add .Tensor ,
550+ args = (x , to_add_to_x ),
551+ )
552+ slice_out = builder .call_operator (
553+ op = exir_ops .edge .aten .full .default ,
554+ args = ([1 , 6 ], 0.0 ),
555+ kwargs = {"dtype" : torch .float32 },
568556 )
557+ # This slice should always be optimized, since add_x is not placeholder
558+ # and the slice is along the outermost dim
559+ slice_result = builder .call_operator (
560+ op = torch .ops .aten .slice_copy .Tensor_out ,
561+ args = (
562+ add_x ,
563+ 0 , # dim
564+ 1 , # start
565+ 2 , # end
566+ 1 , # step
567+ ),
568+ kwargs = {"out" : slice_out },
569+ )
570+ builder .output ([slice_result ])
571+ original = builder .get_graph_module ()
572+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
569573 graph_module .graph .eliminate_dead_code ()
570- # Assert that t2 is not optimized away
571574 self .assertEqual (
572- count_node (graph_module , torch .ops .aten .slice_copy .Tensor_out ), 1
575+ count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 1
576+ )
577+ self .verify_nop_memory_alloc (graph_module )
578+
579+ def test_optimize_slice_non_outermost (self ) -> None :
580+ builder = GraphBuilder ()
581+ x = builder .placeholder ("x" , torch .ones (1 , 6 , dtype = torch .float32 ))
582+ to_add_to_x = builder .call_operator (
583+ op = exir_ops .edge .aten .full .default ,
584+ args = ([1 , 6 ], 123.0 ),
585+ kwargs = {"dtype" : torch .float32 },
586+ )
587+ add_x = builder .call_operator (
588+ op = exir_ops .edge .aten .add .Tensor ,
589+ args = (x , to_add_to_x ),
590+ )
591+ slice_out = builder .call_operator (
592+ op = exir_ops .edge .aten .full .default ,
593+ args = ([1 , 2 ], 0.0 ),
594+ kwargs = {"dtype" : torch .float32 },
595+ )
596+ # This slice should be always optimized, since the dims before
597+ # sliced dims are 1.
598+ slice_result = builder .call_operator (
599+ op = torch .ops .aten .slice_copy .Tensor_out ,
600+ args = (
601+ add_x ,
602+ 1 , # dim
603+ 4 , # start
604+ 6 , # end
605+ 1 , # step
606+ ),
607+ kwargs = {"out" : slice_out },
573608 )
574- # Assert that t1 and t3 are optimized to slice_copy_nop veresion
609+ builder .output ([slice_result ])
610+ original = builder .get_graph_module ()
611+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
612+ graph_module .graph .eliminate_dead_code ()
575613 self .assertEqual (
576- count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
614+ count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 1
577615 )
616+ self .verify_nop_memory_alloc (graph_module )
617+
618+ def test_optimize_slice_depending_on_opt_level (self ) -> None :
619+ builder = GraphBuilder ()
620+ x = builder .placeholder ("x" , torch .ones (2 , 6 , dtype = torch .float32 ))
621+ slice_out = builder .call_operator (
622+ op = exir_ops .edge .aten .full .default ,
623+ args = ([1 , 6 ], 0.0 ),
624+ kwargs = {"dtype" : torch .float32 },
625+ )
626+ # This slice should not be optimized when alloc_graph_input=False,
627+ # since y is a placeholder node
628+ slice_result = builder .call_operator (
629+ op = torch .ops .aten .slice_copy .Tensor_out ,
630+ args = (
631+ x ,
632+ 0 , # dim
633+ 0 , # start
634+ 1 , # end
635+ 1 , # step
636+ ),
637+ kwargs = {"out" : slice_out },
638+ )
639+ builder .output ([slice_result ])
640+ original = builder .get_graph_module ()
641+ graph_module = self .run_memory_planning (
642+ original , opt_level = 2 , alloc_graph_input = False
643+ )
644+ graph_module .graph .eliminate_dead_code ()
645+ self .assertEqual (
646+ count_node (graph_module , torch .ops .aten .slice_copy .Tensor_out ), 1
647+ )
648+ self .verify_nop_memory_alloc (graph_module )
649+
578650 # When we compile with alloc_graph_input=True, all the slice ops must
579- # be optimized.
580- # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
581- # pass to run:
582- graph_module = (
583- compiler .export_to_executorch_gen_etrecord (
584- SliceTensor (),
585- (x , y , z ),
586- opt_level = 3 ,
587- mem_algo = 1 ,
588- alloc_graph_input = True ,
589- )
590- .exported_program ()
591- .graph_module
651+ # be optimized, which is available only at opt_level 2+.
652+ graph_module = self .run_memory_planning (
653+ original , opt_level = 3 , alloc_graph_input = True
592654 )
593655 graph_module .graph .eliminate_dead_code ()
594- self .assertFalse (count_node (graph_module , torch .ops .aten .slice_copy .Tensor_out ))
595656 self .assertEqual (
596- count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
657+ count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 1
597658 )
598659 self .verify_nop_memory_alloc (graph_module )
599660
600- def test_optimize_select_Tensor (self ) -> None :
601- class SelectTensor (torch .nn .Module ):
602- def forward (self , x , y , z ):
603- x1 = torch .add (x , 2.4 , 3.1 )
604- # This select should always be optimized, since x1 is not
605- # placeholder, and the select is along the outermost dim
606- t1 = torch .select_copy (x1 , 0 , 1 )
607- # This select should not be optimized if alloc_graph_input=False,
608- # since y is a placeholder node.
609- t2 = torch .select_copy (y , 0 , 0 )
610- # This select should always be optimized, since the dims before
611- # select dims are 1
612- z1 = torch .add (z , 2.4 , 3.1 )
613- t3 = torch .select (z1 , 1 , 4 )
614- return (t1 + t2 ) * t3
615-
616- x = torch .ones (3 , 6 )
617- y = torch .ones (2 , 6 )
618- z = torch .ones (1 , 6 )
619- # Optimizing select ops is only at opt_level 2+, and requires the memory planning
620- # pass to run:
621- graph_module = (
622- compiler .export_to_executorch_gen_etrecord (
623- SelectTensor (),
624- (x , y , z ),
625- opt_level = 2 ,
626- mem_algo = 1 ,
627- alloc_graph_input = False ,
628- )
629- .exported_program ()
630- .graph_module
661+ def test_optimize_select_outermost (self ) -> None :
662+ builder = GraphBuilder ()
663+ x = builder .placeholder ("x" , torch .ones (3 , 6 , dtype = torch .float32 ))
664+ to_add_to_x = builder .call_operator (
665+ op = exir_ops .edge .aten .full .default ,
666+ args = ([3 , 6 ], 123.0 ),
667+ kwargs = {"dtype" : torch .float32 },
631668 )
669+ add_x = builder .call_operator (
670+ op = exir_ops .edge .aten .add .Tensor ,
671+ args = (x , to_add_to_x ),
672+ )
673+ slice_out = builder .call_operator (
674+ op = exir_ops .edge .aten .full .default ,
675+ args = ([1 , 6 ], 0.0 ),
676+ kwargs = {"dtype" : torch .float32 },
677+ )
678+ # This select should always be optimized, since add_x is not placeholder
679+ # and the select is along the outermost dim
680+ slice_result = builder .call_operator (
681+ op = torch .ops .aten .select_copy .int_out ,
682+ args = (
683+ add_x ,
684+ 0 , # dim
685+ 1 , # index
686+ ),
687+ kwargs = {"out" : slice_out },
688+ )
689+ builder .output ([slice_result ])
690+ original = builder .get_graph_module ()
691+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
632692 graph_module .graph .eliminate_dead_code ()
633- # Assert that t2 is not optimized away
634693 self .assertEqual (
635- count_node (graph_module , torch .ops .aten .select_copy .int_out ), 1
694+ count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 1
695+ )
696+ self .verify_nop_memory_alloc (graph_module )
697+
698+ def test_optimize_select_non_outermost (self ) -> None :
699+ builder = GraphBuilder ()
700+ x = builder .placeholder ("x" , torch .ones (1 , 6 , dtype = torch .float32 ))
701+ to_add_to_x = builder .call_operator (
702+ op = exir_ops .edge .aten .full .default ,
703+ args = ([1 , 6 ], 123.0 ),
704+ kwargs = {"dtype" : torch .float32 },
705+ )
706+ add_x = builder .call_operator (
707+ op = exir_ops .edge .aten .add .Tensor ,
708+ args = (x , to_add_to_x ),
709+ )
710+ slice_out = builder .call_operator (
711+ op = exir_ops .edge .aten .full .default ,
712+ args = ([1 , 2 ], 0.0 ),
713+ kwargs = {"dtype" : torch .float32 },
714+ )
715+ # This select should always be optimized, since the dims before
716+ # select dims are 1
717+ slice_result = builder .call_operator (
718+ op = torch .ops .aten .select_copy .int_out ,
719+ args = (
720+ add_x ,
721+ 1 , # dim
722+ 4 , # index
723+ ),
724+ kwargs = {"out" : slice_out },
636725 )
637- # Assert that t1 and t3 are optimized to select_copy_nop veresion
726+ builder .output ([slice_result ])
727+ original = builder .get_graph_module ()
728+ graph_module = self .run_memory_planning (original , alloc_graph_input = False )
729+ graph_module .graph .eliminate_dead_code ()
638730 self .assertEqual (
639- count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 2
731+ count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 1
640732 )
641- # When we compile with alloc_graph_input=True, all the select ops must
642- # be optimized.
643- # Optimizing select ops is only at opt_level 2+, and requires the memory planning
644- # pass to run:
645- graph_module = (
646- compiler .export_to_executorch_gen_etrecord (
647- SelectTensor (),
648- (x , y , z ),
649- opt_level = 3 ,
650- mem_algo = 1 ,
651- alloc_graph_input = True ,
652- )
653- .exported_program ()
654- .graph_module
733+ self .verify_nop_memory_alloc (graph_module )
734+
735+ def test_optimize_select_depending_on_opt_level (self ) -> None :
736+ builder = GraphBuilder ()
737+ x = builder .placeholder ("x" , torch .ones (2 , 6 , dtype = torch .float32 ))
738+ slice_out = builder .call_operator (
739+ op = exir_ops .edge .aten .full .default ,
740+ args = ([1 , 6 ], 0.0 ),
741+ kwargs = {"dtype" : torch .float32 },
742+ )
743+ # This select should not be optimized if alloc_graph_input=False,
744+ # since y is a placeholder node.
745+ slice_result = builder .call_operator (
746+ op = torch .ops .aten .select_copy .int_out ,
747+ args = (
748+ x ,
749+ 0 , # dim
750+ 0 , # index
751+ ),
752+ kwargs = {"out" : slice_out },
753+ )
754+ builder .output ([slice_result ])
755+ original = builder .get_graph_module ()
756+ graph_module = self .run_memory_planning (
757+ original , opt_level = 2 , alloc_graph_input = False
655758 )
656759 graph_module .graph .eliminate_dead_code ()
657760 self .assertEqual (
658- count_node (graph_module , torch .ops .aten .select_copy .int_out ), 0
761+ count_node (graph_module , torch .ops .aten .select_copy .int_out ), 1
659762 )
763+ self .verify_nop_memory_alloc (graph_module )
764+
765+ # When we compile with alloc_graph_input=True, all the slice ops must
766+ # be optimized, which is available only at opt_level 2+.
767+ graph_module = self .run_memory_planning (
768+ original , opt_level = 3 , alloc_graph_input = True
769+ )
770+ graph_module .graph .eliminate_dead_code ()
660771 self .assertEqual (
661- count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
772+ count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 1
662773 )
663774 self .verify_nop_memory_alloc (graph_module )
664775
0 commit comments