11import logging
22import os
3- from contextlib import asynccontextmanager
4- from typing import Any , AsyncGenerator , AsyncIterator , Optional , Sequence
3+ from typing import Any , AsyncGenerator , Optional , Sequence
54from uuid import uuid4
65
76from langchain_core .runnables .config import RunnableConfig
1615 UiPathErrorCategory ,
1716 UiPathErrorCode ,
1817 UiPathResumeTrigger ,
18+ UiPathRuntimeContext ,
1919 UiPathRuntimeResult ,
2020 UiPathRuntimeStatus ,
2121)
2727)
2828
2929from .._utils ._schema import generate_schema_from_graph
30- from ._context import LangGraphRuntimeContext
3130from ._exception import LangGraphErrorCode , LangGraphRuntimeError
3231from ._graph_resolver import AsyncResolver , LangGraphJsonResolver
3332from ._input import get_graph_input
@@ -42,58 +41,40 @@ class LangGraphRuntime(UiPathBaseRuntime):
4241 This allows using the class with 'async with' statements.
4342 """
4443
45- def __init__ (self , context : LangGraphRuntimeContext , graph_resolver : AsyncResolver ):
44+ def __init__ (
45+ self ,
46+ context : UiPathRuntimeContext ,
47+ graph_resolver : AsyncResolver ,
48+ memory : AsyncSqliteSaver ,
49+ ):
4650 super ().__init__ (context )
47- self .context : LangGraphRuntimeContext = context
51+ self .context : UiPathRuntimeContext = context
4852 self .graph_resolver : AsyncResolver = graph_resolver
53+ self .memory : AsyncSqliteSaver = memory
4954 self .resume_triggers_table : str = "__uipath_resume_triggers"
5055
51- @asynccontextmanager
52- async def _get_or_create_memory (self ) -> AsyncIterator [AsyncSqliteSaver ]:
53- """
54- Get existing memory from context or create a new one.
55-
56- If memory is created, it will be automatically disposed at the end.
57- If memory already exists in context, it will be reused without disposal.
58-
59- Yields:
60- AsyncSqliteSaver instance
61- """
62- # Check if memory already exists in context
63- if self .context .memory is not None :
64- # Use existing memory, don't dispose
65- yield self .context .memory
66- else :
67- # Create new memory and dispose at the end
68- async with AsyncSqliteSaver .from_conn_string (
69- self .state_file_path
70- ) as memory :
71- yield memory
72- # Memory is automatically disposed by the context manager
73-
7456 async def execute (self ) -> Optional [UiPathRuntimeResult ]:
7557 """Execute the graph with the provided input and configuration."""
7658 graph = await self .graph_resolver ()
7759 if not graph :
7860 return None
7961
8062 try :
81- async with self ._get_or_create_memory () as memory :
82- compiled_graph = await self ._setup_graph (memory , graph )
83- graph_input = await self ._get_graph_input (memory )
84- graph_config = self ._get_graph_config ()
85-
86- # Execute without streaming
87- graph_output = await compiled_graph .ainvoke (
88- graph_input ,
89- graph_config ,
90- interrupt_before = self .context .breakpoints ,
91- )
63+ compiled_graph = await self ._setup_graph (self .memory , graph )
64+ graph_input = await self ._get_graph_input (self .memory )
65+ graph_config = self ._get_graph_config ()
66+
67+ # Execute without streaming
68+ graph_output = await compiled_graph .ainvoke (
69+ graph_input ,
70+ graph_config ,
71+ interrupt_before = self .context .breakpoints ,
72+ )
9273
93- # Get final state and create result
94- self .context .result = await self ._create_runtime_result (
95- compiled_graph , graph_config , memory , graph_output
96- )
74+ # Get final state and create result
75+ self .context .result = await self ._create_runtime_result (
76+ compiled_graph , graph_config , self . memory , graph_output
77+ )
9778
9879 return self .context .result
9980
@@ -137,61 +118,60 @@ async def stream(
137118 return
138119
139120 try :
140- async with self ._get_or_create_memory () as memory :
141- compiled_graph = await self ._setup_graph (memory , graph )
142- graph_input = await self ._get_graph_input (memory )
143- graph_config = self ._get_graph_config ()
144-
145- # Track final chunk for result creation
146- final_chunk : Optional [dict [Any , Any ]] = None
147-
148- # Stream events from graph
149- async for stream_chunk in compiled_graph .astream (
150- graph_input ,
151- graph_config ,
152- interrupt_before = self .context .breakpoints ,
153- stream_mode = ["messages" , "updates" ],
154- subgraphs = True ,
155- ):
156- _ , chunk_type , data = stream_chunk
157-
158- # Emit UiPathAgentMessageEvent for messages
159- if chunk_type == "messages" :
160- if isinstance (data , tuple ):
161- message , _ = data
162- event = UiPathAgentMessageEvent (
163- payload = message ,
164- execution_id = self .context .execution_id ,
165- )
166- yield event
167-
168- # Emit UiPathAgentStateEvent for state updates
169- elif chunk_type == "updates" :
170- if isinstance (data , dict ):
171- final_chunk = data
172-
173- # Emit state update event for each node
174- for node_name , agent_data in data .items ():
175- if isinstance (agent_data , dict ):
176- state_event = UiPathAgentStateEvent (
177- payload = agent_data ,
178- node_name = node_name ,
179- execution_id = self .context .execution_id ,
180- )
181- yield state_event
182-
183- # Extract output from final chunk
184- graph_output = self ._extract_graph_result (
185- final_chunk , compiled_graph .output_channels
186- )
121+ compiled_graph = await self ._setup_graph (self .memory , graph )
122+ graph_input = await self ._get_graph_input (self .memory )
123+ graph_config = self ._get_graph_config ()
124+
125+ # Track final chunk for result creation
126+ final_chunk : Optional [dict [Any , Any ]] = None
127+
128+ # Stream events from graph
129+ async for stream_chunk in compiled_graph .astream (
130+ graph_input ,
131+ graph_config ,
132+ interrupt_before = self .context .breakpoints ,
133+ stream_mode = ["messages" , "updates" ],
134+ subgraphs = True ,
135+ ):
136+ _ , chunk_type , data = stream_chunk
137+
138+ # Emit UiPathAgentMessageEvent for messages
139+ if chunk_type == "messages" :
140+ if isinstance (data , tuple ):
141+ message , _ = data
142+ event = UiPathAgentMessageEvent (
143+ payload = message ,
144+ execution_id = self .context .execution_id ,
145+ )
146+ yield event
147+
148+ # Emit UiPathAgentStateEvent for state updates
149+ elif chunk_type == "updates" :
150+ if isinstance (data , dict ):
151+ final_chunk = data
152+
153+ # Emit state update event for each node
154+ for node_name , agent_data in data .items ():
155+ if isinstance (agent_data , dict ):
156+ state_event = UiPathAgentStateEvent (
157+ payload = agent_data ,
158+ node_name = node_name ,
159+ execution_id = self .context .execution_id ,
160+ )
161+ yield state_event
162+
163+ # Extract output from final chunk
164+ graph_output = self ._extract_graph_result (
165+ final_chunk , compiled_graph .output_channels
166+ )
187167
188- # Get final state and create result
189- self .context .result = await self ._create_runtime_result (
190- compiled_graph , graph_config , memory , graph_output
191- )
168+ # Get final state and create result
169+ self .context .result = await self ._create_runtime_result (
170+ compiled_graph , graph_config , self . memory , graph_output
171+ )
192172
193- # Yield the final result as last event
194- yield self .context .result
173+ # Yield the final result as last event
174+ yield self .context .result
195175
196176 except Exception as e :
197177 raise self ._create_runtime_error (e ) from e
@@ -480,10 +460,13 @@ class LangGraphScriptRuntime(LangGraphRuntime):
480460 """
481461
482462 def __init__ (
483- self , context : LangGraphRuntimeContext , entrypoint : Optional [str ] = None
463+ self ,
464+ context : UiPathRuntimeContext ,
465+ memory : AsyncSqliteSaver ,
466+ entrypoint : Optional [str ] = None ,
484467 ):
485468 self .resolver = LangGraphJsonResolver (entrypoint = entrypoint )
486- super ().__init__ (context , self .resolver )
469+ super ().__init__ (context , self .resolver , memory = memory )
487470
488471 @override
489472 async def get_entrypoint (self ) -> Entrypoint :
0 commit comments