Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from torchx.util.log_tee_helpers import tee_logs
from torchx.util.types import none_throws
from torchx.workspace import Workspace


MISSING_COMPONENT_ERROR_MSG = (
Expand Down Expand Up @@ -92,7 +93,7 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:

torchx_args = TorchXRunArgs(**filtered_json_data)
if torchx_args.workspace == "":
torchx_args.workspace = f"file://{Path.cwd()}"
torchx_args.workspace = f"{Path.cwd()}"
return torchx_args


Expand Down Expand Up @@ -250,7 +251,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
subparser.add_argument(
"--workspace",
"--buck-target",
default=f"file://{Path.cwd()}",
default=f"{Path.cwd()}",
action=torchxconfig_run,
help="local workspace to build/patch (buck-target of main binary if using buck)",
)
Expand Down Expand Up @@ -289,12 +290,14 @@ def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
else args.component_args
)
try:
workspace = Workspace.from_str(args.workspace) if args.workspace else None

if args.dryrun:
dryrun_info = runner.dryrun_component(
args.component_name,
component_args,
args.scheduler,
workspace=args.workspace,
workspace=workspace,
cfg=args.scheduler_cfg,
parent_run_id=args.parent_run_id,
)
Expand Down
12 changes: 6 additions & 6 deletions torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:

def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
"""Test that arguments with values conflict with stdin."""
args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"])
args = self.parser.parse_args(["--stdin", "--workspace", "/custom/path"])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

Expand Down Expand Up @@ -499,7 +499,7 @@ def test_torchx_run_args_from_json(self) -> None:
self.assertEqual(result.dryrun, False)
self.assertEqual(result.wait, False)
self.assertEqual(result.log, False)
self.assertEqual(result.workspace, f"file://{Path.cwd()}")
self.assertEqual(result.workspace, f"{Path.cwd()}")
self.assertEqual(result.parent_run_id, None)
self.assertEqual(result.tee_logs, False)
self.assertEqual(result.component_args, {})
Expand All @@ -515,7 +515,7 @@ def test_torchx_run_args_from_json(self) -> None:
"dryrun": True,
"wait": True,
"log": True,
"workspace": "file:///custom/path",
"workspace": "/custom/path",
"parent_run_id": "parent123",
"tee_logs": True,
}
Expand All @@ -529,7 +529,7 @@ def test_torchx_run_args_from_json(self) -> None:
self.assertEqual(result2.dryrun, True)
self.assertEqual(result2.wait, True)
self.assertEqual(result2.log, True)
self.assertEqual(result2.workspace, "file:///custom/path")
self.assertEqual(result2.workspace, "/custom/path")
self.assertEqual(result2.parent_run_id, "parent123")
self.assertEqual(result2.tee_logs, True)

Expand Down Expand Up @@ -626,7 +626,7 @@ def test_torchx_run_args_from_argparse(self) -> None:
args.dryrun = True
args.wait = False
args.log = True
args.workspace = "file:///custom/workspace"
args.workspace = "/custom/workspace"
args.parent_run_id = "parent_123"
args.tee_logs = False

Expand Down Expand Up @@ -654,7 +654,7 @@ def test_torchx_run_args_from_argparse(self) -> None:
self.assertEqual(result.dryrun, True)
self.assertEqual(result.wait, False)
self.assertEqual(result.log, True)
self.assertEqual(result.workspace, "file:///custom/workspace")
self.assertEqual(result.workspace, "/custom/workspace")
self.assertEqual(result.parent_run_id, "parent_123")
self.assertEqual(result.tee_logs, False)
self.assertEqual(result.component_args, {})
Expand Down
55 changes: 30 additions & 25 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from torchx.util.session import get_session_id_or_create_new, TORCHX_INTERNAL_SESSION_ID

