Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchx.specs as specs
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
from torchx.cli.cmd_base import SubCommand
from torchx.cli.cmd_log import get_logs
from torchx.runner import config, get_runner, Runner
from torchx.runner.config import load_sections
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
Expand Down Expand Up @@ -186,6 +187,12 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
help="optional parent run ID that this run belongs to."
" It can be used to group runs for experiment tracking purposes",
)
subparser.add_argument(
"--tee_logs",
action="store_true",
default=False,
help="Add additional prefix to log lines to indicate which replica is printing the log",
)
subparser.add_argument(
"component_name_and_args",
nargs=argparse.REMAINDER,
Expand Down Expand Up @@ -237,14 +244,18 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
print(app_handle)

if args.scheduler.startswith("local"):
self._wait_and_exit(runner, app_handle, log=True)
self._wait_and_exit(
runner, app_handle, log=True, tee_logs=args.tee_logs
)
else:
logger.info(f"Launched app: {app_handle}")
app_status = runner.status(app_handle)
if app_status:
logger.info(app_status.format())
if args.wait or args.log:
self._wait_and_exit(runner, app_handle, log=args.log)
self._wait_and_exit(
runner, app_handle, log=args.log, tee_logs=args.tee_logs
)

except (ComponentValidationException, ComponentNotFoundException) as e:
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
Expand All @@ -267,10 +278,16 @@ def run(self, args: argparse.Namespace) -> None:
with get_runner(component_defaults=component_defaults) as runner:
self._run(runner, args)

def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None:
def _wait_and_exit(
self, runner: Runner, app_handle: str, log: bool, tee_logs: bool = False
) -> None:
logger.info("Waiting for the app to finish...")

log_thread = self._start_log_thread(runner, app_handle) if log else None
log_thread = (
self._start_log_thread(runner, app_handle, tee_logs_enabled=tee_logs)
if log
else None
)

status = runner.wait(app_handle, wait_interval=1)
if not status:
Expand All @@ -287,15 +304,30 @@ def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None:
else:
logger.debug(status)

def _start_log_thread(self, runner: Runner, app_handle: str) -> threading.Thread:
thread = tee_logs(
dst=sys.stderr,
app_handle=app_handle,
regex=None,
runner=runner,
should_tail=True,
streams=None,
colorize=not sys.stderr.closed and sys.stderr.isatty(),
)
def _start_log_thread(
self, runner: Runner, app_handle: str, tee_logs_enabled: bool = False
) -> threading.Thread:
if tee_logs_enabled:
thread = tee_logs(
dst=sys.stderr,
app_handle=app_handle,
regex=None,
runner=runner,
should_tail=True,
streams=None,
colorize=not sys.stderr.closed and sys.stderr.isatty(),
)
else:
thread = threading.Thread(
target=get_logs,
kwargs={
"file": sys.stderr,
"runner": runner,
"identifier": app_handle,
"regex": None,
"should_tail": True,
},
)
thread.daemon = True
thread.start()
return thread
Loading