diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 5e4afa188..14e919f83 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -15,9 +15,10 @@ import pathlib import re import typing +import warnings from dataclasses import asdict, dataclass, field from datetime import datetime -from enum import Enum +from enum import Enum, IntEnum from json import JSONDecodeError from string import Template from typing import ( @@ -891,10 +892,42 @@ class runopt: Represents the metadata about the specific run option """ + class AutoAlias(IntEnum): + snake_case = 0x1 + SNAKE_CASE = 0x2 + camelCase = 0x4 + + @staticmethod + def convert_to_camel_case(alias: str) -> str: + words = re.split(r"[_\-\s]+|(?<=[a-z])(?=[A-Z])", alias) + words = [w for w in words if w] # Remove empty strings + if not words: + return "" + return words[0].lower() + "".join(w.capitalize() for w in words[1:]) + + @staticmethod + def convert_to_snake_case(alias: str) -> str: + alias = re.sub(r"[-\s]+", "_", alias) + alias = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", alias) + alias = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", alias) + return alias.lower() + + @staticmethod + def convert_to_const_case(alias: str) -> str: + return runopt.AutoAlias.convert_to_snake_case(alias).upper() + + class alias(str): + pass + + class deprecated(str): + pass + default: CfgVal opt_type: Type[CfgVal] is_required: bool help: str + aliases: set[alias] | None = None + deprecated_aliases: set[deprecated] | None = None @property def is_type_list_of_str(self) -> bool: @@ -986,6 +1019,7 @@ class runopts: def __init__(self) -> None: self._opts: Dict[str, runopt] = {} + self._alias_to_key: dict[str, str] = {} def __iter__(self) -> Iterator[Tuple[str, runopt]]: return self._opts.items().__iter__() @@ -1013,9 +1047,16 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool: def get(self, name: str) -> Optional[runopt]: """ - Returns option if any was registered, or None otherwise + Returns option if any was registered, or None otherwise. + First searches for the option by ``name``, then falls-back to matching ``name`` with any + registered aliases. + """ - return self._opts.get(name, None) + if name in self._opts: + return self._opts[name] + if name in self._alias_to_key: + return self._opts[self._alias_to_key[name]] + return None def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: """ @@ -1030,6 +1071,36 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: for cfg_key, runopt in self._opts.items(): val = resolved_cfg.get(cfg_key) + resolved_name = None + aliases = runopt.aliases or [] + deprecated_aliases = runopt.deprecated_aliases or [] + if val is None: + for alias in aliases: + val = resolved_cfg.get(alias) + if alias in cfg or val is not None: + resolved_name = alias + break + for alias in deprecated_aliases: + val = resolved_cfg.get(alias) + if val is not None: + resolved_name = alias + use_instead = self._alias_to_key.get(alias) + warnings.warn( + f"Run option `{alias}` is deprecated, use `{use_instead}` instead", + UserWarning, + stacklevel=2, + ) + break + else: + resolved_name = cfg_key + for alias in aliases: + duplicate_val = resolved_cfg.get(alias) + if alias in cfg or duplicate_val is not None: + raise InvalidRunConfigException( + f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`", + resolved_name, + cfg, + ) # check required opt if runopt.is_required and val is None: @@ -1049,7 +1120,7 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]: ) # not required and not set, set to default - if val is None: + if val is None and resolved_name is None: resolved_cfg[cfg_key] = runopt.default return resolved_cfg @@ -1142,9 +1213,72 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]: cfg[key] = val return cfg + def _generate_aliases( + self, auto_alias: int, aliases: set[str] + ) -> set[runopt.alias]: + generated_aliases = set() + for alias in aliases: + if auto_alias & runopt.AutoAlias.camelCase: + generated_aliases.add(runopt.AutoAlias.convert_to_camel_case(alias)) + if auto_alias & runopt.AutoAlias.snake_case: + generated_aliases.add(runopt.AutoAlias.convert_to_snake_case(alias)) + if auto_alias & runopt.AutoAlias.SNAKE_CASE: + generated_aliases.add(runopt.AutoAlias.convert_to_const_case(alias)) + return generated_aliases + + def _get_primary_key_and_aliases( + self, + cfg_key: list[str | int] | str, + ) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]: + """ + Returns the primary key and aliases for the given cfg_key. + """ + if isinstance(cfg_key, str): + return cfg_key, set(), set() + + if len(cfg_key) == 0: + raise ValueError("cfg_key must be a non-empty list") + + if isinstance(cfg_key[0], runopt.alias) or isinstance( + cfg_key[0], runopt.deprecated + ): + warnings.warn( + "The main name of the run option should be the head of the list.", + UserWarning, + stacklevel=2, + ) + primary_key = None + auto_alias = 0x0 + aliases = set[runopt.alias]() + deprecated_aliases = set[runopt.deprecated]() + for name in cfg_key: + if isinstance(name, runopt.alias): + aliases.add(name) + elif isinstance(name, runopt.deprecated): + deprecated_aliases.add(name) + elif isinstance(name, int): + auto_alias = auto_alias | name + else: + if primary_key is not None: + raise ValueError( + f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. " + ) + primary_key = name + if primary_key is None or primary_key == "": + raise ValueError( + "Missing cfg_key. Please provide one other than the aliases." + ) + if auto_alias != 0x0: + aliases_to_generate_for = aliases | {primary_key} + additional_aliases = self._generate_aliases( + auto_alias, aliases_to_generate_for + ) + aliases.update(additional_aliases) + return primary_key, aliases, deprecated_aliases + def add( self, - cfg_key: str, + cfg_key: str | list[str | int], type_: Type[CfgVal], help: str, default: CfgVal = None, @@ -1155,6 +1289,9 @@ def add( value (if any). If the ``default`` is not specified then this option is a required option. """ + primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases( + cfg_key + ) if required and default is not None: raise ValueError( f"Required option: {cfg_key} must not specify default value. Given: {default}" @@ -1165,8 +1302,12 @@ def add( f"Option: {cfg_key}, must be of type: {type_}." f" Given: {default} ({type(default).__name__})" ) - - self._opts[cfg_key] = runopt(default, type_, required, help) + opt = runopt(default, type_, required, help, aliases, deprecated_aliases) + for alias in aliases: + self._alias_to_key[alias] = primary_key + for deprecated_alias in deprecated_aliases: + self._alias_to_key[deprecated_alias] = primary_key + self._opts[primary_key] = opt def update(self, other: "runopts") -> None: self._opts.update(other._opts) diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 6bbacd5ee..d748579a1 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -13,6 +13,7 @@ import tempfile import time import unittest +import warnings from dataclasses import asdict from pathlib import Path from typing import Dict, List, Mapping, Tuple, Union @@ -578,6 +579,97 @@ def test_runopts_add(self) -> None: # this print is intentional (demonstrates the intended usecase) print(opts) + def test_runopts_add_with_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.alias("jobPriority")], + type_=str, + help="priority for the job", + ) + self.assertEqual(1, len(opts._opts)) + self.assertIsNotNone(opts.get("job_priority")) + self.assertIsNotNone(opts.get("jobPriority")) + + def test_runopts_resolve_with_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.alias("jobPriority")], + type_=str, + help="priority for the job", + ) + opts.resolve({"job_priority": "high"}) + opts.resolve({"jobPriority": "low"}) + with self.assertRaises(InvalidRunConfigException): + opts.resolve({"job_priority": "high", "jobPriority": "low"}) + + def test_runopts_resolve_with_none_valued_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.alias("jobPriority")], + type_=str, + help="priority for the job", + ) + opts.add( + ["modelTypeName", runopt.alias("model_type_name")], + type_=Union[str, None], + help="ML Hub Model Type to attribute resource utilization for job", + ) + resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"}) + self.assertEqual(resolved_opts.get("model_type_name"), None) + self.assertEqual(resolved_opts.get("jobPriority"), "low") + self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"}) + + with self.assertRaises(InvalidRunConfigException): + opts.resolve({"model_type_name": None, "modelTypeName": "low"}) + + def test_runopts_add_with_deprecated_aliases(self) -> None: + opts = runopts() + with warnings.catch_warnings(record=True) as w: + opts.add( + [runopt.deprecated("jobPriority"), "job_priority"], + type_=str, + help="run as user", + ) + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, UserWarning) + self.assertEqual( + str(w[0].message), + "The main name of the run option should be the head of the list.", + ) + + opts.resolve({"job_priority": "high"}) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + opts.resolve({"jobPriority": "high"}) + self.assertEqual(len(w), 1) + self.assertEqual(w[0].category, UserWarning) + self.assertEqual( + str(w[0].message), + "Run option `jobPriority` is deprecated, use `job_priority` instead", + ) + + def test_runopt_auto_aliases(self) -> None: + opts = runopts() + opts.add( + ["job_priority", runopt.AutoAlias.camelCase], + type_=str, + help="run as user", + ) + opts.add( + [ + "model_type_name", + runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE, + ], + type_=str, + help="run as user", + ) + self.assertEqual(2, len(opts._opts)) + self.assertIsNotNone(opts.get("job_priority")) + self.assertIsNotNone(opts.get("jobPriority")) + self.assertIsNotNone(opts.get("model_type_name")) + self.assertIsNotNone(opts.get("modelTypeName")) + self.assertIsNotNone(opts.get("MODEL_TYPE_NAME")) + def get_runopts(self) -> runopts: opts = runopts() opts.add("run_as", type_=str, help="run as user", required=True)