|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | import unittest |
| 10 | +from datetime import timedelta |
10 | 11 | from unittest import mock |
11 | 12 |
|
12 | 13 | from monarch.tools import commands |
13 | | -from monarch.tools.commands import component_args_from_cli |
| 14 | +from monarch.tools.commands import component_args_from_cli, server_ready |
14 | 15 |
|
15 | 16 | from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults |
16 | 17 | defaults, |
@@ -101,3 +102,78 @@ def test_info( |
101 | 102 | ), |
102 | 103 | commands.info("slurm:///job-id"), |
103 | 104 | ) |
| 105 | + |
| 106 | + |
| 107 | +UNUSED = "__UNUSED__" |
| 108 | +_5_MS = timedelta(milliseconds=5) |
| 109 | + |
| 110 | + |
| 111 | +def server(state: AppState) -> ServerSpec: |
| 112 | + mesh_x = MeshSpec(name="x", num_hosts=2, host_type=UNUSED, gpus=-1) |
| 113 | + mesh_y = MeshSpec(name="y", num_hosts=4, host_type=UNUSED, gpus=-1) |
| 114 | + meshes = [mesh_x, mesh_y] |
| 115 | + |
| 116 | + if state == AppState.RUNNING: |
| 117 | + for mesh in meshes: |
| 118 | + mesh.hostnames = [f"node{i}" for i in range(mesh.num_hosts)] |
| 119 | + |
| 120 | + return ServerSpec(name=UNUSED, state=state, meshes=meshes) |
| 121 | + |
| 122 | + |
| 123 | +class TestCommandsAsync(unittest.IsolatedAsyncioTestCase): |
| 124 | + async def test_server_ready_server_does_not_exist(self) -> None: |
| 125 | + with mock.patch( |
| 126 | + "monarch.tools.commands.info", |
| 127 | + return_value=None, |
| 128 | + ): |
| 129 | + server_info = await server_ready("slurm:///123", check_interval=_5_MS) |
| 130 | + self.assertIsNone(server_info) |
| 131 | + |
| 132 | + async def test_server_ready_pending_to_running(self) -> None: |
| 133 | + with mock.patch( |
| 134 | + "monarch.tools.commands.info", |
| 135 | + side_effect=[ |
| 136 | + server(AppState.UNSUBMITTED), |
| 137 | + server(AppState.SUBMITTED), |
| 138 | + server(AppState.PENDING), |
| 139 | + server(AppState.PENDING), |
| 140 | + server(AppState.RUNNING), |
| 141 | + server(AppState.CANCELLED), |
| 142 | + ], |
| 143 | + ) as mock_info: |
| 144 | + server_info = await server_ready("slurm:///123", check_interval=_5_MS) |
| 145 | + |
| 146 | + self.assertIsNotNone(server_info) |
| 147 | + self.assertTrue(server_info.is_running) |
| 148 | + self.assertEqual(server_info.state, AppState.RUNNING) |
| 149 | + |
| 150 | + mesh_x = server_info.get_mesh_spec("x") |
| 151 | + mesh_y = server_info.get_mesh_spec("y") |
| 152 | + self.assertListEqual(mesh_x.hostnames, ["node0", "node1"]) |
| 153 | + self.assertListEqual(mesh_y.hostnames, ["node0", "node1", "node2", "node3"]) |
| 154 | + |
| 155 | + mock_info.assert_called() |
| 156 | + # called 5 times, once for UNSUBMITTED, SUBMITTED, PENDING, PENDING, and RUNNING |
| 157 | + self.assertEqual(mock_info.call_count, 5) |
| 158 | + |
| 159 | + async def test_server_ready_pending_to_terminal(self) -> None: |
| 160 | + for terminal_state in [AppState.SUCCEEDED, AppState.FAILED, AppState.CANCELLED]: |
| 161 | + with self.subTest(terminal_state=terminal_state): |
| 162 | + with mock.patch( |
| 163 | + "monarch.tools.commands.info", |
| 164 | + side_effect=[ |
| 165 | + server(AppState.SUBMITTED), |
| 166 | + server(AppState.PENDING), |
| 167 | + server(AppState.PENDING), |
| 168 | + server(terminal_state), |
| 169 | + ], |
| 170 | + ) as mock_info: |
| 171 | + server_info = await server_ready( |
| 172 | + "slurm:///123", |
| 173 | + check_interval=_5_MS, |
| 174 | + ) |
| 175 | + |
| 176 | + self.assertIsNotNone(server_info) |
| 177 | + self.assertEqual(server_info.state, terminal_state) |
| 178 | + mock_info.assert_called() |
| 179 | + self.assertEqual(mock_info.call_count, 4) |
0 commit comments