Skip to content

Commit c1fea89

Browse files
authored
Adding State.from_.any() (#497)
* feat: Adding State.from_.any() * feat: Items such states and transitions now appear on the order they are defined * refac: Using CallbackGroup on transition_list
1 parent bf55631 commit c1fea89

File tree

16 files changed

+343
-230
lines changed

16 files changed

+343
-230
lines changed

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
python-version: '3.13'
2222

2323
- name: Setup Graphviz
24-
uses: ts-graphviz/setup-graphviz@v1
24+
uses: ts-graphviz/setup-graphviz@v2
2525

2626
- name: Install uv
2727
uses: astral-sh/setup-uv@v3

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ Easily iterate over all states:
168168

169169
```py
170170
>>> [s.id for s in sm.states]
171-
['green', 'red', 'yellow']
171+
['green', 'yellow', 'red']
172172

173173
```
174174

statemachine/callbacks.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ def __init__(
7171
func,
7272
group: CallbackGroup,
7373
is_convention=False,
74+
is_event: bool = False,
7475
cond=None,
7576
priority: CallbackPriority = CallbackPriority.NAMING,
7677
expected_value=None,
7778
):
7879
self.func = func
7980
self.group = group
8081
self.is_convention = is_convention
82+
self.is_event = is_event
8183
self.cond = cond
8284
self.expected_value = expected_value
8385
self.priority = priority
@@ -88,7 +90,12 @@ def __init__(
8890
elif callable(func):
8991
self.reference = SpecReference.CALLABLE
9092
self.is_bounded = hasattr(func, "__self__")
91-
self.attr_name = func.__name__
93+
self.attr_name = (
94+
func.__name__ if not self.is_event or self.is_bounded else f"_{func.__name__}_"
95+
)
96+
if not self.is_bounded:
97+
func.attr_name = self.attr_name
98+
func.is_event = is_event
9299
else:
93100
self.reference = SpecReference.NAME
94101
self.attr_name = func
@@ -114,11 +121,6 @@ def __eq__(self, other):
114121
def __hash__(self):
115122
return id(self)
116123

117-
def _update_func(self, func: Callable, attr_name: str):
118-
self.func = func
119-
self.reference = SpecReference.CALLABLE
120-
self.attr_name = attr_name
121-
122124

123125
class SpecListGrouper:
124126
def __init__(self, list: "CallbackSpecList", group: CallbackGroup) -> None:
@@ -158,7 +160,7 @@ def __init__(self, factory=CallbackSpec):
158160
def __repr__(self):
159161
return f"{type(self).__name__}({self.items!r}, factory={self.factory!r})"
160162

161-
def _add_unbounded_callback(self, func, is_event=False, transitions=None, **kwargs):
163+
def _add_unbounded_callback(self, func, transitions=None, **kwargs):
162164
"""This list was a target for adding a func using decorator
163165
`@<state|event>[.on|before|after|enter|exit]` syntax.
164166
@@ -181,11 +183,7 @@ def _add_unbounded_callback(self, func, is_event=False, transitions=None, **kwar
181183
event.
182184
183185
"""
184-
spec = self._add(func, **kwargs)
185-
if not getattr(func, "_specs_to_update", None):
186-
func._specs_to_update = set()
187-
if is_event:
188-
func._specs_to_update.add(spec._update_func)
186+
self._add(func, **kwargs)
189187
func._transitions = transitions
190188

191189
return func
@@ -202,7 +200,10 @@ def grouper(self, group: CallbackGroup) -> SpecListGrouper:
202200
return self._groupers[group]
203201

204202
def _add(self, func, group: CallbackGroup, **kwargs):
205-
spec = self.factory(func, group, **kwargs)
203+
if isinstance(func, CallbackSpec):
204+
spec = func
205+
else:
206+
spec = self.factory(func, group, **kwargs)
206207

207208
if spec in self.items:
208209
return

statemachine/engines/async_.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..exceptions import InvalidDefinition
88
from ..exceptions import TransitionNotAllowed
99
from ..i18n import _
10+
from ..state import State
1011
from ..transition import Transition
1112

1213
if TYPE_CHECKING:
@@ -82,7 +83,7 @@ async def processing_loop(self):
8283
async def _trigger(self, trigger_data: TriggerData):
8384
event_data = None
8485
if trigger_data.event == "__initial__":
85-
transition = Transition(None, self.sm._get_initial_state(), event="__initial__")
86+
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
8687
transition._specs.clear()
8788
event_data = EventData(trigger_data=trigger_data, transition=transition)
8889
await self._activate(event_data)

statemachine/engines/sync.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ..event_data import EventData
66
from ..event_data import TriggerData
77
from ..exceptions import TransitionNotAllowed
8+
from ..state import State
89
from ..transition import Transition
910

1011
if TYPE_CHECKING:
@@ -85,7 +86,7 @@ def processing_loop(self):
8586
def _trigger(self, trigger_data: TriggerData):
8687
event_data = None
8788
if trigger_data.event == "__initial__":
88-
transition = Transition(None, self.sm._get_initial_state(), event="__initial__")
89+
transition = Transition(State(), self.sm._get_initial_state(), event="__initial__")
8990
transition._specs.clear()
9091
event_data = EventData(trigger_data=trigger_data, transition=transition)
9192
self._activate(event_data)

statemachine/factory.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Dict
55
from typing import List
66
from typing import Tuple
7-
from uuid import uuid4
87

98
from . import registry
109
from .event import Event
@@ -179,7 +178,7 @@ def add_inherited(cls, bases):
179178
cls.add_event(event=Event(id=event.id, name=event.name))
180179

181180
def add_from_attributes(cls, attrs): # noqa: C901
182-
for key, value in sorted(attrs.items(), key=lambda pair: pair[0]):
181+
for key, value in attrs.items():
183182
if isinstance(value, States):
184183
cls._add_states_from_dict(value)
185184
if isinstance(value, State):
@@ -195,7 +194,7 @@ def add_from_attributes(cls, attrs): # noqa: C901
195194
),
196195
old_event=value,
197196
)
198-
elif getattr(value, "_specs_to_update", None):
197+
elif getattr(value, "attr_name", None):
199198
cls._add_unbounded_callback(key, value)
200199

