Skip to content

Commit 48d5c7d

Browse files
AbishekSfacebook-github-bot
authored andcommitted
Introduce StructuredRunOpt for more complex runopts
Summary: StructuredRunOpt can be used to create new complex runopts This can be used as for example: ``` dataclass class UlimitTest(StructuredRunOpt): name: str hard: int soft: int def template(self) -> str: return "{name},{soft:d},{hard:d}" ``` This comes with 1. template() that helps the from_repr() use that template to map to the fields with types 2. __eq__ to check equality between two Differential Revision: D85159071
1 parent 5053c87 commit 48d5c7d

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

torchx/specs/api.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import abc
910
import asyncio
1011
import copy
1112
import inspect
@@ -17,6 +18,7 @@
1718
import shutil
1819
import typing
1920
import warnings
21+
from abc import abstractmethod
2022
from dataclasses import asdict, dataclass, field
2123
from datetime import datetime
2224
from enum import Enum, IntEnum
@@ -39,6 +41,8 @@
3941
Union,
4042
)
4143

44+
import parse
45+
4246
from torchx.util.types import to_dict
4347

4448
_APP_STATUS_FORMAT_TEMPLATE = """AppStatus:
@@ -930,6 +934,34 @@ def get_type_name(tp: Type[CfgVal]) -> str:
930934
return str(tp)
931935

932936

937+
U = TypeVar("U", bound="StructuredRunOpt")
938+
939+
940+
class StructuredRunOpt(abc.ABC):
941+
942+
@abstractmethod
943+
def template(self) -> str:
944+
"""
945+
Returns the template string for the StructuredRunOpt
946+
"""
947+
pass
948+
949+
def __repr__(self) -> str:
950+
return self.template().format(**asdict(self))
951+
952+
def __eq__(self, other: Type[T]) -> bool:
953+
return isinstance(other, type(self)) and asdict(self) == asdict(other)
954+
955+
@classmethod
956+
def from_repr(cls: Type[U], repr: str) -> U:
957+
"""
958+
Parses the repr string and returns a StructuredRunOpt object
959+
"""
960+
tmpl = cls.__new__(cls).template()
961+
result = parse.parse(tmpl, repr)
962+
return cls(**result.named)
963+
964+
933965
@dataclass
934966
class runopt:
935967
"""

torchx/specs/test/api_test.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import time
1515
import unittest
1616
import warnings
17-
from dataclasses import asdict
17+
from dataclasses import asdict, dataclass
1818
from pathlib import Path
1919
from typing import Dict, List, Mapping, Tuple, Union
2020
from unittest import mock
@@ -43,6 +43,7 @@
4343
RoleStatus,
4444
runopt,
4545
runopts,
46+
StructuredRunOpt,
4647
TORCHX_HOME,
4748
Workspace,
4849
)
@@ -550,6 +551,40 @@ def test_getset_metadata(self) -> None:
550551
self.assertEqual(None, app.metadata.get("non_existent"))
551552

552553

554+
class StructuredRunOptTest(unittest.TestCase):
555+
556+
@dataclass
557+
class UlimitTest(StructuredRunOpt):
558+
name: str
559+
hard: int
560+
soft: int
561+
562+
def template(self) -> str:
563+
return "{name},{soft:d},{hard:d}"
564+
565+
def test_structured_runopt(self) -> None:
566+
opt = self.UlimitTest(name="test", hard=100, soft=50)
567+
568+
# Test class from_repr
569+
self.assertEqual(
570+
opt,
571+
self.UlimitTest.from_repr(
572+
"test,50,100",
573+
),
574+
)
575+
576+
# Test repr
577+
self.assertEqual(
578+
"StructuredRunOptTest.UlimitTest(name='test', hard=100, soft=50)", repr(opt)
579+
)
580+
581+
# Test equality
582+
opt_other = self.UlimitTest(name="test", hard=100, soft=50)
583+
self.assertEqual(opt, opt_other)
584+
opt_other = self.UlimitTest(name="test", hard=100, soft=70)
585+
self.assertNotEqual(opt, opt_other)
586+
587+
553588
class RunConfigTest(unittest.TestCase):
554589
def get_cfg(self) -> Mapping[str, CfgVal]:
555590
return {

0 commit comments

Comments
 (0)