Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pathlib
import re
import typing
import warnings
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -894,11 +895,15 @@ class runopt:
class alias(str):
pass

class deprecated(str):
pass

default: CfgVal
opt_type: Type[CfgVal]
is_required: bool
help: str
aliases: list[alias] | None = None
deprecated_aliases: list[deprecated] | None = None

@property
def is_type_list_of_str(self) -> bool:
Expand Down Expand Up @@ -990,7 +995,7 @@ class runopts:

def __init__(self) -> None:
self._opts: Dict[str, runopt] = {}
self._alias_to_key: dict[runopt.alias, str] = {}
self._alias_to_key: dict[str, str] = {}

def __iter__(self) -> Iterator[Tuple[str, runopt]]:
return self._opts.items().__iter__()
Expand Down Expand Up @@ -1044,12 +1049,24 @@ def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
val = resolved_cfg.get(cfg_key)
resolved_name = None
aliases = runopt.aliases or []
deprecated_aliases = runopt.deprecated_aliases or []
if val is None:
for alias in aliases:
val = resolved_cfg.get(alias)
if alias in cfg or val is not None:
resolved_name = alias
break
for alias in deprecated_aliases:
val = resolved_cfg.get(alias)
if val is not None:
resolved_name = alias
use_instead = self._alias_to_key.get(alias)
warnings.warn(
f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
UserWarning,
stacklevel=2,
)
break
else:
resolved_name = cfg_key
for alias in aliases:
Expand Down Expand Up @@ -1175,20 +1192,32 @@ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
def _get_primary_key_and_aliases(
self,
cfg_key: list[str] | str,
) -> tuple[str, list[runopt.alias]]:
) -> tuple[str, list[runopt.alias], list[runopt.deprecated]]:
"""
Returns the primary key and aliases for the given cfg_key.
"""
if isinstance(cfg_key, str):
return cfg_key, []
return cfg_key, [], []

if len(cfg_key) == 0:
raise ValueError("cfg_key must be a non-empty list")

if isinstance(cfg_key[0], runopt.alias) or isinstance(
cfg_key[0], runopt.deprecated
):
warnings.warn(
"The main name of the run option should be the head of the list.",
UserWarning,
stacklevel=2,
)
primary_key = None
aliases = list[runopt.alias]()
deprecated_aliases = list[runopt.deprecated]()
for name in cfg_key:
if isinstance(name, runopt.alias):
aliases.append(name)
elif isinstance(name, runopt.deprecated):
deprecated_aliases.append(name)
else:
if primary_key is not None:
raise ValueError(
Expand All @@ -1199,7 +1228,7 @@ def _get_primary_key_and_aliases(
raise ValueError(
"Missing cfg_key. Please provide one other than the aliases."
)
return primary_key, aliases
return primary_key, aliases, deprecated_aliases

def add(
self,
Expand All @@ -1214,7 +1243,9 @@ def add(
value (if any). If the ``default`` is not specified then this option
is a required option.
"""
primary_key, aliases = self._get_primary_key_and_aliases(cfg_key)
primary_key, aliases, deprecated_aliases = self._get_primary_key_and_aliases(
cfg_key
)
if required and default is not None:
raise ValueError(
f"Required option: {cfg_key} must not specify default value. Given: {default}"
Expand All @@ -1225,9 +1256,11 @@ def add(
f"Option: {cfg_key}, must be of type: {type_}."
f" Given: {default} ({type(default).__name__})"
)
opt = runopt(default, type_, required, help, aliases)
opt = runopt(default, type_, required, help, aliases, deprecated_aliases)
for alias in aliases:
self._alias_to_key[alias] = primary_key
for deprecated_alias in deprecated_aliases:
self._alias_to_key[deprecated_alias] = primary_key
self._opts[primary_key] = opt

def update(self, other: "runopts") -> None:
Expand Down
27 changes: 27 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import tempfile
import time
import unittest
import warnings
from dataclasses import asdict
from pathlib import Path
from typing import Dict, List, Mapping, Tuple, Union
Expand Down Expand Up @@ -621,6 +622,32 @@ def test_runopts_resolve_with_none_valued_aliases(self) -> None:
with self.assertRaises(InvalidRunConfigException):
opts.resolve({"model_type_name": None, "modelTypeName": "low"})

def test_runopts_add_with_deprecated_aliases(self) -> None:
opts = runopts()
with warnings.catch_warnings(record=True) as w:
opts.add(
[runopt.deprecated("jobPriority"), "job_priority"],
type_=str,
help="run as user",
)
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
str(w[0].message),
"The main name of the run option should be the head of the list.",
)

opts.resolve({"job_priority": "high"})
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
opts.resolve({"jobPriority": "high"})
self.assertEqual(len(w), 1)
self.assertEqual(w[0].category, UserWarning)
self.assertEqual(
str(w[0].message),
"Run option `jobPriority` is deprecated, use `job_priority` instead",
)

def get_runopts(self) -> runopts:
opts = runopts()
opts.add("run_as", type_=str, help="run as user", required=True)
Expand Down
Loading