diff --git a/docs/actions.md b/docs/actions.md index f1c10c52..6ca6e0b4 100644 --- a/docs/actions.md +++ b/docs/actions.md @@ -26,6 +26,8 @@ when something changes, and are not bound to a specific state or event: - `after_transition()` +- `finalize()` + The following example offers an overview of the "generic" callbacks available: ```py @@ -54,6 +56,10 @@ The following example offers an overview of the "generic" callbacks available: ... ... def after_transition(self, event, state): ... print(f"After '{event}', on the '{state.id}' state.") +... +... def finalize(self, event, source, target, state): +... print(f"Finalizing transition {event} from {source.id} to {target.id}") +... print(f"Current state: {state.id}") >>> sm = ExampleStateMachine() # On initialization, the machine run a special event `__initial__` @@ -65,6 +71,8 @@ Exiting 'initial' state from 'loop' event. On 'loop', on the 'initial' state. Entering 'initial' state from 'loop' event. After 'loop', on the 'initial' state. +Finalizing transition loop from initial to initial +Current state: initial ['before_transition_return', 'on_transition_return'] >>> sm.go() @@ -73,6 +81,8 @@ Exiting 'initial' state from 'go' event. On 'go', on the 'initial' state. Entering 'final' state from 'go' event. After 'go', on the 'final' state. +Finalizing transition go from initial to final +Current state: final ['before_transition_return', 'on_transition_return'] ``` @@ -346,6 +356,10 @@ Actions registered on the same group don't have order guaranties and are execute - `after_()`, `after_transition()` - `destination` - Callbacks declared in the transition or event. +* - Finalize + - `finalize()` + - `destination` + - Guaranteed to run after every transition attempt, whether successful or failed. ``` @@ -381,6 +395,9 @@ defined explicitly. The following provides an example: ... def on_loop(self): ... return "On loop" ... +... def finalize(self): +... # Finalize return values are not included in results +... return "Finalize" >>> sm = ExampleStateMachine() diff --git a/statemachine/callbacks.py b/statemachine/callbacks.py index 0a6613c1..bfb79574 100644 --- a/statemachine/callbacks.py +++ b/statemachine/callbacks.py @@ -49,6 +49,7 @@ class CallbackGroup(IntEnum): ON = auto() AFTER = auto() COND = auto() + FINALIZE = auto() def build_key(self, specs: "CallbackSpecList") -> str: return f"{self.name}@{id(specs)}" diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 9d2b3f9f..fca9bb87 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from ..callbacks import CallbackGroup from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import InvalidDefinition @@ -101,30 +102,49 @@ async def _activate(self, trigger_data: TriggerData, transition: "Transition"): event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs - await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs) - if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs): - return False, None + try: + await self.sm._callbacks.async_call(transition.validators.key, *args, **kwargs) + if not await self.sm._callbacks.async_all(transition.cond.key, *args, **kwargs): + return False, None + + source = transition.source + target = transition.target - source = transition.source - target = transition.target + result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs) + if source is not None and not transition.internal: + await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs) - result = await self.sm._callbacks.async_call(transition.before.key, *args, **kwargs) - if source is not None and not transition.internal: - await self.sm._callbacks.async_call(source.exit.key, *args, **kwargs) + result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs) - result += await self.sm._callbacks.async_call(transition.on.key, *args, **kwargs) + self.sm.current_state = target + event_data.state = target + kwargs["state"] = target - self.sm.current_state = target - event_data.state = target - kwargs["state"] = target + if not transition.internal: + await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs) + await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs) - if not transition.internal: - await self.sm._callbacks.async_call(target.enter.key, *args, **kwargs) - await self.sm._callbacks.async_call(transition.after.key, *args, **kwargs) + if len(result) == 0: + result = None + elif len(result) == 1: + result = result[0] - if len(result) == 0: - result = None - elif len(result) == 1: - result = result[0] + return True, result + finally: + # Run finalize actions regardless of success/failure + await self._run_finalize_actions(event_data) - return True, result + async def _run_finalize_actions(self, event_data: EventData): + """Run finalize actions after a transition attempt.""" + try: + args, kwargs = event_data.args, event_data.extended_kwargs + await self.sm._callbacks.async_call( + CallbackGroup.FINALIZE.build_key(event_data.transition._specs), + *args, + **kwargs, + ) + except Exception as e: + # Log but don't re-raise finalize errors + import logging + + logging.error(f"Error in finalize action: {e}") diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index 4400cd08..9c5dd925 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from ..callbacks import CallbackGroup from ..event_data import EventData from ..event_data import TriggerData from ..exceptions import TransitionNotAllowed @@ -103,30 +104,49 @@ def _activate(self, trigger_data: TriggerData, transition: "Transition"): event_data = EventData(trigger_data=trigger_data, transition=transition) args, kwargs = event_data.args, event_data.extended_kwargs - self.sm._callbacks.call(transition.validators.key, *args, **kwargs) - if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs): - return False, None + try: + self.sm._callbacks.call(transition.validators.key, *args, **kwargs) + if not self.sm._callbacks.all(transition.cond.key, *args, **kwargs): + return False, None + + source = transition.source + target = transition.target - source = transition.source - target = transition.target + result = self.sm._callbacks.call(transition.before.key, *args, **kwargs) + if source is not None and not transition.internal: + self.sm._callbacks.call(source.exit.key, *args, **kwargs) - result = self.sm._callbacks.call(transition.before.key, *args, **kwargs) - if source is not None and not transition.internal: - self.sm._callbacks.call(source.exit.key, *args, **kwargs) + result += self.sm._callbacks.call(transition.on.key, *args, **kwargs) - result += self.sm._callbacks.call(transition.on.key, *args, **kwargs) + self.sm.current_state = target + event_data.state = target + kwargs["state"] = target - self.sm.current_state = target - event_data.state = target - kwargs["state"] = target + if not transition.internal: + self.sm._callbacks.call(target.enter.key, *args, **kwargs) + self.sm._callbacks.call(transition.after.key, *args, **kwargs) - if not transition.internal: - self.sm._callbacks.call(target.enter.key, *args, **kwargs) - self.sm._callbacks.call(transition.after.key, *args, **kwargs) + if len(result) == 0: + result = None + elif len(result) == 1: + result = result[0] - if len(result) == 0: - result = None - elif len(result) == 1: - result = result[0] + return True, result + finally: + # Run finalize actions regardless of success/failure + self._run_finalize_actions(event_data) - return True, result + def _run_finalize_actions(self, event_data: EventData): + """Run finalize actions after a transition attempt.""" + try: + args, kwargs = event_data.args, event_data.extended_kwargs + self.sm._callbacks.call( + CallbackGroup.FINALIZE.build_key(event_data.transition._specs), + *args, + **kwargs, + ) + except Exception as e: + # Log but don't re-raise finalize errors + import logging + + logging.error(f"Error in finalize action: {e}") diff --git a/statemachine/transition.py b/statemachine/transition.py index a9044f0f..c33f5b49 100644 --- a/statemachine/transition.py +++ b/statemachine/transition.py @@ -34,6 +34,8 @@ class Transition: before the transition is executed. after (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked after the transition is executed. + finalize (Optional[Union[str, Callable, List[Callable]]]): The callbacks to be invoked + after the transition is executed. """ def __init__( @@ -48,6 +50,7 @@ def __init__( on=None, before=None, after=None, + finalize=None, ): self.source = source self.target = target @@ -68,6 +71,9 @@ def __init__( self.after = self._specs.grouper(CallbackGroup.AFTER).add( after, priority=CallbackPriority.INLINE ) + self.finalize = self._specs.grouper(CallbackGroup.FINALIZE).add( + finalize, priority=CallbackPriority.INLINE + ) self.cond = ( self._specs.grouper(CallbackGroup.COND) .add(cond, priority=CallbackPriority.INLINE, expected_value=True) @@ -87,6 +93,7 @@ def _setup(self): before = self.before.add on = self.on.add after = self.after.add + finalize = self.finalize.add before("before_transition", priority=CallbackPriority.GENERIC, is_convention=True) on("on_transition", priority=CallbackPriority.GENERIC, is_convention=True) @@ -118,6 +125,12 @@ def _setup(self): is_convention=True, ) + finalize( + "finalize", + priority=CallbackPriority.AFTER, + is_convention=True, + ) + def match(self, event: str): return self._events.match(event) diff --git a/tests/test_finalize.py b/tests/test_finalize.py new file mode 100644 index 00000000..950b9b04 --- /dev/null +++ b/tests/test_finalize.py @@ -0,0 +1,323 @@ +import pytest + +from statemachine import State +from statemachine import StateMachine +from statemachine.transition import Transition + + +class TrafficLightMachine(StateMachine): + """A simple traffic light state machine for testing.""" + + red = State("Red", initial=True) + yellow = State("Yellow") + green = State("Green") + + cycle = red.to(green) | green.to(yellow) | yellow.to(red) + + +class AsyncTrafficLightMachine(TrafficLightMachine): + """Traffic light with async callbacks.""" + + async def before_cycle(self): + # Makes this an async state machine + pass + + +# Basic finalize functionality tests +def test_sync_finalize_basic(): + """Test basic finalize action in sync mode.""" + calls = [] + + class TestMachine(TrafficLightMachine): + def finalize(self, state): + calls.append(("finalize", state.id)) + + sm = TestMachine() + assert sm.current_state == sm.red + sm.cycle() # red -> green + assert sm.current_state == sm.green + assert calls == [("finalize", "green")] + + sm.cycle() # green -> yellow + assert sm.current_state == sm.yellow + assert calls == [("finalize", "green"), ("finalize", "yellow")] + + +# Error handling tests +def test_sync_finalize_with_error(): + """Test finalize action when transition fails.""" + calls = [] + + class FailingTrafficLight(TrafficLightMachine): + def before_cycle(self): + raise ValueError("Simulated failure") + + def finalize(self, state): + calls.append(("finalize", state.id)) + + sm = FailingTrafficLight() + + assert sm.current_state == sm.red + with pytest.raises(ValueError): # noqa: PT011 + sm.cycle() + + assert sm.current_state == sm.red # State unchanged due to error + assert calls == [("finalize", "red")] # Finalize still called + + +def test_sync_finalize_error_propagation(): + """Test that finalize errors are logged but don't affect state machine operation.""" + calls = [] + + class TestMachine(TrafficLightMachine): + def finalize(self, state): + calls.append("failing") + raise ValueError("Simulated failure") + + sm = TestMachine() + sm.cycle() # Should complete despite finalize error + assert calls == ["failing"] + assert sm.current_state == sm.green + + +# Full dependency injection tests +def test_sync_finalize_dependency_injection(): + """Test that finalize method supports dependency injection.""" + results = {} + + class TestMachine(TrafficLightMachine): + def finalize( + self, + message, + event, + source, + target, + state, + model, + transition, + *args, + **kwargs, + ): + results.update( + { + "message": message, + "event": event, + "source": source.id, + "target": target.id, + "current_state": state.id, + "model": model, + "transition": transition, + "args": args, + "kwargs": kwargs, + } + ) + + sm = TestMachine() + sm.cycle(123, message="test") # Pass some args and kwargs + # Verify all injected parameters + assert results["event"] == "cycle" + assert results["source"] == "red" + assert results["target"] == "green" + assert results["current_state"] == "green" + assert results["model"] is sm.model + assert isinstance(results["transition"], Transition) + assert results["kwargs"]["event_data"].args == (123,) + assert results["message"] == "test" + + +def test_callback_ordering(): + """Test that callbacks are executed in the correct order.""" + execution_order = [] + + class OrderedCallbackMachine(StateMachine): + state1 = State("State1", initial=True) + state2 = State("State2") + + transition = state1.to(state2) + + def before_transition(self): + execution_order.append("before") + + def on_exit_state1(self): + execution_order.append("exit") + + def on_transition(self): + execution_order.append("on") + + def on_enter_state2(self): + execution_order.append("enter") + + def after_transition(self): + execution_order.append("after") + + def finalize(self): + execution_order.append("finalize") + + sm = OrderedCallbackMachine() + sm.transition() + + assert execution_order == [ + "before", + "exit", + "on", + "enter", + "after", + "finalize", + ], "validate run ordering" + + +###### Async tests ###### + + +@pytest.mark.asyncio() +async def test_async_finalize_basic(): + """Test basic finalize action in async mode.""" + calls = [] + + class TestMachine(AsyncTrafficLightMachine): + async def finalize(self, state): + calls.append(("finalize", state.id)) + + sm = TestMachine() + await sm.activate_initial_state() + + assert sm.current_state == sm.red + await sm.cycle() # red -> green + assert sm.current_state == sm.green + assert calls == [("finalize", "green")] + + await sm.cycle() # green -> yellow + assert sm.current_state == sm.yellow + assert calls == [("finalize", "green"), ("finalize", "yellow")] + + +@pytest.mark.asyncio() +async def test_async_finalize_with_error(): + """Test finalize action when async transition fails.""" + calls = [] + + class AsyncFailingTrafficLight(AsyncTrafficLightMachine): + async def before_cycle(self): + raise ValueError("Simulated async failure") + + async def finalize(self, state): + calls.append(("finalize", state.id)) + + sm = AsyncFailingTrafficLight() + await sm.activate_initial_state() + + assert sm.current_state == sm.red + with pytest.raises(ValueError): # noqa: PT011 + await sm.cycle() + + assert sm.current_state == sm.red # State unchanged due to error + assert calls == [("finalize", "red")] # Finalize still called + + +@pytest.mark.asyncio() +async def test_async_finalize_with_async_error(): + """Test that async finalize errors are properly handled.""" + calls = [] + + class TestMachine(AsyncTrafficLightMachine): + async def finalize(self, state): + calls.append(("before_error", state.id)) + raise ValueError("Simulated async error") + + sm = TestMachine() + await sm.activate_initial_state() + + assert sm.current_state == sm.red + await sm.cycle() # Should complete despite async finalize error + assert sm.current_state == sm.green + assert calls == [("before_error", "green")] + + +@pytest.mark.asyncio() +async def test_async_finalize_with_dependency_injection(): + """Test that async finalize supports dependency injection.""" + results = {} + + class TestMachine(AsyncTrafficLightMachine): + async def finalize( + self, + message, + event, + source, + target, + state, + model, + transition, + *args, + **kwargs, + ): + results.update( + { + "message": message, + "event": event, + "source": source.id, + "target": target.id, + "current_state": state.id, + "model": model, + "transition": transition, + "args": args, + "kwargs": kwargs, + } + ) + + sm = TestMachine() + await sm.activate_initial_state() + await sm.cycle(123, message="test") # Pass some args and kwargs + + # Verify all injected parameters + assert results["event"] == "cycle" + assert results["source"] == "red" + assert results["target"] == "green" + assert results["current_state"] == "green" + assert results["model"] is sm.model + assert isinstance(results["transition"], Transition) + assert results["kwargs"]["event_data"].args == (123,) + assert results["message"] == "test" + + +@pytest.mark.asyncio() +async def test_async_callback_ordering(): + """Test that callbacks are executed in the correct order in async mode.""" + execution_order = [] + + class AsyncOrderedCallbackMachine(StateMachine): + state1 = State("State1", initial=True) + state2 = State("State2") + + transition = state1.to(state2) + + async def before_transition(self): + execution_order.append("before") + + async def on_exit_state1(self): + execution_order.append("exit") + + async def on_transition(self): + execution_order.append("on") + + async def on_enter_state2(self): + execution_order.append("enter") + + async def after_transition(self): + execution_order.append("after") + + async def finalize(self): + execution_order.append("finalize") + + sm = AsyncOrderedCallbackMachine() + await sm.activate_initial_state() + await sm.transition() + + assert execution_order == [ + "before", + "exit", + "on", + "enter", + "after", + "finalize", + ], "validate run ordering"