Skip to content

Commit c6378f3

Browse files
AbishekSfacebook-github-bot
authored andcommitted
Introduce StructuredRunOpt for more complex runopts (#1154)
Summary: StructuredRunOpt can be used to create new complex runopts This can be used as for example: ``` dataclass class UlimitTest(StructuredRunOpt): name: str hard: int soft: int def template(self) -> str: return "{name},{soft:d},{hard:d}" ``` This comes with 1. template() that helps the from_repr() use that template to map to the fields with types 2. __eq__ to check equality between two Add this new type to CfgVal's acceptable types. Also modify a piece of code that could use CfgVal instead of typing out entire list of types in. https://www.internalfb.com/code/fbsource/[37f968940832a633afa761e829e81184858cf6b8]/fbcode/msl/experimental/training_execution_environment/monarch_backend/api/launch_cluster.py?lines=196-199 Differential Revision: D85159071
1 parent 7f68b89 commit c6378f3

File tree

4 files changed

+136
-2
lines changed

4 files changed

+136
-2
lines changed

docs/source/specs.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ Run Configs
5252
.. autoclass:: runopts
5353
:members:
5454

55+
.. autoclass:: StructuredRunOpt
56+
:members:
57+
5558
Run Status
5659
--------------
5760
.. autoclass:: AppStatus

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ docker
44
filelock
55
fsspec>=2023.10.0
66
tabulate
7+
parse

torchx/specs/api.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import abc
910
import asyncio
1011
import copy
1112
import inspect
@@ -17,6 +18,7 @@
1718
import shutil
1819
import typing
1920
import warnings
21+
from abc import abstractmethod
2022
from dataclasses import asdict, dataclass, field
2123
from datetime import datetime
2224
from enum import Enum, IntEnum
@@ -39,6 +41,8 @@
3941
Union,
4042
)
4143

44+
import parse
45+
4246
from torchx.util.types import to_dict
4347

4448
_APP_STATUS_FORMAT_TEMPLATE = """AppStatus:
@@ -877,11 +881,86 @@ def __init__(self, status: AppStatus, *args: object) -> None:
877881
self.status = status
878882

879883

884+
U = TypeVar("U", bound="StructuredRunOpt")
885+
886+
887+
class StructuredRunOpt(abc.ABC):
888+
"""
889+
StructuredRunOpt is a class that represents a structured run option.
890+
This is to allow for more complex types than currently supported.
891+
892+
Usage
893+
894+
.. code-block:: python
895+
896+
@dataclass
897+
class Ulimit(StructuredRunOpt):
898+
name: str
899+
hard: int
900+
soft: int
901+
902+
def template(self) -> str:
903+
return "{name},{soft:d},{hard:d}"
904+
905+
"""
906+
907+
@abstractmethod
908+
def template(self) -> str:
909+
"""
910+
Returns the template string for the StructuredRunOpt.
911+
These are mapped to the field names of the StructuredRunOpt object.
912+
913+
Usage
914+
915+
.. code-block:: python
916+
917+
@dataclass
918+
class Ulimit(StructuredRunOpt):
919+
name: str
920+
hard: int
921+
soft: int
922+
923+
def template(self) -> str:
924+
# The template string should contain the field names of the Ulimit object.
925+
# Template strings also may need types as below where `:d` is for integer type.
926+
return "{name},{soft:d},{hard:d}"
927+
928+
opts = runopts()
929+
opts.add("ulimit", type_=self.Ulimit, help="ulimits for the container")
930+
931+
# .from_repr() is used to create a Ulimit object from a string representation that is the template.
932+
cfg = opts.resolve(
933+
{
934+
"ulimit": self.Ulimit.from_repr(
935+
"test,50,100",
936+
)
937+
}
938+
)
939+
940+
"""
941+
pass
942+
943+
def __repr__(self) -> str:
944+
return self.template().format(**asdict(self))
945+
946+
def __eq__(self, other: object) -> bool:
947+
return isinstance(other, type(self)) and asdict(self) == asdict(other)
948+
949+
@classmethod
950+
def from_repr(cls: Type[U], repr: str) -> U:
951+
"""
952+
Parses the repr string and returns a StructuredRunOpt object
953+
"""
954+
tmpl = cls.__new__(cls).template()
955+
result = parse.parse(tmpl, repr)
956+
return cls(**result.named)
957+
958+
880959
# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
881960
# TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
882961
# in isinstance(). Should replace with that.
883962
# see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
884-
CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]
963+
CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], StructuredRunOpt, None]
885964

886965

887966
T = TypeVar("T")

torchx/specs/test/api_test.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import time
1515
import unittest
1616
import warnings
17-
from dataclasses import asdict
17+
from dataclasses import asdict, dataclass
1818
from pathlib import Path
1919
from typing import Dict, List, Mapping, Tuple, Union
2020
from unittest import mock
@@ -43,6 +43,7 @@
4343
RoleStatus,
4444
runopt,
4545
runopts,
46+
StructuredRunOpt,
4647
TORCHX_HOME,
4748
Workspace,
4849
)
@@ -550,6 +551,56 @@ def test_getset_metadata(self) -> None:
550551
self.assertEqual(None, app.metadata.get("non_existent"))
551552

552553

554+
class StructuredRunOptTest(unittest.TestCase):
555+
556+
@dataclass
557+
class UlimitTest(StructuredRunOpt):
558+
name: str
559+
hard: int
560+
soft: int
561+
562+
def template(self) -> str:
563+
return "{name},{soft:d},{hard:d}"
564+
565+
def test_structured_runopt(self) -> None:
566+
opt = self.UlimitTest(name="test", hard=100, soft=50)
567+
568+
# Test class from_repr
569+
self.assertEqual(
570+
opt,
571+
self.UlimitTest.from_repr(
572+
"test,50,100",
573+
),
574+
)
575+
576+
# Test repr
577+
self.assertEqual(
578+
"StructuredRunOptTest.UlimitTest(name='test', hard=100, soft=50)", repr(opt)
579+
)
580+
581+
# Test equality
582+
opt_other = self.UlimitTest(name="test", hard=100, soft=50)
583+
self.assertEqual(opt, opt_other)
584+
opt_other = self.UlimitTest(name="test", hard=100, soft=70)
585+
self.assertNotEqual(opt, opt_other)
586+
587+
# Test with runopts
588+
589+
opts = runopts()
590+
opts.add("ulimit", type_=self.UlimitTest, help="test ulimit")
591+
cfg = opts.resolve(
592+
{
593+
"ulimit": self.UlimitTest.from_repr(
594+
"test,50,100",
595+
)
596+
}
597+
)
598+
self.assertEqual(
599+
cfg.get("ulimit"),
600+
self.UlimitTest(name="test", hard=100, soft=50),
601+
)
602+
603+
553604
class RunConfigTest(unittest.TestCase):
554605
def get_cfg(self) -> Mapping[str, CfgVal]:
555606
return {

0 commit comments

Comments
 (0)