Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchx/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
RoleStatus,
runopt,
runopts,
TORCHX_HOME,
UnknownAppException,
UnknownSchedulerException,
VolumeMount,
Expand All @@ -53,6 +54,7 @@

GiB: int = 1024


ResourceFactory = Callable[[], Resource]

AWS_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
Expand Down
28 changes: 28 additions & 0 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import inspect
import json
import logging as logger
import os
import pathlib
import re
import typing
from dataclasses import asdict, dataclass, field
Expand Down Expand Up @@ -66,6 +68,32 @@
RESET = "\033[0m"


def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
"""
Path to the "dot-directory" for torchx.
Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.

Usage:

.. doc-test::

from pathlib import Path
from torchx.specs import TORCHX_HOME

assert TORCHX_HOME() == Path.home() / ".torchx"
assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
```
"""

default_dir = str(pathlib.Path.home() / ".torchx")
torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))

torchx_home = torchx_home / os.path.sep.join(subdir_paths)
torchx_home.mkdir(parents=True, exist_ok=True)

return torchx_home


# ========================================
# ==== Distributed AppDef API =======
# ========================================
Expand Down
31 changes: 31 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import asyncio
import concurrent
import os
import tempfile
import time
import unittest
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Mapping, Tuple, Union
from unittest import mock
from unittest.mock import MagicMock

import torchx.specs.named_resources_aws as named_resources_aws
Expand All @@ -40,9 +43,37 @@
RoleStatus,
runopt,
runopts,
TORCHX_HOME,
)


class TorchXHomeTest(unittest.TestCase):
# guard against TORCHX_HOME set outside the test
@mock.patch.dict(os.environ, {}, clear=True)
def test_TORCHX_HOME_default(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
user_home = Path(tmpdir) / "sally"
with mock.patch("pathlib.Path.home", return_value=user_home):
torchx_home = TORCHX_HOME()
self.assertEqual(torchx_home, user_home / ".torchx")
self.assertTrue(torchx_home.exists())

def test_TORCHX_HOME_override(self) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
override_torchx_home = Path(tmpdir) / "test" / ".torchx"
with mock.patch.dict(
os.environ, {"TORCHX_HOME": str(override_torchx_home)}
):
torchx_home = TORCHX_HOME()
conda_pack_out = TORCHX_HOME("conda-pack", "out")

self.assertEqual(override_torchx_home, torchx_home)
self.assertEqual(torchx_home / "conda-pack" / "out", conda_pack_out)

self.assertTrue(torchx_home.is_dir())
self.assertTrue(conda_pack_out.is_dir())


class AppDryRunInfoTest(unittest.TestCase):
def test_repr(self) -> None:
request_mock = MagicMock()
Expand Down
Loading