diff --git a/torchx/util/modules.py b/torchx/util/modules.py index 5ac418ee7..f8f40847b 100644 --- a/torchx/util/modules.py +++ b/torchx/util/modules.py @@ -8,16 +8,13 @@ import importlib from types import ModuleType -from typing import Callable, Optional, Union +from typing import Callable, Optional, TypeVar, Union def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]]: """ Loads and returns the module/module attr represented by the ``path``: ``full.module.path:optional_attr`` - :: - - 1. ``load_module("this.is.a_module:fn")`` -> equivalent to ``this.is.a_module.fn`` 1. ``load_module("this.is.a_module")`` -> equivalent to ``this.is.a_module`` """ @@ -33,3 +30,36 @@ def load_module(path: str) -> Union[ModuleType, Optional[Callable[..., object]]] return getattr(module, method) if method else module except Exception: return None + + +T = TypeVar("T") + + +def import_attr(name: str, attr: str, default: T) -> T: + """ + Imports ``name.attr`` and returns it if the module is found. + Otherwise, returns the specified ``default``. + Useful when getting an attribute from an optional dependency. + + Note that the ``default`` parameter is intentionally not an optional + since this function is intended to be used with modules that may not be + installed as a dependency. Therefore the caller must ALWAYS provide a + sensible default. + + Usage: + + .. code-block:: python + + aws_resources = import_attr("torchx.specs.named_resources_aws", "NAMED_RESOURCES", default={}) + all_resources.update(aws_resources) + + Raises: + AttributeError: If the module exists (e.g. can be imported) + but does not have an attribute with name ``attr``. + """ + try: + mod = importlib.import_module(name) + except ModuleNotFoundError: + return default + else: + return getattr(mod, attr) diff --git a/torchx/util/test/modules_test.py b/torchx/util/test/modules_test.py index 7b490fc70..4ea62ed6c 100644 --- a/torchx/util/test/modules_test.py +++ b/torchx/util/test/modules_test.py @@ -4,9 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import unittest -from torchx.util.modules import load_module +from torchx.util.modules import import_attr, load_module class ModulesTest(unittest.TestCase): @@ -21,3 +23,23 @@ def test_load_module_method(self) -> None: import os self.assertEqual(result, os.path.join) + + def test_try_import(self) -> None: + def _join(_0: str, *_1: str) -> str: + return "" # should never be called + + os_path_join = import_attr("os.path", "join", default=_join) + import os + + self.assertEqual(os.path.join, os_path_join) + + def test_try_import_non_existent_module(self) -> None: + should_default = import_attr("non.existent", "foo", default="bar") + self.assertEqual("bar", should_default) + + def test_try_import_non_existent_attr(self) -> None: + def _join(_0: str, *_1: str) -> str: + return "" # should never be called + + with self.assertRaises(AttributeError): + import_attr("os.path", "joyin", default=_join)