Skip to content

Commit cabdb83

Browse files
feat: Add finalize action after every transition (#386)
1 parent b0367f0 commit cabdb83

File tree

6 files changed

+434
-40
lines changed

6 files changed

+434
-40
lines changed

docs/actions.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ when something changes, and are not bound to a specific state or event:
2626

2727
- `after_transition()`
2828

29+
- `finalize()`
30+
2931
The following example offers an overview of the "generic" callbacks available:
3032

3133
```py
@@ -54,6 +56,10 @@ The following example offers an overview of the "generic" callbacks available:
5456
...
5557
... def after_transition(self, event, state):
5658
... print(f"After '{event}', on the '{state.id}' state.")
59+
...
60+
... def finalize(self, event, source, target, state):
61+
... print(f"Finalizing transition {event} from {source.id} to {target.id}")
62+
... print(f"Current state: {state.id}")
5763

5864

5965
>>> sm = ExampleStateMachine() # On initialization, the machine run a special event `__initial__`
@@ -65,6 +71,8 @@ Exiting 'initial' state from 'loop' event.
6571
On 'loop', on the 'initial' state.
6672
Entering 'initial' state from 'loop' event.
6773
After 'loop', on the 'initial' state.
74+
Finalizing transition loop from initial to initial
75+
Current state: initial
6876
['before_transition_return', 'on_transition_return']
6977

7078
>>> sm.go()
@@ -73,6 +81,8 @@ Exiting 'initial' state from 'go' event.
7381
On 'go', on the 'initial' state.
7482
Entering 'final' state from 'go' event.
7583
After 'go', on the 'final' state.
84+
Finalizing transition go from initial to final
85+
Current state: final
7686
['before_transition_return', 'on_transition_return']
7787

7888
```
@@ -346,6 +356,10 @@ Actions registered on the same group don't have order guaranties and are execute
346356
- `after_<event>()`, `after_transition()`
347357
- `destination`
348358
- Callbacks declared in the transition or event.
359+
* - Finalize
360+
- `finalize()`
361+
- `destination`
362+
- Guaranteed to run after every transition attempt, whether successful or failed.
349363
350364
```
351365

@@ -381,6 +395,9 @@ defined explicitly. The following provides an example:
381395
... def on_loop(self):
382396
... return "On loop"
383397
...
398+
... def finalize(self):
399+
... # Finalize return values are not included in results
400+
... return "Finalize"
384401

385402
>>> sm = ExampleStateMachine()
386403

statemachine/callbacks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class CallbackGroup(IntEnum):
4949
ON = auto()
5050
AFTER = auto()
5151
COND = auto()
52+
FINALIZE = auto()
5253

5354
def build_key(self, specs: "CallbackSpecList") -> str:
5455
return f"{self.name}@{id(specs)}"

statemachine/engines/async_.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING
22

3+
from ..callbacks import CallbackGroup
34
from ..event_data import EventData
45
from ..event_data import TriggerData
56
from ..exceptions import InvalidDefinition
@@ -101,30 +102,49 @@ async def _activate(self, trigger_data: TriggerData, transition: "Transition"):
101102
event_data = EventData(trigger_data=trigger_data, transition=transition)
102103
args, kwargs = event_data.args, event_data.extended_kwargs
103104

104-
await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs)
105-
if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs):
106-
return False, None
105+
try:
106+
await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs)
107+
if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs):
108+
return False, None
109+
110+
source = transition.source
111+
target = transition.target
107112

108-
source = transition.source
109-
target = transition.target
113+
result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs)
114+
if source is not None and not transition.internal:
115+
await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs)
110116

111-
result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs)
112-
if source is not None and not transition.internal:
113-
await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs)
117+
result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs)
114118

115-
result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs)
119+
self.sm.current_state = target
120+
event_data.state = target
121+
kwargs["state"] = target
116122

117-
self.sm.current_state = target
118-
event_data.state = target
119-
kwargs["state"] = target
123+
if not transition.internal:
124+
await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs)
125+
await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs)
120126

121-
if not transition.internal:
122-
await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs)
123-
await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs)
127+
if len(result) == 0:
128+
result = None
129+
elif len(result) == 1:
130+
result = result[0]
124131

125-
if len(result) == 0:
126-
result = None
127-
elif len(result) == 1:
128-
result = result[0]
132+
return True, result
133+
finally:
134+
# Run finalize actions regardless of success/failure
135+
await self._run_finalize_actions(event_data)
129136

130-
return True, result
137+
async def _run_finalize_actions(self, event_data: EventData):
138+
"""Run finalize actions after a transition attempt."""
139+
try:
140+
args, kwargs = event_data.args, event_data.extended_kwargs
141+
await self.sm._callbacks.async_call(
142+
CallbackGroup.FINALIZE.build_key(event_data.transition._specs),
143+
*args,
144+
**kwargs,
145+
)
146+
except Exception as e:
147+
# Log but don't re-raise finalize errors
148+
import logging
149+
150+
logging.error(f"Error in finalize action: {e}")

statemachine/engines/sync.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING
22

3+
from ..callbacks import CallbackGroup
34
from ..event_data import EventData
45
from ..event_data import TriggerData
56
from ..exceptions import TransitionNotAllowed
@@ -103,30 +104,49 @@ def _activate(self, trigger_data: TriggerData, transition: "Transition"):
103104
event_data = EventData(trigger_data=trigger_data, transition=transition)
104105
args, kwargs = event_data.args, event_data.extended_kwargs
105106

106-
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
107-
if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs):
108-
return False, None
107+
try:
108+
self.sm._callbacks.call(transition.validators.key, *args, **kwargs)
109+
if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs):
110+
return False, None
111+
112+
source = transition.source
113+
target = transition.target
109114

110-
source = transition.source
111-
target = transition.target
115+
result = self.sm._callbacks.call(transition.before.key, *args, **kwargs)
116+
if source is not None and not transition.internal:
117+
self.sm._callbacks.call(source.exit.key, *args, **kwargs)
112118

113-
result = self.sm._callbacks.call(transition.before.key, *args, **kwargs)
114-
if source is not None and not transition.internal:
115-
self.sm._callbacks.call(source.exit.key, *args, **kwargs)
119+
result += self.sm._callbacks.call(transition.on.key, *args, **kwargs)
116120

117-
result += self.sm._callbacks.call(transition.on.key, *args, **kwargs)
121+
self.sm.current_state = target
122+
event_data.state = target
123+
kwargs["state"] = target
118124

119-
self.sm.current_state = target
120-
event_data.state = target
121-
kwargs["state"] = target
125+
if not transition.internal:
126+
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
127+
self.sm._callbacks.call(transition.after.key, *args, **kwargs)
122128

123-
if not transition.internal:
124-
self.sm._callbacks.call(target.enter.key, *args, **kwargs)
125-
self.sm._callbacks.call(transition.after.key, *args, **kwargs)
129+
if len(result) == 0:
130+
result = None
131+
elif len(result) == 1:
132+
result = result[0]
126133

127-
if len(result) == 0:
128-
result = None
129-
elif len(result) == 1:
130-
result = result[0]
134+
return True, result
135+
finally:
136+
# Run finalize actions regardless of success/failure
137+
self._run_finalize_actions(event_data)
131138

132-
return True, result
139+
def _run_finalize_actions(self, event_data: EventData):
140+
"""Run finalize actions after a transition attempt."""
141+
try:
142+
args, kwargs = event_data.args, event_data.extended_kwargs
143+
self.sm._callbacks.call(
144+
CallbackGroup.FINALIZE.build_key(event_data.transition._specs),
145+
*args,
146+
**kwargs,
147+
)
148+
except Exception as e:
149+
# Log but don't re-raise finalize errors
150+
import logging
151+
152+
logging.error(f"Error in finalize action: {e}")

statemachine/transition.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class Transition:
3434
before the transition is executed.
3535
after (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked
3636
after the transition is executed.
37+
finalize (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked
38+
after the transition is executed.
3739
"""
3840

3941
def __init__(
@@ -48,6 +50,7 @@ def __init__(
4850
on=None,
4951
before=None,
5052
after=None,
53+
finalize=None,
5154
):
5255
self.source = source
5356
self.target = target
@@ -68,6 +71,9 @@ def __init__(
6871
self.after = self._specs.grouper(CallbackGroup.AFTER).add(
6972
after, priority=CallbackPriority.INLINE
7073
)
74+
self.finalize = self._specs.grouper(CallbackGroup.FINALIZE).add(
75+
finalize, priority=CallbackPriority.INLINE
76+
)
7177
self.cond = (
7278
self._specs.grouper(CallbackGroup.COND)
7379
.add(cond, priority=CallbackPriority.INLINE, expected_value=True)
@@ -87,6 +93,7 @@ def _setup(self):
8793
before = self.before.add
8894
on = self.on.add
8995
after = self.after.add
96+
finalize = self.finalize.add
9097

9198
before("before_transition", priority=CallbackPriority.GENERIC, is_convention=True)
9299
on("on_transition", priority=CallbackPriority.GENERIC, is_convention=True)
@@ -118,6 +125,12 @@ def _setup(self):
118125
is_convention=True,
119126
)
120127

128+
finalize(
129+
"finalize",
130+
priority=CallbackPriority.AFTER,
131+
is_convention=True,
132+
)
133+
121134
def match(self, event: str):
122135
return self._events.match(event)
123136

0 commit comments

Comments
 (0)