Skip to content

Commit 4c1383e

Browse files
authored
Merge pull request #250 from UiPath/fix/shared_memory
fix: shared memory for evals
2 parents c9dfedc + 8dca8d6 commit 4c1383e

File tree

9 files changed

+238
-241
lines changed

9 files changed

+238
-241
lines changed

src/uipath_langchain/_cli/_runtime/_context.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/uipath_langchain/_cli/_runtime/_input.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@
88
UiPathErrorCategory,
99
UiPathResumeTrigger,
1010
UiPathResumeTriggerType,
11+
UiPathRuntimeContext,
1112
)
1213
from uipath._cli._runtime._hitl import HitlReader
1314

14-
from ._context import LangGraphRuntimeContext
1515
from ._conversation import uipath_to_human_messages
1616
from ._exception import LangGraphErrorCode, LangGraphRuntimeError
1717

1818
logger = logging.getLogger(__name__)
1919

2020

2121
async def get_graph_input(
22-
context: LangGraphRuntimeContext,
22+
context: UiPathRuntimeContext,
2323
memory: AsyncSqliteSaver,
2424
resume_triggers_table: str = "__uipath_resume_triggers",
2525
) -> Any:
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
from contextlib import asynccontextmanager
3+
4+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
5+
from uipath._cli._runtime._contracts import UiPathRuntimeContext
6+
7+
8+
def get_connection_string(context: UiPathRuntimeContext) -> str:
9+
if context.runtime_dir and context.state_file:
10+
path = os.path.join(context.runtime_dir, context.state_file)
11+
if not context.resume and context.job_id is None:
12+
# If not resuming and no job id, delete the previous state file
13+
if os.path.exists(path):
14+
os.remove(path)
15+
os.makedirs(context.runtime_dir, exist_ok=True)
16+
return path
17+
return os.path.join("__uipath", "state.db")
18+
19+
20+
@asynccontextmanager
21+
async def get_memory(context: UiPathRuntimeContext):
22+
"""Create and manage the AsyncSqliteSaver instance."""
23+
async with AsyncSqliteSaver.from_conn_string(
24+
get_connection_string(context)
25+
) as memory:
26+
yield memory

src/uipath_langchain/_cli/_runtime/_runtime.py

Lines changed: 81 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
import os
3-
from contextlib import asynccontextmanager
4-
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Sequence
3+
from typing import Any, AsyncGenerator, Optional, Sequence
54
from uuid import uuid4
65

76
from langchain_core.runnables.config import RunnableConfig
@@ -16,6 +15,7 @@
1615
UiPathErrorCategory,
1716
UiPathErrorCode,
1817
UiPathResumeTrigger,
18+
UiPathRuntimeContext,
1919
UiPathRuntimeResult,
2020
UiPathRuntimeStatus,
2121
)
@@ -27,7 +27,6 @@
2727
)
2828