from torchx.util.types import none_throws
from torchx.workspace.api import WorkspaceMixin
from torchx.workspace.api import Workspace, WorkspaceMixin

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -171,7 +171,7 @@ def run_component(
component_args: Union[list[str], dict[str, Any]],
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[str] = None,
workspace: Optional[Union[Workspace, str]] = None,
parent_run_id: Optional[str] = None,
) -> AppHandle:
"""
Expand Down Expand Up @@ -206,7 +206,7 @@ def run_component(
ComponentNotFoundException: if the ``component_path`` is failed to resolve.
"""

with log_event("run_component", workspace=workspace) as ctx:
with log_event("run_component") as ctx:
dryrun_info = self.dryrun_component(
component,
component_args,
Expand All @@ -217,7 +217,8 @@ def run_component(
)
handle = self.schedule(dryrun_info)
app = none_throws(dryrun_info._app)
ctx._torchx_event.workspace = workspace

ctx._torchx_event.workspace = str(workspace)
ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
ctx._torchx_event.app_image = app.roles[0].image
ctx._torchx_event.app_id = parse_app_handle(handle)[2]
Expand All @@ -230,7 +231,7 @@ def dryrun_component(
component_args: Union[list[str], dict[str, Any]],
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[str] = None,
workspace: Optional[Union[Workspace, str]] = None,
parent_run_id: Optional[str] = None,
) -> AppDryRunInfo:
"""
Expand Down Expand Up @@ -259,7 +260,7 @@ def run(
app: AppDef,
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[str] = None,
workspace: Optional[Union[Workspace, str]] = None,
parent_run_id: Optional[str] = None,
) -> AppHandle:
"""
Expand All @@ -272,9 +273,7 @@ def run(
An application handle that is used to call other action APIs on the app.
"""

with log_event(
api="run", runcfg=json.dumps(cfg) if cfg else None, workspace=workspace
) as ctx:
with log_event(api="run") as ctx:
dryrun_info = self.dryrun(
app,
scheduler,
Expand All @@ -283,10 +282,15 @@ def run(
parent_run_id=parent_run_id,
)
handle = self.schedule(dryrun_info)
ctx._torchx_event.scheduler = none_throws(dryrun_info._scheduler)
ctx._torchx_event.app_image = none_throws(dryrun_info._app).roles[0].image
ctx._torchx_event.app_id = parse_app_handle(handle)[2]
ctx._torchx_event.app_metadata = app.metadata

event = ctx._torchx_event
event.scheduler = scheduler
event.runcfg = json.dumps(cfg) if cfg else None
event.workspace = str(workspace)
event.app_id = parse_app_handle(handle)[2]
event.app_image = none_throws(dryrun_info._app).roles[0].image
event.app_metadata = app.metadata

return handle

def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle:
Expand Down Expand Up @@ -320,21 +324,22 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle:

"""
scheduler = none_throws(dryrun_info._scheduler)
app_image = none_throws(dryrun_info._app).roles[0].image
cfg = dryrun_info._cfg
with log_event(
"schedule",
scheduler,
app_image=app_image,
runcfg=json.dumps(cfg) if cfg else None,
) as ctx:
with log_event("schedule") as ctx:
sched = self._scheduler(scheduler)
app_id = sched.schedule(dryrun_info)
app_handle = make_app_handle(scheduler, self._name, app_id)

app = none_throws(dryrun_info._app)
self._apps[app_handle] = app
_, _, app_id = parse_app_handle(app_handle)
ctx._torchx_event.app_id = app_id

event = ctx._torchx_event
event.scheduler = scheduler
event.runcfg = json.dumps(cfg) if cfg else None
event.app_id = app_id
event.app_image = none_throws(dryrun_info._app).roles[0].image
event.app_metadata = app.metadata

return app_handle

def name(self) -> str:
Expand All @@ -345,7 +350,7 @@ def dryrun(
app: AppDef,
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[str] = None,
workspace: Optional[Union[Workspace, str]] = None,
parent_run_id: Optional[str] = None,
) -> AppDryRunInfo:
"""
Expand Down Expand Up @@ -414,7 +419,7 @@ def dryrun(
"dryrun",
scheduler,
runcfg=json.dumps(cfg) if cfg else None,
workspace=workspace,
workspace=str(workspace),
):
sched = self._scheduler(scheduler)
resolved_cfg = sched.run_opts().resolve(cfg)
Expand All @@ -429,7 +434,7 @@ def dryrun(
logger.info(
'To disable workspaces pass: --workspace="" from CLI or workspace=None programmatically.'
)
sched.build_workspace_and_update_role(role, workspace, resolved_cfg)
sched.build_workspace_and_update_role2(role, workspace, resolved_cfg)

if old_img != role.image:
logger.info(
Expand Down
30 changes: 20 additions & 10 deletions torchx/runner/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@

from .api import SourceType, TorchxEvent # noqa F401

# pyre-fixme[9]: _events_logger is a global variable
_events_logger: logging.Logger = None
_events_logger: Optional[logging.Logger] = None

log: logging.Logger = logging.getLogger(__name__)


def _get_or_create_logger(destination: str = "null") -> logging.Logger:
Expand All @@ -51,19 +52,28 @@ def _get_or_create_logger(destination: str = "null") -> logging.Logger:
a new logger if None provided.
"""
global _events_logger

if _events_logger:
return _events_logger
logging_handler = get_logging_handler(destination)
logging_handler.setLevel(logging.DEBUG)
_events_logger = logging.getLogger(f"torchx-events-{destination}")
# Do not propagate message to the root logger
_events_logger.propagate = False
_events_logger.addHandler(logging_handler)
return _events_logger
else:
logging_handler = get_logging_handler(destination)
logging_handler.setLevel(logging.DEBUG)
_events_logger = logging.getLogger(f"torchx-events-{destination}")
# Do not propagate message to the root logger
_events_logger.propagate = False
_events_logger.addHandler(logging_handler)

assert _events_logger # make type-checker happy
return _events_logger


def record(event: TorchxEvent, destination: str = "null") -> None:
_get_or_create_logger(destination).info(event.serialize())
try:
serialized_event = event.serialize()
except Exception:
log.exception("failed to serialize event, will not record event")
else:
_get_or_create_logger(destination).info(serialized_event)


class log_event:
Expand Down
2 changes: 1 addition & 1 deletion torchx/runner/events/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TorchxEvent:
scheduler: Scheduler that is used to execute request
api: Api name
app_id: Unique id that is set by the underlying scheduler
image: Image/container bundle that is used to execute request.
app_image: Image/container bundle that is used to execute request.
app_metadata: metadata to the app (treatment of metadata is scheduler dependent)
runcfg: Run config that was used to schedule app.
source: Type of source the event is generated.
Expand Down
29 changes: 29 additions & 0 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Stream
from torchx.specs import AppDef, AppDryRunInfo, CfgVal, runopts
from torchx.test.fixtures import TestWithTmpDir
from torchx.workspace import Workspace


class TestScheduler(Scheduler):
Expand Down Expand Up @@ -506,3 +507,31 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
opt_name in cfg,
f"missing {opt_name} in {sched} run opts with cfg {cfg}",
)

def test_get_workspace_config(self) -> None:
configdir = self.tmpdir
self.write(
str(configdir / ".torchxconfig"),
"""#
[cli:run]
workspace =
/home/foo/third-party/verl: verl
/home/foo/bar/scripts/.torchxconfig: verl/.torchxconfig
/home/foo/baz:
""",
)

workspace_config = get_config(
prefix="cli", name="run", key="workspace", dirs=[str(configdir)]
)
self.assertIsNotNone(workspace_config)

workspace = Workspace.from_str(workspace_config)
self.assertDictEqual(
{
"/home/foo/third-party/verl": "verl",
"/home/foo/bar/scripts/.torchxconfig": "verl/.torchxconfig",
"/home/foo/baz": "",
},
workspace.projects,
)
Loading
Loading