diff --git a/torchx/specs/api.py b/torchx/specs/api.py index abff9e250..10e79d06d 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -936,42 +936,12 @@ class runopt: Represents the metadata about the specific run option """ - class AutoAlias(IntEnum): - snake_case = 0x1 - SNAKE_CASE = 0x2 - camelCase = 0x4 - - @staticmethod - def convert_to_camel_case(alias: str) -> str: - words = re.split(r"[_\-\s]+|(?<=[a-z])(?=[A-Z])", alias) - words = [w for w in words if w] # Remove empty strings - if not words: - return "" - return words[0].lower() + "".join(w.capitalize() for w in words[1:]) - - @staticmethod - def convert_to_snake_case(alias: str) -> str: - alias = re.sub(r"[-\s]+", "_", alias) - alias = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", alias) - alias = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", alias) - return alias.lower() - - @staticmethod - def convert_to_const_case(alias: str) -> str: - return runopt.AutoAlias.convert_to_snake_case(alias).upper() - - class alias(str): - pass - - class deprecated(str): - pass - default: CfgVal opt_type: Type[CfgVal] is_required: bool help: str - aliases: set[alias] | None = None - deprecated_aliases: set[deprecated] | None = None + aliases: list[str] | None = None + deprecated_aliases: list[str] | None = None @property def is_type_list_of_str(self) -> bool: @@ -1257,85 +1227,23 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]: cfg[key] = val return cfg - def _generate_aliases( - self, auto_alias: int, aliases: set[str] - ) -> set[runopt.alias]: - generated_aliases = set() - for alias in aliases: - if auto_alias & runopt.AutoAlias.camelCase: - generated_aliases.add(runopt.AutoAlias.convert_to_camel_case(alias)) - if auto_alias & runopt.AutoAlias.snake_case: - generated_aliases.add(runopt.AutoAlias.convert_to_snake_case(alias)) - if auto_alias & runopt.AutoAlias.SNAKE_CASE: - generated_aliases.add(runopt.AutoAlias.convert_to_const_case(alias)) - return generated_aliases - - def _get_primary_key_and_aliases( - self, - cfg_key: list[str | int] | str, - ) -> tuple[str, set[runopt.alias], set[runopt.deprecated]]: - """ - Returns the primary key and aliases for the given cfg_key. - """ - if isinstance(cfg_key, str): - return cfg_key, set(), set() - - if len(cfg_key) == 0: - raise ValueError("cfg_key must be a non-empty list") - - if isinstance(cfg_key[0], runopt.alias) or isinstance( - cfg_key[0], runopt.deprecated - ): - warnings.warn( - "The main name of the run option should be the head of the list.", - UserWarning, - stacklevel=2, - ) - primary_key = None - auto_alias = 0x0 - aliases = set[runopt.alias]() - deprecated_aliases = set[runopt.deprecated]() - for name in cfg_key: - if isinstance(name, runopt.alias): - aliases.add(name) - elif isinstance(name, runopt.deprecated): - deprecated_aliases.add(name) - elif isinstance(name, int): - auto_alias = auto_alias | name - else: - if primary_key is not None: - raise ValueError( - f" Given more than one primary key: {primary_key}, {name}. Please use runopt.alias type for aliases. " - ) - primary_key = name - if primary_key is None or primary_key == "": - raise ValueError( - "Missing cfg_key. Please provide one other than the aliases." - ) - if auto_alias != 0x0: - aliases_to_generate_for = aliases | {primary_key} - additional_aliases = self._generate_aliases( - auto_alias, aliases_to_generate_for - ) - aliases.update(additional_aliases) - return primary_key, aliases, deprecated_aliases - def add( self, - cfg_key: str | list[str | int], + cfg_key: str, type_: Type[CfgVal], help: str, default: CfgVal = None, required: bool = False, + aliases: Optional[list[str]] = None, + deprecated_aliases: Optional[list[str]] = None, ) -> None: """ Adds the ``config`` option with the given help string and ``default`` value (if any). If the ``default`` is not specified then this option is a required option. """ - primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases( - cfg_key - ) + aliases = aliases or [] + deprecated_aliases = deprecated_aliases or [] if required and default is not None: raise ValueError( f"Required option: {cfg_key} must not specify default value. Given: {default}" @@ -1346,12 +1254,20 @@ def add( f"Option: {cfg_key}, must be of type: {type_}." f" Given: {default} ({type(default).__name__})" ) - opt = runopt(default, type_, required, help, aliases, deprecated_aliases) + + opt = runopt( + default, + type_, + required, + help, + list(set(aliases)), + list(set(deprecated_aliases)), + ) for alias in aliases: - self._alias_to_key[alias] = primary_key + self._alias_to_key[alias] = cfg_key for deprecated_alias in deprecated_aliases: - self._alias_to_key[deprecated_alias] = primary_key - self._opts[primary_key] = opt + self._alias_to_key[deprecated_alias] = cfg_key + self._opts[cfg_key] = opt def update(self, other: "runopts") -> None: self._opts.update(other._opts) diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 5b92a7609..49be2474e 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -605,7 +605,8 @@ def test_runopts_add(self) -> None: def test_runopts_add_with_aliases(self) -> None: opts = runopts() opts.add( - ["job_priority", runopt.alias("jobPriority")], + "job_priority", + aliases=["jobPriority"], type_=str, help="priority for the job", ) @@ -616,7 +617,8 @@ def test_runopts_add_with_aliases(self) -> None: def test_runopts_resolve_with_aliases(self) -> None: opts = runopts() opts.add( - ["job_priority", runopt.alias("jobPriority")], + "job_priority", + aliases=["jobPriority"], type_=str, help="priority for the job", ) @@ -628,71 +630,45 @@ def test_runopts_resolve_with_aliases(self) -> None: def test_runopts_resolve_with_none_valued_aliases(self) -> None: opts = runopts() opts.add( - ["job_priority", runopt.alias("jobPriority")], + "job_priority", + aliases=["jobPriority"], type_=str, help="priority for the job", ) opts.add( - ["modelTypeName", runopt.alias("model_type_name")], + "model_type_name", + aliases=["modelTypeName"], type_=Union[str, None], help="ML Hub Model Type to attribute resource utilization for job", ) - resolved_opts = opts.resolve({"model_type_name": None, "jobPriority": "low"}) - self.assertEqual(resolved_opts.get("model_type_name"), None) + resolved_opts = opts.resolve({"modelTypeName": None, "jobPriority": "low"}) + self.assertEqual(resolved_opts.get("modelTypeName"), None) self.assertEqual(resolved_opts.get("jobPriority"), "low") - self.assertEqual(resolved_opts, {"model_type_name": None, "jobPriority": "low"}) + self.assertEqual(resolved_opts, {"modelTypeName": None, "jobPriority": "low"}) with self.assertRaises(InvalidRunConfigException): - opts.resolve({"model_type_name": None, "modelTypeName": "low"}) + opts.resolve({"modelTypeName": None, "model_type_name": "low"}) def test_runopts_add_with_deprecated_aliases(self) -> None: opts = runopts() - with warnings.catch_warnings(record=True) as w: - opts.add( - [runopt.deprecated("jobPriority"), "job_priority"], - type_=str, - help="run as user", - ) - self.assertEqual(len(w), 1) - self.assertEqual(w[0].category, UserWarning) - self.assertEqual( - str(w[0].message), - "The main name of the run option should be the head of the list.", - ) + opts.add( + "job_priority", + deprecated_aliases=["priority"], + type_=str, + help="run as user", + ) opts.resolve({"job_priority": "high"}) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - opts.resolve({"jobPriority": "high"}) + opts.resolve({"priority": "high"}) self.assertEqual(len(w), 1) self.assertEqual(w[0].category, UserWarning) self.assertEqual( str(w[0].message), - "Run option `jobPriority` is deprecated, use `job_priority` instead", + "Run option `priority` is deprecated, use `job_priority` instead", ) - def test_runopt_auto_aliases(self) -> None: - opts = runopts() - opts.add( - ["job_priority", runopt.AutoAlias.camelCase], - type_=str, - help="run as user", - ) - opts.add( - [ - "model_type_name", - runopt.AutoAlias.camelCase | runopt.AutoAlias.SNAKE_CASE, - ], - type_=str, - help="run as user", - ) - self.assertEqual(2, len(opts._opts)) - self.assertIsNotNone(opts.get("job_priority")) - self.assertIsNotNone(opts.get("jobPriority")) - self.assertIsNotNone(opts.get("model_type_name")) - self.assertIsNotNone(opts.get("modelTypeName")) - self.assertIsNotNone(opts.get("MODEL_TYPE_NAME")) - def get_runopts(self) -> runopts: opts = runopts() opts.add("run_as", type_=str, help="run as user", required=True)