diff --git a/docs/source/specs.rst b/docs/source/specs.rst index 6fd1ea2c4..62224243d 100644 --- a/docs/source/specs.rst +++ b/docs/source/specs.rst @@ -52,6 +52,9 @@ Run Configs .. autoclass:: runopts :members: +.. autoclass:: StructuredRunOpt + :members: + Run Status -------------- .. autoclass:: AppStatus diff --git a/torchx/specs/__init__.py b/torchx/specs/__init__.py index d43470433..da0a21cdf 100644 --- a/torchx/specs/__init__.py +++ b/torchx/specs/__init__.py @@ -43,6 +43,7 @@ RoleStatus, runopt, runopts, + StructuredRunOpt, TORCHX_HOME, UnknownAppException, UnknownSchedulerException, @@ -226,6 +227,7 @@ def gpu_x_1() -> Dict[str, Resource]: "RoleStatus", "runopt", "runopts", + "StructuredRunOpt", "UnknownAppException", "UnknownSchedulerException", "InvalidRunConfigException", diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 10e79d06d..530e17461 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -6,6 +6,7 @@ # pyre-strict +import abc import asyncio import copy import inspect @@ -17,9 +18,10 @@ import shutil import typing import warnings -from dataclasses import asdict, dataclass, field +from abc import abstractmethod +from dataclasses import asdict, dataclass, field, fields from datetime import datetime -from enum import Enum, IntEnum +from enum import Enum from json import JSONDecodeError from string import Template from typing import ( @@ -36,10 +38,10 @@ Tuple, Type, TypeVar, - Union, ) from torchx.util.types import to_dict +from typing_extensions import Self _APP_STATUS_FORMAT_TEMPLATE = """AppStatus: State: ${state} @@ -877,11 +879,81 @@ def __init__(self, status: AppStatus, *args: object) -> None: self.status = status -# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str]) +U = TypeVar("U", bound="StructuredRunOpt") + + +class StructuredRunOpt(abc.ABC): + """ + StructuredRunOpt is a class that represents a structured run option. + This is to allow for more complex types than currently supported. + + Usage + + .. doctest:: + @dataclass + class Ulimit(StructuredRunOpt): + name: str + hard: int + soft: int + + def template(self) -> str: + # The template string should contain the field names of the Ulimit object. + # The field names are mapped to the keys in the repr string. + # These are comma seperated and wrapped in curly braces. + return "{name},{soft},{hard}" + + opts = runopts() + opts.add("ulimit", type_=self.Ulimit, help="ulimits for the container") + + # .from_repr() is used to create a Ulimit object from a string representation that is the template. + cfg = opts.resolve( + { + "ulimit": self.Ulimit.from_repr( + "test,50,100", + ) + } + ) + + """ + + @abstractmethod + def template(self) -> str: + """ + Returns the template string for the StructuredRunOpt. + These are mapped to the field names of the StructuredRunOpt object. + """ + ... + + def __repr__(self) -> str: + key_value = ", ".join(**asdict(self)) + return f"{self.__class__.__name__}({key_value})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and asdict(self) == asdict(other) + + @classmethod + def from_repr(cls, repr: str) -> Self: + """ + Parses the repr string and returns a StructuredRunOpt object + """ + tmpl = cls.__new__(cls).template() + fields_from_tmpl = [field.strip("{}") for field in tmpl.split(",")] + values = repr.split(",") + gd = dict(zip(fields_from_tmpl, values, strict=True)) + for field_cls in fields(cls): + name = field_cls.name + field_type = field_cls.type + value = gd.get(name) + gd[name] = field_type(value) + return cls(**gd) + + +# valid run cfg values; support primitives (str, int, float, bool, List[str], Dict[str, str]) +# And StructuredRunOpt Type for more complex types. # TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly # in isinstance(). Should replace with that. # see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type -CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None] +CfgVal = str | int | float | bool | List[str] | Dict[str, str] | StructuredRunOpt | None T = TypeVar("T") diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 49be2474e..d2b51a0b5 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -14,7 +14,7 @@ import time import unittest import warnings -from dataclasses import asdict +from dataclasses import asdict, dataclass from pathlib import Path from typing import Dict, List, Mapping, Tuple, Union from unittest import mock @@ -43,6 +43,7 @@ RoleStatus, runopt, runopts, + StructuredRunOpt, TORCHX_HOME, Workspace, ) @@ -550,6 +551,56 @@ def test_getset_metadata(self) -> None: self.assertEqual(None, app.metadata.get("non_existent")) +class StructuredRunOptTest(unittest.TestCase): + + @dataclass + class UlimitTest(StructuredRunOpt): + name: str + hard: int + soft: int + + def template(self) -> str: + return "{name},{soft},{hard}" + + def test_structured_runopt(self) -> None: + opt = self.UlimitTest(name="test", hard=100, soft=50) + + # Test class from_repr + self.assertEqual( + opt, + self.UlimitTest.from_repr( + "test,50,100", + ), + ) + + # Test repr + self.assertEqual( + "StructuredRunOptTest.UlimitTest(name='test', hard=100, soft=50)", repr(opt) + ) + + # Test equality + opt_other = self.UlimitTest(name="test", hard=100, soft=50) + self.assertEqual(opt, opt_other) + opt_other = self.UlimitTest(name="test", hard=100, soft=70) + self.assertNotEqual(opt, opt_other) + + # Test with runopts + + opts = runopts() + opts.add("ulimit", type_=self.UlimitTest, help="test ulimit") + cfg = opts.resolve( + { + "ulimit": self.UlimitTest.from_repr( + "test,50,100", + ) + } + ) + self.assertEqual( + cfg.get("ulimit"), + self.UlimitTest(name="test", hard=100, soft=50), + ) + + class RunConfigTest(unittest.TestCase): def get_cfg(self) -> Mapping[str, CfgVal]: return {