diff --git a/torchx/components/dist.py b/torchx/components/dist.py index 55718474d..ebeaa36c2 100644 --- a/torchx/components/dist.py +++ b/torchx/components/dist.py @@ -132,9 +132,6 @@ def spmd( j: {nnodes}x{nproc_per_node}. For GPU hosts omitting nproc_per_node will infer it from the GPU count on the host env: environment variables to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3) max_retries: the number of scheduler retries allowed - rdzv_port: the port on rank0's host to use for hosting the c10d store used for rendezvous. - Only takes effect when running multi-node. When running single node, this parameter - is ignored and a random free port is chosen. mounts: (for docker based runs only) mounts to mount into the worker environment/container (ex. type=,src=/host,dst=/job[,readonly]). debug: whether to run with preset debug flags enabled @@ -174,6 +171,7 @@ def ddp( max_retries: int = 0, rdzv_port: int = 29500, rdzv_backend: str = "c10d", + rdzv_conf: Optional[str] = None, mounts: Optional[List[str]] = None, debug: bool = False, tee: int = 3, @@ -208,6 +206,7 @@ def ddp( Only takes effect when running multi-node. When running single node, this parameter is ignored and a random free port is chosen. rdzv_backend: the rendezvous backend to use. Only takes effect when running multi-node. + rdzv_conf: the additional rendezvous configuration to use (ex. join_timeout=600,close_timeout=600,timeout=600). mounts: mounts to mount into the worker environment/container (ex. type=,src=/host,dst=/job[,readonly]). See scheduler documentation for more info. debug: whether to run with preset debug flags enabled @@ -258,6 +257,7 @@ def ddp( "torchrun", "--rdzv_backend", rdzv_backend, + *(["--rdzv_conf", rdzv_conf] if rdzv_conf is not None else []), "--rdzv_endpoint", rdzv_endpoint, "--rdzv_id", diff --git a/torchx/components/test/dist_test.py b/torchx/components/test/dist_test.py index ac57a1bf0..d63ce43b9 100644 --- a/torchx/components/test/dist_test.py +++ b/torchx/components/test/dist_test.py @@ -41,8 +41,10 @@ def test_ddp_debug(self) -> None: self.assertEqual(env[k], v) def test_ddp_rdzv_backend_static(self) -> None: - app = ddp(script="foo.py", rdzv_backend="static") + rdzv_conf = "join_timeout=600,close_timeout=600,timeout=600" + app = ddp(script="foo.py", rdzv_backend="static", rdzv_conf=rdzv_conf) cmd = app.roles[0].args[1] + self.assertTrue(f"--rdzv_conf {rdzv_conf}" in cmd) self.assertTrue("--rdzv_backend static" in cmd) self.assertTrue("--node_rank" in cmd)