@@ -43,3 +43,100 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343
4444 for scheduler in schedulers .values ():
4545 self .assertEqual ("test_session" , scheduler .session_name )
46+
47+ @patch ("torchx.schedulers.load_group" )
48+ def test_custom_schedulers_merged (self , mock_load_group : MagicMock ) -> None :
49+ mock_scheduler = MagicMock ()
50+ mock_load_group .return_value = {"custom" : mock_scheduler }
51+
52+ factories = get_scheduler_factories ()
53+
54+ self .assertIn ("custom" , factories )
55+ self .assertEqual (factories ["custom" ], mock_scheduler )
56+ self .assertIn ("local_docker" , factories )
57+
58+ @patch ("torchx.schedulers.load_group" )
59+ def test_custom_scheduler_overrides_default (
60+ self , mock_load_group : MagicMock
61+ ) -> None :
62+ mock_scheduler = MagicMock ()
63+ mock_load_group .return_value = {"local_docker" : mock_scheduler }
64+
65+ factories = get_scheduler_factories ()
66+
67+ self .assertEqual (factories ["local_docker" ], mock_scheduler )
68+
69+ @patch ("torchx.schedulers.load_group" )
70+ def test_skip_defaults_with_custom_schedulers (
71+ self , mock_load_group : MagicMock
72+ ) -> None :
73+ mock_scheduler = MagicMock ()
74+ mock_load_group .return_value = {"custom" : mock_scheduler }
75+
76+ factories = get_scheduler_factories (skip_defaults = True )
77+
78+ self .assertEqual (factories , {"custom" : mock_scheduler })
79+ self .assertNotIn ("local_docker" , factories )
80+
81+ @patch ("torchx.schedulers.load_group" )
82+ def test_with_custom_schedulers_skip_defaults_false (
83+ self , mock_load_group : MagicMock
84+ ) -> None :
85+ """with custom schedulers, skip_defaults=False returns both"""
86+ mock_aws = MagicMock ()
87+ mock_custom = MagicMock ()
88+ mock_load_group .return_value = {"aws_batch" : mock_aws , "custom_1" : mock_custom }
89+
90+ factories = get_scheduler_factories (skip_defaults = False )
91+
92+ self .assertIn ("aws_batch" , factories )
93+ self .assertIn ("custom_1" , factories )
94+ self .assertIn ("local_docker" , factories )
95+ self .assertIn ("slurm" , factories )
96+
97+ @patch ("torchx.schedulers.load_group" )
98+ def test_with_custom_schedulers_skip_defaults_true (
99+ self , mock_load_group : MagicMock
100+ ) -> None :
101+ """with custom schedulers, skip_defaults=True returns only custom"""
102+ mock_aws = MagicMock ()
103+ mock_custom = MagicMock ()
104+ mock_load_group .return_value = {"aws_batch" : mock_aws , "custom_1" : mock_custom }
105+
106+ factories = get_scheduler_factories (skip_defaults = True )
107+
108+ self .assertEqual (set (factories .keys ()), {"aws_batch" , "custom_1" })
109+
110+ @patch ("torchx.schedulers.load_group" )
111+ def test_no_custom_schedulers_skip_defaults_false (
112+ self , mock_load_group : MagicMock
113+ ) -> None :
114+ """no custom schedulers, skip_defaults=False returns defaults"""
115+ mock_load_group .return_value = {}
116+
117+ factories = get_scheduler_factories (skip_defaults = False )
118+
119+ self .assertIn ("local_docker" , factories )
120+ self .assertIn ("slurm" , factories )
121+
122+ @patch ("torchx.schedulers.load_group" )
123+ def test_no_custom_schedulers_skip_defaults_true (
124+ self , mock_load_group : MagicMock
125+ ) -> None :
126+ """no custom schedulers, skip_defaults=True returns empty"""
127+ mock_load_group .return_value = {}
128+
129+ factories = get_scheduler_factories (skip_defaults = True )
130+
131+ self .assertEqual (factories , {})
132+
133+ @patch ("torchx.schedulers.load_group" )
134+ def test_custom_scheduler_is_default (self , mock_load_group : MagicMock ) -> None :
135+ """first custom scheduler becomes the default"""
136+ mock_aws = MagicMock ()
137+ mock_custom = MagicMock ()
138+ mock_load_group .return_value = {"aws_batch" : mock_aws , "custom_1" : mock_custom }
139+
140+ default_name = get_default_scheduler_name ()
141+
142+ self .assertIn (default_name , ["aws_batch" , "custom_1" ])
0 commit comments