2929
from .._utils._schema import generate_schema_from_graph
30-
from ._context import LangGraphRuntimeContext
3130
from ._exception import LangGraphErrorCode, LangGraphRuntimeError
3231
from ._graph_resolver import AsyncResolver, LangGraphJsonResolver
3332
from ._input import get_graph_input
@@ -42,58 +41,40 @@ class LangGraphRuntime(UiPathBaseRuntime):
4241
This allows using the class with 'async with' statements.
4342
"""
4443

45-
def __init__(self, context: LangGraphRuntimeContext, graph_resolver: AsyncResolver):
44+
def __init__(
45+
self,
46+
context: UiPathRuntimeContext,
47+
graph_resolver: AsyncResolver,
48+
memory: AsyncSqliteSaver,
49+
):
4650
super().__init__(context)
47-
self.context: LangGraphRuntimeContext = context
51+
self.context: UiPathRuntimeContext = context
4852
self.graph_resolver: AsyncResolver = graph_resolver
53+
self.memory: AsyncSqliteSaver = memory
4954
self.resume_triggers_table: str = "__uipath_resume_triggers"
5055

51-
@asynccontextmanager
52-
async def _get_or_create_memory(self) -> AsyncIterator[AsyncSqliteSaver]:
53-
"""
54-
Get existing memory from context or create a new one.
55-
56-
If memory is created, it will be automatically disposed at the end.
57-
If memory already exists in context, it will be reused without disposal.
58-
59-
Yields:
60-
AsyncSqliteSaver instance
61-
"""
62-
# Check if memory already exists in context
63-
if self.context.memory is not None:
64-
# Use existing memory, don't dispose
65-
yield self.context.memory
66-
else:
67-
# Create new memory and dispose at the end
68-
async with AsyncSqliteSaver.from_conn_string(
69-
self.state_file_path
70-
) as memory:
71-
yield memory
72-
# Memory is automatically disposed by the context manager
73-
7456
async def execute(self) -> Optional[UiPathRuntimeResult]:
7557
"""Execute the graph with the provided input and configuration."""
7658
graph = await self.graph_resolver()
7759
if not graph:
7860
return None
7961

8062
try:
81-
async with self._get_or_create_memory() as memory:
82-
compiled_graph = await self._setup_graph(memory, graph)
83-
graph_input = await self._get_graph_input(memory)
84-
graph_config = self._get_graph_config()
85-
86-
# Execute without streaming
87-
graph_output = await compiled_graph.ainvoke(
88-
graph_input,
89-
graph_config,
90-
interrupt_before=self.context.breakpoints,
91-
)
63+
compiled_graph = await self._setup_graph(self.memory, graph)
64+
graph_input = await self._get_graph_input(self.memory)
65+
graph_config = self._get_graph_config()
66+
67+
# Execute without streaming
68+
graph_output = await compiled_graph.ainvoke(
69+
graph_input,
70+
graph_config,
71+
interrupt_before=self.context.breakpoints,
72+
)
9273

93-
# Get final state and create result
94-
self.context.result = await self._create_runtime_result(
95-
compiled_graph, graph_config, memory, graph_output
96-
)
74+
# Get final state and create result
75+
self.context.result = await self._create_runtime_result(
76+
compiled_graph, graph_config, self.memory, graph_output
77+
)
9778

9879
return self.context.result
9980

@@ -137,61 +118,60 @@ async def stream(
137118
return
138119

139120
try:
140-
async with self._get_or_create_memory() as memory:
141-
compiled_graph = await self._setup_graph(memory, graph)
142-
graph_input = await self._get_graph_input(memory)
143-
graph_config = self._get_graph_config()
144-
145-
# Track final chunk for result creation
146-
final_chunk: Optional[dict[Any, Any]] = None
147-
148-
# Stream events from graph
149-
async for stream_chunk in compiled_graph.astream(
150-
graph_input,
151-
graph_config,
152-
interrupt_before=self.context.breakpoints,
153-
stream_mode=["messages", "updates"],
154-
subgraphs=True,
155-
):
156-
_, chunk_type, data = stream_chunk
157-
158-
# Emit UiPathAgentMessageEvent for messages
159-
if chunk_type == "messages":
160-
if isinstance(data, tuple):
161-
message, _ = data
162-
event = UiPathAgentMessageEvent(
163-
payload=message,
164-
execution_id=self.context.execution_id,
165-
)
166-
yield event
167-
168-
# Emit UiPathAgentStateEvent for state updates
169-
elif chunk_type == "updates":
170-
if isinstance(data, dict):
171-
final_chunk = data
172-
173-
# Emit state update event for each node
174-
for node_name, agent_data in data.items():
175-
if isinstance(agent_data, dict):
176-
state_event = UiPathAgentStateEvent(
177-
payload=agent_data,
178-
node_name=node_name,
179-
execution_id=self.context.execution_id,
180-
)
181-
yield state_event
182-
183-
# Extract output from final chunk
184-
graph_output = self._extract_graph_result(
185-
final_chunk, compiled_graph.output_channels
186-
)
121+
compiled_graph = await self._setup_graph(self.memory, graph)
122+
graph_input = await self._get_graph_input(self.memory)
123+
graph_config = self._get_graph_config()
124+
125+
# Track final chunk for result creation
126+
final_chunk: Optional[dict[Any, Any]] = None
127+
128+
# Stream events from graph
129+
async for stream_chunk in compiled_graph.astream(
130+
graph_input,
131+
graph_config,
132+
interrupt_before=self.context.breakpoints,
133+
stream_mode=["messages", "updates"],
134+
subgraphs=True,
135+
):
136+
_, chunk_type, data = stream_chunk
137+
138+
# Emit UiPathAgentMessageEvent for messages
139+
if chunk_type == "messages":
140+
if isinstance(data, tuple):
141+
message, _ = data
142+
event = UiPathAgentMessageEvent(
143+
payload=message,
144+
execution_id=self.context.execution_id,
145+
)
146+
yield event
147+
148+
# Emit UiPathAgentStateEvent for state updates
149+
elif chunk_type == "updates":
150+
if isinstance(data, dict):
151+
final_chunk = data
152+
153+
# Emit state update event for each node
154+
for node_name, agent_data in data.items():
155+
if isinstance(agent_data, dict):
156+
state_event = UiPathAgentStateEvent(
157+
payload=agent_data,
158+
node_name=node_name,
159+
execution_id=self.context.execution_id,
160+
)
161+
yield state_event
162+
163+
# Extract output from final chunk
164+
graph_output = self._extract_graph_result(
165+
final_chunk, compiled_graph.output_channels
166+
)
187167

188-
# Get final state and create result
189-
self.context.result = await self._create_runtime_result(
190-
compiled_graph, graph_config, memory, graph_output
191-
)
168+
# Get final state and create result
169+
self.context.result = await self._create_runtime_result(
170+
compiled_graph, graph_config, self.memory, graph_output
171+
)
192172

193-
# Yield the final result as last event
194-
yield self.context.result
173+
# Yield the final result as last event
174+
yield self.context.result
195175

196176
except Exception as e:
197177
raise self._create_runtime_error(e) from e
@@ -480,10 +460,13 @@ class LangGraphScriptRuntime(LangGraphRuntime):
480460
"""
481461