201200
def _add_states_from_dict(cls, states):
@@ -205,13 +204,10 @@ def _add_states_from_dict(cls, states):
205204
def _add_unbounded_callback(cls, attr_name, func):
206205
# if func is an event, the `attr_name` will be replaced by an event trigger,
207206
# so we'll also give the ``func`` a new unique name to be used by the callback
208-
# machinery.
209-
cls.add_event(event=Event(func._transitions, id=attr_name, name=attr_name))
210-
attr_name = f"_{attr_name}_{uuid4().hex}"
211-
setattr(cls, attr_name, func)
212-
213-
for ref in func._specs_to_update:
214-
ref(getattr(cls, attr_name), attr_name)
207+
# machinery that is stored at ``func.attr_name``
208+
setattr(cls, func.attr_name, func)
209+
if func.is_event:
210+
cls.add_event(event=Event(func._transitions, id=attr_name, name=attr_name))
215211

216212
def add_state(cls, id, state: State):
217213
state._set_id(id)
@@ -236,7 +232,7 @@ def add_event(
236232

237233
transitions = event._transitions
238234
if transitions is not None:
239-
transitions.add_event(event)
235+
transitions._on_event_defined(event=event, states=list(cls.states))
240236

241237
if event not in cls._events:
242238
cls._events[event] = None

statemachine/state.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import TYPE_CHECKING
22
from typing import Any
33
from typing import Dict
4+
from typing import List
45
from weakref import ref
56

67
from .callbacks import CallbackGroup
@@ -15,6 +16,37 @@
1516
from .statemachine import StateMachine
1617

1718

19+
class _TransitionBuilder:
20+
def __init__(self, state: "State"):
21+
self._state = state
22+
23+
def itself(self, **kwargs):
24+
return self.__call__(self._state, **kwargs)
25+
26+
def __call__(self, *states: "State", **kwargs):
27+
raise NotImplementedError
28+
29+
30+
class _ToState(_TransitionBuilder):
31+
def __call__(self, *states: "State", **kwargs):
32+
transitions = TransitionList(Transition(self._state, state, **kwargs) for state in states)
33+
self._state.transitions.add_transitions(transitions)
34+
return transitions
35+
36+
37+
class _FromState(_TransitionBuilder):
38+
def any(self, **kwargs):
39+
return self.__call__(AnyState(), **kwargs)
40+
41+
def __call__(self, *states: "State", **kwargs):
42+
transitions = TransitionList()
43+
for origin in states:
44+
transition = Transition(origin, self._state, **kwargs)
45+
origin.transitions.add_transitions(transition)
46+
transitions.add_transitions(transition)
47+
return transitions
48+
49+
1850
class State:
1951
"""
2052
A State in a :ref:`StateMachine` describes a particular behavior of the machine.
@@ -136,6 +168,12 @@ def _setup(self):
136168
self.exit.add("on_exit_state", priority=CallbackPriority.GENERIC, is_convention=True)
137169
self.exit.add(f"on_exit_{self.id}", priority=CallbackPriority.NAMING, is_convention=True)
138170

171+
def _on_event_defined(self, event: str, transition: Transition, states: List["State"]):
172+
"""Called by statemachine factory when an event is defined having a transition
173+
starting from this state.
174+
"""
175+
pass
176+
139177
def __repr__(self):
140178
return (
141179
f"{type(self).__name__}({self.name!r}, id={self.id!r}, value={self.value!r}, "
@@ -172,38 +210,15 @@ def _set_id(self, id: str):
172210
if not self.name:
173211
self.name = self._id.replace("_", " ").capitalize()
174212

175-
def _to_(self, *states: "State", **kwargs):
176-
transitions = TransitionList(Transition(self, state, **kwargs) for state in states)
177-
self.transitions.add_transitions(transitions)
178-
return transitions
179-
180-
def _from_(self, *states: "State", **kwargs):
181-
transitions = TransitionList()
182-
for origin in states:
183-
transition = Transition(origin, self, **kwargs)
184-
origin.transitions.add_transitions(transition)
185-
transitions.add_transitions(transition)
186-
return transitions
187-
188-
def _get_proxy_method_to_itself(self, method):
189-
def proxy(*states: "State", **kwargs):
190-
return method(*states, **kwargs)
191-
192-
def proxy_to_itself(**kwargs):
193-
return proxy(self, **kwargs)
194-
195-
proxy.itself = proxy_to_itself
196-
return proxy
197-
198213
@property
199-
def to(self):
214+
def to(self) -> _ToState:
200215
"""Create transitions to the given target states."""
201-
return self._get_proxy_method_to_itself(self._to_)
216+
return _ToState(self)
202217

203218
@property
204-
def from_(self):
219+
def from_(self) -> _FromState:
205220
"""Create transitions from the given target states (reversed)."""
206-
return self._get_proxy_method_to_itself(self._from_)
221+
return _FromState(self)
207222

208223
@property
209224
def initial(self):
@@ -269,3 +284,19 @@ def id(self) -> str:
269284
@property
270285
def is_active(self):
271286
return self._machine().current_state == self
287+
288+
289+
class AnyState(State):
290+
"""A special state that works as a "ANY" placeholder.
291+
292+
It is used as the "From" state of a transtion,
293+
until the state machine class is evaluated.
294+
"""
295+
296+
def _on_event_defined(self, event: str, transition: Transition, states: List[State]):
297+
for state in states:
298+
if state.final:
299+
continue
300+
new_transition = transition._copy_with_args(source=state, event=event)
301+
302+
state.transitions.add_transitions(new_transition)

statemachine/transition.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
from copy import deepcopy
2+
from typing import TYPE_CHECKING
3+
14
from .callbacks import CallbackGroup
25
from .callbacks import CallbackPriority
36
from .callbacks import CallbackSpecList
47
from .events import Events
58
from .exceptions import InvalidDefinition
69

10+
if TYPE_CHECKING:
11+
from .statemachine import State
12+
713

814
class Transition:
915
"""A transition holds reference to the source and target state.
@@ -32,8 +38,8 @@ class Transition:
3238

3339
def __init__(
3440
self,
35-
source,
36-
target,
41+
source: "State",
42+
target: "State",
3743
event=None,
3844
internal=False,
3945
validators=None,
@@ -125,3 +131,17 @@ def events(self):
125131

126132
def add_event(self, value):
127133
self._events.add(value)
134+
135+
def _copy_with_args(self, **kwargs):
136+
source = kwargs.pop("source", self.source)
137+
target = kwargs.pop("target", self.target)
138+
event = kwargs.pop("event", self.event)
139+
internal = kwargs.pop("internal", self.internal)
140+
new_transition = Transition(
141+
source=source, target=target, event=event, internal=internal, **kwargs
142+
)
143+
for spec in self._specs:
144+
new_spec = deepcopy(spec)
145+
new_transition._specs.add(new_spec, new_spec.group)
146+
147+
return new_transition

0 commit comments

Comments
 (0)