Skip to content

Commit 1e3df20

Browse files
authored
Allow dryrun from cli in JSON launching case
Differential Revision: D83681938 Pull Request resolved: #1135
1 parent 1f2b94e commit 1e3df20

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

torchx/cli/cmd_run.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,16 @@ def _get_torchx_stdin_args(
379379
if not args.stdin:
380380
return None
381381
if self._stdin_data_json is None:
382-
self._stdin_data_json = self.torchx_json_from_stdin()
382+
self._stdin_data_json = self.torchx_json_from_stdin(args)
383383
return self._stdin_data_json
384384

385-
def torchx_json_from_stdin(self) -> Dict[str, Any]:
385+
def torchx_json_from_stdin(
386+
self, args: Optional[argparse.Namespace] = None
387+
) -> Dict[str, Any]:
386388
try:
387389
stdin_data_json = json.load(sys.stdin)
390+
if args and args.dryrun:
391+
stdin_data_json["dryrun"] = True
388392
if not isinstance(stdin_data_json, dict):
389393
logger.error(
390394
"Invalid JSON input for `torchx run` command. Expected a dictionary."
@@ -413,6 +417,8 @@ def verify_no_extra_args(self, args: argparse.Namespace) -> None:
413417
continue
414418
if action.dest == "help": # Skip help
415419
continue
420+
if action.dest == "dryrun": # Skip dryrun
421+
continue
416422

417423
current_value = getattr(args, action.dest, None)
418424
default_value = action.default

torchx/cli/test/cmd_run_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,17 @@ def test_verify_no_extra_args_stdin_with_scheduler(self) -> None:
393393

394394
def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
395395
"""Test that boolean flags conflict with stdin."""
396-
boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"]
396+
boolean_flags = ["--wait", "--log", "--tee_logs"]
397397
for flag in boolean_flags:
398398
args = self.parser.parse_args(["--stdin", flag])
399399
with self.assertRaises(SystemExit):
400400
self.cmd_run.verify_no_extra_args(args)
401401

402+
def test_verify_no_extra_args_stdin_dryrun_pass(self) -> None:
403+
"""Test that dryrun is allowed."""
404+
args = self.parser.parse_args(["--stdin", "--dryrun"])
405+
self.cmd_run.verify_no_extra_args(args)
406+
402407
def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
403408
"""Test that arguments with values conflict with stdin."""
404409
args = self.parser.parse_args(["--stdin", "--workspace", "/custom/path"])

0 commit comments

Comments
 (0)