1
+ from __future__ import annotations
2
+
1
3
import asyncio
2
4
import functools
3
5
import uuid
4
6
from typing import (
5
7
TYPE_CHECKING ,
6
8
Any ,
7
9
Awaitable ,
8
- Dict ,
9
- List ,
10
- Optional ,
11
10
Protocol ,
12
- Set ,
13
11
Type ,
14
12
)
15
13
@@ -33,20 +31,20 @@ class CheckpointCallback(Protocol):
33
31
def __call__ (
34
32
self ,
35
33
run_id : str ,
36
- last_completed_step : Optional [ str ] ,
37
- input_ev : Optional [ Event ] ,
38
- output_ev : Optional [ Event ] ,
34
+ last_completed_step : str | None ,
35
+ input_ev : Event | None ,
36
+ output_ev : Event | None ,
39
37
ctx : "Context" ,
40
38
) -> Awaitable [None ]: ...
41
39
42
40
43
41
class Checkpoint (BaseModel ):
44
42
model_config = ConfigDict (arbitrary_types_allowed = True )
45
43
id_ : str = Field (default_factory = lambda : str (uuid .uuid4 ()))
46
- last_completed_step : Optional [ str ]
47
- input_event : Optional [ Event ]
48
- output_event : Optional [ Event ]
49
- ctx_state : Dict [str , Any ]
44
+ last_completed_step : str | None
45
+ input_event : Event | None
46
+ output_event : Event | None
47
+ ctx_state : dict [str , Any ]
50
48
51
49
52
50
class WorkflowCheckpointer :
@@ -63,8 +61,8 @@ class WorkflowCheckpointer:
63
61
def __init__ (
64
62
self ,
65
63
workflow : "Workflow" ,
66
- checkpoint_serializer : Optional [ BaseSerializer ] = None ,
67
- disabled_steps : List [str ] = [],
64
+ checkpoint_serializer : BaseSerializer | None = None ,
65
+ disabled_steps : list [str ] = [],
68
66
):
69
67
"""
70
68
Create a WorkflowCheckpointer object.
@@ -73,15 +71,15 @@ def __init__(
73
71
workflow (Workflow): The wrapped workflow.
74
72
checkpoint_serializer (Optional[BaseSerializer], optional): The serializer to use
75
73
for serializing associated `Context` of a Workflow run. Defaults to None.
76
- disabled_steps (List [str], optional): Steps for which to disable checkpointing. Defaults to [].
74
+ disabled_steps (list [str], optional): Steps for which to disable checkpointing. Defaults to [].
77
75
78
76
"""
79
- self ._checkpoints : Dict [str , List [Checkpoint ]] = {}
77
+ self ._checkpoints : dict [str , list [Checkpoint ]] = {}
80
78
self ._checkpoint_serializer = checkpoint_serializer or JsonSerializer ()
81
79
self ._lock : asyncio .Lock = asyncio .Lock ()
82
80
83
81
self .workflow = workflow
84
- self .enabled_checkpoints : Set [str ] = {
82
+ self .enabled_checkpoints : set [str ] = {
85
83
k for k in workflow ._get_steps () if k != "_done"
86
84
}
87
85
for step_name in disabled_steps :
@@ -122,17 +120,17 @@ def run_from(self, checkpoint: Checkpoint, **kwargs: Any) -> "WorkflowHandler":
122
120
)
123
121
124
122
@property
125
- def checkpoints (self ) -> Dict [str , List [Checkpoint ]]:
123
+ def checkpoints (self ) -> dict [str , list [Checkpoint ]]:
126
124
return self ._checkpoints
127
125
128
126
def new_checkpoint_callback_for_run (self ) -> CheckpointCallback :
129
127
"""Closure to generate a new `CheckpointCallback` with a unique run-id."""
130
128
131
129
async def _create_checkpoint (
132
130
run_id : str ,
133
- last_completed_step : Optional [ str ] ,
134
- input_ev : Optional [ Event ] ,
135
- output_ev : Optional [ Event ] ,
131
+ last_completed_step : str | None ,
132
+ input_ev : Event | None ,
133
+ output_ev : Event | None ,
136
134
ctx : "Context" ,
137
135
) -> None :
138
136
"""Build a checkpoint around the last completed step."""
@@ -156,9 +154,9 @@ async def _create_checkpoint(
156
154
def _checkpoint_filter_condition (
157
155
self ,
158
156
ckpt : Checkpoint ,
159
- last_completed_step : Optional [ str ] ,
160
- input_event_type : Optional [ Type [Event ]] ,
161
- output_event_type : Optional [ Type [Event ]] ,
157
+ last_completed_step : str | None ,
158
+ input_event_type : Type [Event ] | None ,
159
+ output_event_type : Type [Event ] | None ,
162
160
) -> bool :
163
161
if last_completed_step and ckpt .last_completed_step is not last_completed_step :
164
162
return False
@@ -170,11 +168,11 @@ def _checkpoint_filter_condition(
170
168
171
169
def filter_checkpoints (
172
170
self ,
173
- run_id : Optional [ str ] = None ,
174
- last_completed_step : Optional [ str ] = None ,
175
- input_event_type : Optional [ Type [Event ]] = None ,
176
- output_event_type : Optional [ Type [Event ]] = None ,
177
- ) -> List [Checkpoint ]:
171
+ run_id : str | None = None ,
172
+ last_completed_step : str | None = None ,
173
+ input_event_type : Type [Event ] | None = None ,
174
+ output_event_type : Type [Event ] | None = None ,
175
+ ) -> list [Checkpoint ]:
178
176
"""Returns a list of Checkpoint's based on user provided filters."""
179
177
if (
180
178
not run_id
0 commit comments