@@ -43,3 +43,37 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343
4444 for scheduler in schedulers .values ():
4545 self .assertEqual ("test_session" , scheduler .session_name )
46+
47+ @patch ("torchx.schedulers.load_group" )
48+ def test_custom_schedulers_merged (self , mock_load_group : MagicMock ) -> None :
49+ mock_scheduler = MagicMock ()
50+ mock_load_group .return_value = {"custom" : mock_scheduler }
51+
52+ factories = get_scheduler_factories ()
53+
54+ self .assertIn ("custom" , factories )
55+ self .assertEqual (factories ["custom" ], mock_scheduler )
56+ self .assertIn ("local_docker" , factories )
57+
58+ @patch ("torchx.schedulers.load_group" )
59+ def test_custom_scheduler_overrides_default (
60+ self , mock_load_group : MagicMock
61+ ) -> None :
62+ mock_scheduler = MagicMock ()
63+ mock_load_group .return_value = {"local_docker" : mock_scheduler }
64+
65+ factories = get_scheduler_factories ()
66+
67+ self .assertEqual (factories ["local_docker" ], mock_scheduler )
68+
69+ @patch ("torchx.schedulers.load_group" )
70+ def test_skip_defaults_with_custom_schedulers (
71+ self , mock_load_group : MagicMock
72+ ) -> None :
73+ mock_scheduler = MagicMock ()
74+ mock_load_group .return_value = {"custom" : mock_scheduler }
75+
76+ factories = get_scheduler_factories (skip_defaults = True )
77+
78+ self .assertEqual (factories , {"custom" : mock_scheduler })
79+ self .assertNotIn ("local_docker" , factories )
0 commit comments