diff --git a/python/monarch/tools/commands.py b/python/monarch/tools/commands.py index bedc0a556..fa78eb1ab 100644 --- a/python/monarch/tools/commands.py +++ b/python/monarch/tools/commands.py @@ -9,7 +9,10 @@ import argparse import functools import inspect +import logging import os +import time +from datetime import timedelta from typing import Any, Callable, Mapping, Optional, Union from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults @@ -18,12 +21,13 @@ ) from monarch.tools.mesh_spec import mesh_spec_from_metadata, ServerSpec - from torchx.runner import Runner -from torchx.specs import AppDef, AppDryRunInfo, CfgVal +from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal from torchx.specs.builders import parse_args from torchx.util.types import decode, decode_optional +logger: logging.Logger = logging.getLogger(__name__) + def torchx_runner() -> Runner: # namespace is currently unused so make it empty str @@ -165,15 +169,73 @@ def info(server_handle: str) -> Optional[ServerSpec]: if appdef is None: return None + # host status grouped by mesh (role) names + replica_status = {r.role: r.replicas for r in status.roles} + mesh_specs = [] for role in appdef.roles: spec = mesh_spec_from_metadata(appdef, role.name) assert spec is not None, "cannot be 'None' since we iterate over appdef's roles" + + # null-guard since some schedulers do not fill replica_status + if host_status := replica_status.get(role.name): + spec.hostnames = [h.hostname for h in host_status] + mesh_specs.append(spec) return ServerSpec(name=appdef.name, state=status.state, meshes=mesh_specs) +_5_SECONDS = timedelta(seconds=5) + + +async def server_ready( + server_handle: str, check_interval: timedelta = _5_SECONDS +) -> Optional[ServerSpec]: + """Waits until the server's job is in RUNNING state to returns the server spec. + Returns `None` if the server does not exist. + + NOTE: Certain fields such as `hostnames` is only filled (and valid) when the server is RUNNING. + + Usage: + + .. code-block:: python + + server_info = await server_ready("slurm:///123") + if not server_info: + print(f"Job does not exist") + else: + if server_info.is_running: + for mesh in server_info.meshes: + connect_to(mesh.hostnames) + else: + print(f"Job in {server_info.state} state. Hostnames are not available") + + """ + + while True: + server_spec = info(server_handle) + + if not server_spec: # server not found + return None + + if server_spec.state <= AppState.PENDING: # UNSUBMITTED or SUBMITTED or PENDING + # NOTE: TorchX currently does not have async APIs so need to loop-on-interval + # TODO maybe inverse exponential backoff instead of constant interval? + check_interval_seconds = check_interval.total_seconds() + logger.info( + "waiting for %s to be %s (current: %s), will check again in %g seconds...", + server_handle, + AppState.RUNNING, + server_spec.state, + check_interval_seconds, + ) + time.sleep(check_interval_seconds) + continue + else: + return server_spec + + def kill(server_handle: str) -> None: with torchx_runner() as runner: runner.cancel(server_handle) diff --git a/python/monarch/tools/mesh_spec.py b/python/monarch/tools/mesh_spec.py index 258911e27..ba83f7335 100644 --- a/python/monarch/tools/mesh_spec.py +++ b/python/monarch/tools/mesh_spec.py @@ -6,7 +6,7 @@ # pyre-strict import string -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional from torchx import specs @@ -29,6 +29,7 @@ class MeshSpec: host_type: str gpus: int port: int = DEFAULT_REMOTE_ALLOCATOR_PORT + hostnames: list[str] = field(default_factory=list) def _tag(mesh_name: str, tag_template: str) -> str: @@ -84,6 +85,10 @@ class ServerSpec: state: specs.AppState meshes: list[MeshSpec] + @property + def is_running(self) -> bool: + return self.state == specs.AppState.RUNNING + def get_mesh_spec(self, mesh_name: str) -> MeshSpec: for mesh_spec in self.meshes: if mesh_spec.name == mesh_name: @@ -115,6 +120,7 @@ def to_json(self) -> dict[str, Any]: "host_type": mesh.host_type, "hosts": mesh.num_hosts, "gpus": mesh.gpus, + "hostnames": mesh.hostnames, } for mesh in self.meshes }, diff --git a/python/tests/tools/test_cli.py b/python/tests/tools/test_cli.py index b957e347e..7879c9a24 100644 --- a/python/tests/tools/test_cli.py +++ b/python/tests/tools/test_cli.py @@ -68,12 +68,14 @@ def test_info(self, mock_cmd_info: mock.MagicMock) -> None: "trainer": { "host_type": "gpu.medium", "hosts": 4, - "gpus": 2 + "gpus": 2, + "hostnames": [] }, "generator": { "host_type": "gpu.small", "hosts": 16, - "gpus": 1 + "gpus": 1, + "hostnames": [] } } } diff --git a/python/tests/tools/test_commands.py b/python/tests/tools/test_commands.py index 9ee60a128..90fdc7e96 100644 --- a/python/tests/tools/test_commands.py +++ b/python/tests/tools/test_commands.py @@ -7,10 +7,11 @@ # pyre-strict import unittest +from datetime import timedelta from unittest import mock from monarch.tools import commands -from monarch.tools.commands import component_args_from_cli +from monarch.tools.commands import component_args_from_cli, server_ready from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults defaults, @@ -101,3 +102,78 @@ def test_info( ), commands.info("slurm:///job-id"), ) + + +UNUSED = "__UNUSED__" +_5_MS = timedelta(milliseconds=5) + + +def server(state: AppState) -> ServerSpec: + mesh_x = MeshSpec(name="x", num_hosts=2, host_type=UNUSED, gpus=-1) + mesh_y = MeshSpec(name="y", num_hosts=4, host_type=UNUSED, gpus=-1) + meshes = [mesh_x, mesh_y] + + if state == AppState.RUNNING: + for mesh in meshes: + mesh.hostnames = [f"node{i}" for i in range(mesh.num_hosts)] + + return ServerSpec(name=UNUSED, state=state, meshes=meshes) + + +class TestCommandsAsync(unittest.IsolatedAsyncioTestCase): + async def test_server_ready_server_does_not_exist(self) -> None: + with mock.patch( + "monarch.tools.commands.info", + return_value=None, + ): + server_info = await server_ready("slurm:///123", check_interval=_5_MS) + self.assertIsNone(server_info) + + async def test_server_ready_pending_to_running(self) -> None: + with mock.patch( + "monarch.tools.commands.info", + side_effect=[ + server(AppState.UNSUBMITTED), + server(AppState.SUBMITTED), + server(AppState.PENDING), + server(AppState.PENDING), + server(AppState.RUNNING), + server(AppState.CANCELLED), + ], + ) as mock_info: + server_info = await server_ready("slurm:///123", check_interval=_5_MS) + + self.assertIsNotNone(server_info) + self.assertTrue(server_info.is_running) + self.assertEqual(server_info.state, AppState.RUNNING) + + mesh_x = server_info.get_mesh_spec("x") + mesh_y = server_info.get_mesh_spec("y") + self.assertListEqual(mesh_x.hostnames, ["node0", "node1"]) + self.assertListEqual(mesh_y.hostnames, ["node0", "node1", "node2", "node3"]) + + mock_info.assert_called() + # called 5 times, once for UNSUBMITTED, SUBMITTED, PENDING, PENDING, and RUNNING + self.assertEqual(mock_info.call_count, 5) + + async def test_server_ready_pending_to_terminal(self) -> None: + for terminal_state in [AppState.SUCCEEDED, AppState.FAILED, AppState.CANCELLED]: + with self.subTest(terminal_state=terminal_state): + with mock.patch( + "monarch.tools.commands.info", + side_effect=[ + server(AppState.SUBMITTED), + server(AppState.PENDING), + server(AppState.PENDING), + server(terminal_state), + ], + ) as mock_info: + server_info = await server_ready( + "slurm:///123", + check_interval=_5_MS, + ) + + self.assertIsNotNone(server_info) + self.assertEqual(server_info.state, terminal_state) + mock_info.assert_called() + self.assertEqual(mock_info.call_count, 4) diff --git a/python/tests/tools/test_mesh_spec.py b/python/tests/tools/test_mesh_spec.py index f2f892790..b62189870 100644 --- a/python/tests/tools/test_mesh_spec.py +++ b/python/tests/tools/test_mesh_spec.py @@ -82,7 +82,11 @@ def test_mesh_spec_from_metadata(self) -> None: def test_mesh_spec_can_dump_as_json(self) -> None: mesh_spec = MeshSpec( - name="trainer", num_hosts=4, host_type="gpu.medium", gpus=2 + name="trainer", + num_hosts=4, + host_type="gpu.medium", + gpus=2, + hostnames=["n0", "n1", "n2", "n3"], ) expected = """ { @@ -90,7 +94,13 @@ def test_mesh_spec_can_dump_as_json(self) -> None: "num_hosts": 4, "host_type": "gpu.medium", "gpus": 2, - "port": 26600 + "port": 26600, + "hostnames": [ + "n0", + "n1", + "n2", + "n3" + ] } """ self.assertEqual(expected.strip("\n"), json.dumps(asdict(mesh_spec), indent=2))