Skip to content

Commit 44994ed

Browse files
committed
feat: Hierarquical statemachines with compose and parallel
1 parent e1a3af5 commit 44994ed

20 files changed

+217
-81
lines changed

docs/transitions.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Syntax:
8484
>>> draft = State("Draft")
8585

8686
>>> draft.to.itself()
87-
TransitionList([Transition('Draft', 'Draft', event='', internal=False)])
87+
TransitionList([Transition('Draft', 'Draft', event=[], internal=False)])
8888

8989
```
9090

@@ -101,7 +101,7 @@ Syntax:
101101
>>> draft = State("Draft")
102102

103103
>>> draft.to.itself(internal=True)
104-
TransitionList([Transition('Draft', 'Draft', event='', internal=True)])
104+
TransitionList([Transition('Draft', 'Draft', event=[], internal=True)])
105105

106106
```
107107

statemachine/contrib/diagram.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ def _transition_as_edge(self, transition):
172172

173173
def get_graph(self):
174174
graph = self._get_graph(self.machine)
175-
self._graph_states(self.machine, graph)
175+
self._graph_states(self.machine, graph, is_root=True)
176176
return graph
177177

178-
def _graph_states(self, state, graph):
178+
def _graph_states(self, state, graph, is_root=False):
179179
initial_node = self._initial_node(state)
180180
initial_subgraph = pydot.Subgraph(
181181
graph_name=f"{initial_node.get_name()}_initial",
@@ -193,8 +193,9 @@ def _graph_states(self, state, graph):
193193
graph.add_subgraph(initial_subgraph)
194194
graph.add_subgraph(atomic_states_subgraph)
195195

196-
initial = next(s for s in state.states if s.initial)
197-
graph.add_edge(self._initial_edge(initial_node, initial))
196+
if is_root:
197+
initial = next(s for s in state.states if s.initial)
198+
graph.add_edge(self._initial_edge(initial_node, initial))
198199

199200
for substate in state.states:
200201
if substate.states:

statemachine/engines/base.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,10 @@ def _select_transitions(
241241
def first_transition_that_matches(state: State, event: Event) -> "Transition | None":
242242
for s in chain([state], state.ancestors()):
243243
for transition in s.transitions:
244-
if predicate(transition, event) and self._conditions_match(
245-
transition, trigger_data
244+
if (
245+
not transition.initial
246+
and predicate(transition, event)
247+
and self._conditions_match(transition, trigger_data)
246248
):
247249
return transition
248250

@@ -481,31 +483,28 @@ def add_descendant_states_to_enter(
481483
# Handle compound states
482484
states_for_default_entry.add(info)
483485
initial_state = next(s for s in state.states if s.initial)
484-
for transition in initial_state.transitions:
485-
info_initial = StateTransition(
486-
transition=transition,
487-
target=transition.target,
488-
source=transition.source,
489-
)
490-
self.add_descendant_states_to_enter(
491-
info_initial,
492-
states_to_enter,
493-
states_for_default_entry,
494-
default_history_content,
495-
)
496-
for transition in initial_state.transitions:
497-
info_initial = StateTransition(
498-
transition=transition,
499-
target=transition.target,
500-
source=transition.source,
501-
)
502-
self.add_ancestor_states_to_enter(
503-
info_initial,
504-
state,
505-
states_to_enter,
506-
states_for_default_entry,
507-
default_history_content,
508-
)
486+
transition = next(
487+
t for t in state.transitions if t.initial and t.target == initial_state
488+
)
489+
info_initial = StateTransition(
490+
transition=transition,
491+
target=transition.target,
492+
source=transition.source,
493+
)
494+
self.add_descendant_states_to_enter(
495+
info_initial,
496+
states_to_enter,
497+
states_for_default_entry,
498+
default_history_content,
499+
)
500+
501+
self.add_ancestor_states_to_enter(
502+
info_initial,
503+
state,
504+
states_to_enter,
505+
states_for_default_entry,
506+
default_history_content,
507+
)
509508
elif state.parallel:
510509
# Handle parallel states
511510
for child_state in state.states:

statemachine/events.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ class Events:
88
def __init__(self):
99
self._items: list[Event] = []
1010

11-
def __repr__(self):
11+
def __str__(self):
1212
sep = " " if len(self._items) > 1 else ""
1313
return sep.join(item for item in self._items)
1414

15+
def __repr__(self):
16+
return f"{self._items!r}"
17+
1518
def __iter__(self):
1619
return iter(self._items)
1720

statemachine/factory.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
registry.register(cls)
3333
cls.name = cls.__name__
3434
cls.id = cls.name.lower()
35+
# TODO: Experiment with the IDEA of a root state
36+
# cls.root = State(id=cls.id, name=cls.name)
3537
cls.states: States = States()
3638
cls.states_map: Dict[Any, State] = {}
3739
"""Map of ``state.value`` to the corresponding :ref:`state`."""
@@ -50,7 +52,7 @@ def __init__(
5052
if not cls.states:
5153
return
5254

53-
cls._initials_by_document_order(cls.states)
55+
cls._initials_by_document_order(cls.states, parent=None)
5456

5557
initials = [s for s in cls.states if s.initial]
5658
parallels = [s.id for s in cls.states if s.parallel]
@@ -79,16 +81,24 @@ def __init__(
7981

8082
def __getattr__(self, attribute: str) -> Any: ...
8183

82-
def _initials_by_document_order(cls, states):
84+
def _initials_by_document_order(cls, states, parent: "State | None" = None):
8385
"""Set initial state by document order if no explicit initial state is set"""
84-
has_initial = False
86+
initial: "State | None" = None
8587
for s in states:
86-
cls._initials_by_document_order(s.states)
88+
cls._initials_by_document_order(s.states, s)
8789
if s.initial:
88-
has_initial = True
90+
initial = s
8991
break
90-
if not has_initial and states:
91-
states[0]._initial = True
92+
if not initial and states:
93+
initial = states[0]
94+
initial._initial = True
95+
96+
if (
97+
parent
98+
and initial
99+
and not any(t for t in parent.transitions if t.initial and t.target == initial)
100+
):
101+
parent.to(initial, initial=True)
92102

93103
def _unpack_builders_callbacks(cls):
94104
callbacks = {}

statemachine/graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def visit_connected_states(state):
1111
continue
1212
already_visited.add(state)
1313
yield state
14-
visit.extend(s for s in state.states if s.initial)
1514
visit.extend(t.target for t in state.transitions)
1615

1716

statemachine/io/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class TransitionDict(TypedDict, total=False):
2222
target: str
2323
event: "str | None"
2424
internal: bool
25+
initial: bool
2526
validators: bool
2627
cond: "str | ActionProtocol | Sequence[str] | Sequence[ActionProtocol]"
2728
unless: "str | ActionProtocol | Sequence[str] | Sequence[ActionProtocol]"
@@ -117,6 +118,7 @@ def create_machine_class_from_definition(
117118
transition = source.to(
118119
target,
119120
event=event_name,
121+
initial=transition_data.get("initial"),
120122
cond=transition_data.get("cond"),
121123
unless=transition_data.get("unless"),
122124
on=transition_data.get("on"),

statemachine/io/scxml/parser.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import xml.etree.ElementTree as ET
3+
from typing import Set
34

45
from .schema import Action
56
from .schema import AssignAction
@@ -32,6 +33,12 @@ def strip_namespaces(tree: ET.Element):
3233
attrib[new_name] = attrib.pop(name)
3334

3435

36+
def _parse_initial(initial_content: "str | None") -> Set[str]:
37+
if initial_content is None:
38+
return set()
39+
return set(initial_content.split())
40+
41+
3542
def parse_scxml(scxml_content: str) -> StateMachineDefinition:
3643
root = ET.fromstring(scxml_content)
3744
strip_namespaces(root)
@@ -40,9 +47,9 @@ def parse_scxml(scxml_content: str) -> StateMachineDefinition:
4047
if scxml is None:
4148
raise ValueError("No scxml element found in document")
4249

43-
initial_state = scxml.get("initial")
50+
initial_state = _parse_initial(scxml.get("initial"))
4451

45-
definition = StateMachineDefinition(initial_state=initial_state)
52+
definition = StateMachineDefinition(initial_states=initial_state)
4653

4754
# Parse datamodel
4855
datamodel = parse_datamodel(scxml)
@@ -52,19 +59,19 @@ def parse_scxml(scxml_content: str) -> StateMachineDefinition:
5259
# Parse states
5360
for state_elem in scxml:
5461
if state_elem.tag == "state":
55-
state = parse_state(state_elem, definition.initial_state)
62+
state = parse_state(state_elem, definition.initial_states)
5663
definition.states[state.id] = state
5764
elif state_elem.tag == "final":
58-
state = parse_state(state_elem, definition.initial_state, is_final=True)
65+
state = parse_state(state_elem, definition.initial_states, is_final=True)
5966
definition.states[state.id] = state
6067
elif state_elem.tag == "parallel":
61-
state = parse_state(state_elem, definition.initial_state, is_parallel=True)
68+
state = parse_state(state_elem, definition.initial_states, is_parallel=True)
6269
definition.states[state.id] = state
6370

6471
# If no initial state was specified, pick the first state
65-
if not definition.initial_state and definition.states:
66-
definition.initial_state = next(iter(definition.states.keys()))
67-
definition.states[definition.initial_state].initial = True
72+
if not definition.initial_states and definition.states:
73+
definition.initial_states = next(iter(definition.states.keys()))
74+
definition.states[definition.initial_states].initial = True
6875

6976
return definition
7077

@@ -95,15 +102,15 @@ def parse_datamodel(root: ET.Element) -> "DataModel | None":
95102

96103
def parse_state(
97104
state_elem: ET.Element,
98-
initial_state: "str | None",
105+
initial_states: Set[str],
99106
is_final: bool = False,
100107
is_parallel: bool = False,
101108
) -> State:
102109
state_id = state_elem.get("id")
103110
if not state_id:
104111
raise ValueError("State must have an 'id' attribute")
105112

106-
initial = state_id == initial_state
113+
initial = state_id in initial_states
107114
state = State(id=state_id, initial=initial, final=is_final, parallel=is_parallel)
108115

109116
# Parse onentry actions
@@ -122,18 +129,25 @@ def parse_state(
122129
state.transitions.append(transition)
123130

124131
# Parse child states
125-
initial_state = state_elem.get("initial")
132+
initial_states |= _parse_initial(state_elem.get("initial"))
133+
initial_elem = state_elem.find("initial")
134+
if initial_elem is not None:
135+
for trans_elem in initial_elem.findall("transition"):
136+
transition = parse_transition(trans_elem, initial=True)
137+
state.transitions.append(transition)
138+
initial_states |= _parse_initial(trans_elem.get("target"))
139+
126140
for child_state_elem in state_elem.findall("state"):
127-
child_state = parse_state(child_state_elem, initial_state=initial_state)
141+
child_state = parse_state(child_state_elem, initial_states=initial_states)
128142
state.states[child_state.id] = child_state
129143
for child_state_elem in state_elem.findall("parallel"):
130-
state = parse_state(child_state_elem, initial_state=initial_state, is_parallel=True)
144+
state = parse_state(child_state_elem, initial_states=initial_states, is_parallel=True)
131145
state.states[child_state.id] = child_state
132146

133147
return state
134148

135149

136-
def parse_transition(trans_elem: ET.Element) -> Transition:
150+
def parse_transition(trans_elem: ET.Element, initial: bool = False) -> Transition:
137151
target = trans_elem.get("target")
138152
if not target:
139153
raise ValueError("Transition must have a 'target' attribute")
@@ -143,7 +157,9 @@ def parse_transition(trans_elem: ET.Element) -> Transition:
143157

144158
executable_content = parse_executable_content(trans_elem)
145159

146-
return Transition(target=target, event=event, cond=cond, on=executable_content)
160+
return Transition(
161+
target=target, initial=initial, event=event, cond=cond, on=executable_content
162+
)
147163

148164

149165
def parse_executable_content(element: ET.Element) -> ExecutableContent:

statemachine/io/scxml/processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ def _process_transitions(self, transitions: List[Transition]):
8888
event = transition.event or None
8989
if event not in on_dict:
9090
on_dict[event] = []
91-
transition_dict: TransitionDict = {"target": transition.target}
91+
transition_dict: TransitionDict = {
92+
"target": transition.target,
93+
"initial": transition.initial,
94+
}
9295

9396
# Process cond
9497
if transition.cond:

statemachine/io/scxml/schema.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import field
33
from typing import Dict
44
from typing import List
5+
from typing import Set
56

67

78
@dataclass
@@ -93,6 +94,7 @@ class ScriptAction(Action):
9394
@dataclass
9495
class Transition:
9596
target: str
97+
initial: bool = False
9698
event: "str | None" = None
9799
cond: "str | None" = None
98100
on: "ExecutableContent | None" = None
@@ -127,5 +129,5 @@ class DataModel:
127129
@dataclass
128130
class StateMachineDefinition:
129131
states: Dict[str, State] = field(default_factory=dict)
130-
initial_state: "str | None" = None
132+
initial_states: Set[str] = field(default_factory=set)
131133
datamodel: "DataModel | None" = None

0 commit comments

Comments
 (0)