482462
def __init__(
483-
self, context: LangGraphRuntimeContext, entrypoint: Optional[str] = None
463+
self,
464+
context: UiPathRuntimeContext,
465+
memory: AsyncSqliteSaver,
466+
entrypoint: Optional[str] = None,
484467
):
485468
self.resolver = LangGraphJsonResolver(entrypoint=entrypoint)
486-
super().__init__(context, self.resolver)
469+
super().__init__(context, self.resolver, memory=memory)
487470

488471
@override
489472
async def get_entrypoint(self) -> Entrypoint:

src/uipath_langchain/_cli/cli_debug.py

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32
from typing import Optional
43

54
from openinference.instrumentation.langchain import (
@@ -9,17 +8,16 @@
98
from uipath._cli._debug._bridge import UiPathDebugBridge, get_debug_bridge
109
from uipath._cli._debug._runtime import UiPathDebugRuntime
1110
from uipath._cli._runtime._contracts import (
11+
UiPathRuntimeContext,
1212
UiPathRuntimeFactory,
1313
)
1414
from uipath._cli.middlewares import MiddlewareResult
1515
from uipath.tracing import LlmOpsHttpExporter
1616

1717
from .._tracing import _instrument_traceable_attributes
1818
from ._runtime._exception import LangGraphRuntimeError
19-
from ._runtime._runtime import ( # type: ignore[attr-defined]
20-
LangGraphRuntimeContext,
21-
LangGraphScriptRuntime,
22-
)
19+
from ._runtime._memory import get_memory
20+
from ._runtime._runtime import LangGraphScriptRuntime
2321
from ._utils._graph import LangGraphConfig
2422

2523

@@ -36,46 +34,40 @@ def langgraph_debug_middleware(
3634
try:
3735

3836
async def execute():
39-
context = LangGraphRuntimeContext.with_defaults(**kwargs)
37+
context = UiPathRuntimeContext.with_defaults(**kwargs)
4038
context.entrypoint = entrypoint
4139
context.input = input
4240
context.resume = resume
43-
context.execution_id = context.job_id or "default"
4441

4542
_instrument_traceable_attributes()
4643

47-
def generate_runtime(
48-
ctx: LangGraphRuntimeContext,
49-
) -> LangGraphScriptRuntime:
50-
runtime = LangGraphScriptRuntime(ctx, ctx.entrypoint)
51-
# If not resuming and no job id, delete the previous state file
52-
if not ctx.resume and ctx.job_id is None:
53-
if os.path.exists(runtime.state_file_path):
54-
os.remove(runtime.state_file_path)
55-
return runtime
44+
async with get_memory(context) as memory:
45+
runtime_factory = UiPathRuntimeFactory(
46+
LangGraphScriptRuntime,
47+
UiPathRuntimeContext,
48+
runtime_generator=lambda ctx: LangGraphScriptRuntime(
49+
ctx, memory, ctx.entrypoint
50+
),
51+
context_generator=lambda: context,
52+
)
5653

57-
runtime_factory = UiPathRuntimeFactory(
58-
LangGraphScriptRuntime,
59-
LangGraphRuntimeContext,
60-
runtime_generator=generate_runtime,
61-
context_generator=lambda: context,
62-
)
54+
if context.job_id:
55+
runtime_factory.add_span_exporter(
56+
LlmOpsHttpExporter(extra_process_spans=True)
57+
)
6358

64-
if context.job_id:
65-
runtime_factory.add_span_exporter(
66-
LlmOpsHttpExporter(extra_process_spans=True)
59+
runtime_factory.add_instrumentor(
60+
LangChainInstrumentor, get_current_span
6761
)
6862

69-
runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span)
70-
71-
debug_bridge: UiPathDebugBridge = get_debug_bridge(context)
63+
debug_bridge: UiPathDebugBridge = get_debug_bridge(context)
7264

73-
async with UiPathDebugRuntime.from_debug_context(
74-
factory=runtime_factory,
75-
context=context,
76-
debug_bridge=debug_bridge,
77-
) as debug_runtime:
78-
await debug_runtime.execute()
65+
async with UiPathDebugRuntime.from_debug_context(
66+
factory=runtime_factory,
67+
context=context,
68+
debug_bridge=debug_bridge,
69+
) as debug_runtime:
70+
await debug_runtime.execute()
7971

8072
asyncio.run(execute())
8173

0 commit comments

Comments
 (0)