Skip to content

Commit 8f0c6b0

Browse files
committed
feat: list all registered schedulers (#1009)
1 parent 5957532 commit 8f0c6b0

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

torchx/schedulers/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,16 @@ def get_scheduler_factories(
4848
4949
The first scheduler in the dictionary is used as the default scheduler.
5050
"""
51-
52-
if skip_defaults:
53-
default_schedulers = {}
54-
else:
55-
default_schedulers: dict[str, SchedulerFactory] = {}
56-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57-
default_schedulers[scheduler] = _defer_load_scheduler(path)
58-
59-
return load_group(group, default=default_schedulers)
51+
valid_schedulers = (
52+
{}
53+
if skip_defaults
54+
else {
55+
name: _defer_load_scheduler(path)
56+
for name, path in DEFAULT_SCHEDULER_MODULES.items()
57+
}
58+
)
59+
valid_schedulers.update(load_group(group, default={}))
60+
return valid_schedulers
6061

6162

6263
def get_default_scheduler_name() -> str:

torchx/schedulers/test/registry_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,37 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343

4444
for scheduler in schedulers.values():
4545
self.assertEqual("test_session", scheduler.session_name)
46+
47+
@patch("torchx.schedulers.load_group")
48+
def test_custom_schedulers_merged(self, mock_load_group: MagicMock) -> None:
49+
mock_scheduler = MagicMock()
50+
mock_load_group.return_value = {"custom": mock_scheduler}
51+
52+
factories = get_scheduler_factories()
53+
54+
self.assertIn("custom", factories)
55+
self.assertEqual(factories["custom"], mock_scheduler)
56+
self.assertIn("local_docker", factories)
57+
58+
@patch("torchx.schedulers.load_group")
59+
def test_custom_scheduler_overrides_default(
60+
self, mock_load_group: MagicMock
61+
) -> None:
62+
mock_scheduler = MagicMock()
63+
mock_load_group.return_value = {"local_docker": mock_scheduler}
64+
65+
factories = get_scheduler_factories()
66+
67+
self.assertEqual(factories["local_docker"], mock_scheduler)
68+
69+
@patch("torchx.schedulers.load_group")
70+
def test_skip_defaults_with_custom_schedulers(
71+
self, mock_load_group: MagicMock
72+
) -> None:
73+
mock_scheduler = MagicMock()
74+
mock_load_group.return_value = {"custom": mock_scheduler}
75+
76+
factories = get_scheduler_factories(skip_defaults=True)
77+
78+
self.assertEqual(factories, {"custom": mock_scheduler})
79+
self.assertNotIn("local_docker", factories)

0 commit comments

Comments
 (0)