diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 3246652d3..4e3659514 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -21,6 +21,7 @@ import torchx.specs as specs from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run from torchx.cli.cmd_base import SubCommand +from torchx.cli.cmd_log import get_logs from torchx.runner import config, get_runner, Runner from torchx.runner.config import load_sections from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories @@ -186,6 +187,12 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None: help="optional parent run ID that this run belongs to." " It can be used to group runs for experiment tracking purposes", ) + subparser.add_argument( + "--tee_logs", + action="store_true", + default=False, + help="Add additional prefix to log lines to indicate which replica is printing the log", + ) subparser.add_argument( "component_name_and_args", nargs=argparse.REMAINDER, @@ -237,14 +244,18 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None: print(app_handle) if args.scheduler.startswith("local"): - self._wait_and_exit(runner, app_handle, log=True) + self._wait_and_exit( + runner, app_handle, log=True, tee_logs=args.tee_logs + ) else: logger.info(f"Launched app: {app_handle}") app_status = runner.status(app_handle) if app_status: logger.info(app_status.format()) if args.wait or args.log: - self._wait_and_exit(runner, app_handle, log=args.log) + self._wait_and_exit( + runner, app_handle, log=args.log, tee_logs=args.tee_logs + ) except (ComponentValidationException, ComponentNotFoundException) as e: error_msg = f"\nFailed to run component `{component}` got errors: \n {e}" @@ -267,10 +278,16 @@ def run(self, args: argparse.Namespace) -> None: with get_runner(component_defaults=component_defaults) as runner: self._run(runner, args) - def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None: + def _wait_and_exit( + self, runner: Runner, app_handle: str, log: bool, tee_logs: bool = False + ) -> None: logger.info("Waiting for the app to finish...") - log_thread = self._start_log_thread(runner, app_handle) if log else None + log_thread = ( + self._start_log_thread(runner, app_handle, tee_logs_enabled=tee_logs) + if log + else None + ) status = runner.wait(app_handle, wait_interval=1) if not status: @@ -287,15 +304,30 @@ def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None: else: logger.debug(status) - def _start_log_thread(self, runner: Runner, app_handle: str) -> threading.Thread: - thread = tee_logs( - dst=sys.stderr, - app_handle=app_handle, - regex=None, - runner=runner, - should_tail=True, - streams=None, - colorize=not sys.stderr.closed and sys.stderr.isatty(), - ) + def _start_log_thread( + self, runner: Runner, app_handle: str, tee_logs_enabled: bool = False + ) -> threading.Thread: + if tee_logs_enabled: + thread = tee_logs( + dst=sys.stderr, + app_handle=app_handle, + regex=None, + runner=runner, + should_tail=True, + streams=None, + colorize=not sys.stderr.closed and sys.stderr.isatty(), + ) + else: + thread = threading.Thread( + target=get_logs, + kwargs={ + "file": sys.stderr, + "runner": runner, + "identifier": app_handle, + "regex": None, + "should_tail": True, + }, + ) + thread.daemon = True thread.start() return thread