|
13 | 13 |
|
14 | 14 | import hydra |
15 | 15 |
|
16 | | -from torchrl import torchrl_logger |
| 16 | +from torchrl import merge_ray_runtime_env, torchrl_logger |
17 | 17 | from torchrl.data.llm.history import History |
18 | 18 | from torchrl.record.loggers.wandb import WandbLogger |
19 | 19 | from torchrl.weight_update.llm import get_model_metadata |
@@ -319,19 +319,9 @@ def main(cfg): |
319 | 319 | if not k.startswith("_") |
320 | 320 | } |
321 | 321 |
|
322 | | - # Add computed GPU configuration |
| 322 | + # Add computed GPU configuration and merge with default runtime_env |
323 | 323 | ray_init_config["num_gpus"] = device_config["ray_num_gpus"] |
324 | | - # Ensure runtime_env and env_vars exist |
325 | | - if "runtime_env" not in ray_init_config: |
326 | | - ray_init_config["runtime_env"] = {} |
327 | | - if not isinstance(ray_init_config["runtime_env"], dict): |
328 | | - ray_init_config["runtime_env"] = dict(ray_init_config["runtime_env"]) |
329 | | - if "env_vars" not in ray_init_config["runtime_env"]: |
330 | | - ray_init_config["runtime_env"]["env_vars"] = {} |
331 | | - if not isinstance(ray_init_config["runtime_env"]["env_vars"], dict): |
332 | | - ray_init_config["runtime_env"]["env_vars"] = dict( |
333 | | - ray_init_config["runtime_env"]["env_vars"] |
334 | | - ) |
| 324 | + ray_init_config = merge_ray_runtime_env(ray_init_config) |
335 | 325 | torchrl_logger.info(f"Ray init config: {ray_init_config=}") |
336 | 326 | ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") |
337 | 327 | if ray_managed_externally: |
|
0 commit comments