Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions src/uipath_langchain/_cli/_runtime/_context.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/uipath_langchain/_cli/_runtime/_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
UiPathErrorCategory,
UiPathResumeTrigger,
UiPathResumeTriggerType,
UiPathRuntimeContext,
)
from uipath._cli._runtime._hitl import HitlReader

from ._context import LangGraphRuntimeContext
from ._conversation import uipath_to_human_messages
from ._exception import LangGraphErrorCode, LangGraphRuntimeError

logger = logging.getLogger(__name__)


async def get_graph_input(
context: LangGraphRuntimeContext,
context: UiPathRuntimeContext,
memory: AsyncSqliteSaver,
resume_triggers_table: str = "__uipath_resume_triggers",
) -> Any:
Expand Down
26 changes: 26 additions & 0 deletions src/uipath_langchain/_cli/_runtime/_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from contextlib import asynccontextmanager

from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from uipath._cli._runtime._contracts import UiPathRuntimeContext


def get_connection_string(context: UiPathRuntimeContext) -> str:
if context.runtime_dir and context.state_file:
path = os.path.join(context.runtime_dir, context.state_file)
if not context.resume and context.job_id is None:
# If not resuming and no job id, delete the previous state file
if os.path.exists(path):
os.remove(path)
os.makedirs(context.runtime_dir, exist_ok=True)
return path
return os.path.join("__uipath", "state.db")


@asynccontextmanager
async def get_memory(context: UiPathRuntimeContext):
"""Create and manage the AsyncSqliteSaver instance."""
async with AsyncSqliteSaver.from_conn_string(
get_connection_string(context)
) as memory:
yield memory
179 changes: 81 additions & 98 deletions src/uipath_langchain/_cli/_runtime/_runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Sequence
from typing import Any, AsyncGenerator, Optional, Sequence
from uuid import uuid4

from langchain_core.runnables.config import RunnableConfig
Expand All @@ -16,6 +15,7 @@
UiPathErrorCategory,
UiPathErrorCode,
UiPathResumeTrigger,
UiPathRuntimeContext,
UiPathRuntimeResult,
UiPathRuntimeStatus,
)
Expand All @@ -27,7 +27,6 @@
)

from .._utils._schema import generate_schema_from_graph
from ._context import LangGraphRuntimeContext
from ._exception import LangGraphErrorCode, LangGraphRuntimeError
from ._graph_resolver import AsyncResolver, LangGraphJsonResolver
from ._input import get_graph_input
Expand All @@ -42,58 +41,40 @@ class LangGraphRuntime(UiPathBaseRuntime):
This allows using the class with 'async with' statements.
"""

def __init__(self, context: LangGraphRuntimeContext, graph_resolver: AsyncResolver):
def __init__(
self,
context: UiPathRuntimeContext,
graph_resolver: AsyncResolver,
memory: AsyncSqliteSaver,
):
super().__init__(context)
self.context: LangGraphRuntimeContext = context
self.context: UiPathRuntimeContext = context
self.graph_resolver: AsyncResolver = graph_resolver
self.memory: AsyncSqliteSaver = memory
self.resume_triggers_table: str = "__uipath_resume_triggers"

@asynccontextmanager
async def _get_or_create_memory(self) -> AsyncIterator[AsyncSqliteSaver]:
"""
Get existing memory from context or create a new one.

If memory is created, it will be automatically disposed at the end.
If memory already exists in context, it will be reused without disposal.

Yields:
AsyncSqliteSaver instance
"""
# Check if memory already exists in context
if self.context.memory is not None:
# Use existing memory, don't dispose
yield self.context.memory
else:
# Create new memory and dispose at the end
async with AsyncSqliteSaver.from_conn_string(
self.state_file_path
) as memory:
yield memory
# Memory is automatically disposed by the context manager

async def execute(self) -> Optional[UiPathRuntimeResult]:
"""Execute the graph with the provided input and configuration."""
graph = await self.graph_resolver()
if not graph:
return None

try:
async with self._get_or_create_memory() as memory:
compiled_graph = await self._setup_graph(memory, graph)
graph_input = await self._get_graph_input(memory)
graph_config = self._get_graph_config()

# Execute without streaming
graph_output = await compiled_graph.ainvoke(
graph_input,
graph_config,
interrupt_before=self.context.breakpoints,
)
compiled_graph = await self._setup_graph(self.memory, graph)
graph_input = await self._get_graph_input(self.memory)
graph_config = self._get_graph_config()

# Execute without streaming
graph_output = await compiled_graph.ainvoke(
graph_input,
graph_config,
interrupt_before=self.context.breakpoints,
)

# Get final state and create result
self.context.result = await self._create_runtime_result(
compiled_graph, graph_config, memory, graph_output
)
# Get final state and create result
self.context.result = await self._create_runtime_result(
compiled_graph, graph_config, self.memory, graph_output
)

return self.context.result

Expand Down Expand Up @@ -137,61 +118,60 @@ async def stream(
return

try:
async with self._get_or_create_memory() as memory:
compiled_graph = await self._setup_graph(memory, graph)
graph_input = await self._get_graph_input(memory)
graph_config = self._get_graph_config()

# Track final chunk for result creation
final_chunk: Optional[dict[Any, Any]] = None

# Stream events from graph
async for stream_chunk in compiled_graph.astream(
graph_input,
graph_config,
interrupt_before=self.context.breakpoints,
stream_mode=["messages", "updates"],
subgraphs=True,
):
_, chunk_type, data = stream_chunk

# Emit UiPathAgentMessageEvent for messages
if chunk_type == "messages":
if isinstance(data, tuple):
message, _ = data
event = UiPathAgentMessageEvent(
payload=message,
execution_id=self.context.execution_id,
)
yield event

# Emit UiPathAgentStateEvent for state updates
elif chunk_type == "updates":
if isinstance(data, dict):
final_chunk = data

# Emit state update event for each node
for node_name, agent_data in data.items():
if isinstance(agent_data, dict):
state_event = UiPathAgentStateEvent(
payload=agent_data,
node_name=node_name,
execution_id=self.context.execution_id,
)
yield state_event

# Extract output from final chunk
graph_output = self._extract_graph_result(
final_chunk, compiled_graph.output_channels
)
compiled_graph = await self._setup_graph(self.memory, graph)
graph_input = await self._get_graph_input(self.memory)
graph_config = self._get_graph_config()

# Track final chunk for result creation
final_chunk: Optional[dict[Any, Any]] = None

# Stream events from graph
async for stream_chunk in compiled_graph.astream(
graph_input,
graph_config,
interrupt_before=self.context.breakpoints,
stream_mode=["messages", "updates"],
subgraphs=True,
):
_, chunk_type, data = stream_chunk

# Emit UiPathAgentMessageEvent for messages
if chunk_type == "messages":
if isinstance(data, tuple):
message, _ = data
event = UiPathAgentMessageEvent(
payload=message,
execution_id=self.context.execution_id,
)
yield event

# Emit UiPathAgentStateEvent for state updates
elif chunk_type == "updates":
if isinstance(data, dict):
final_chunk = data

# Emit state update event for each node
for node_name, agent_data in data.items():
if isinstance(agent_data, dict):
state_event = UiPathAgentStateEvent(
payload=agent_data,
node_name=node_name,
execution_id=self.context.execution_id,
)
yield state_event

# Extract output from final chunk
graph_output = self._extract_graph_result(
final_chunk, compiled_graph.output_channels
)

# Get final state and create result
self.context.result = await self._create_runtime_result(
compiled_graph, graph_config, memory, graph_output
)
# Get final state and create result
self.context.result = await self._create_runtime_result(
compiled_graph, graph_config, self.memory, graph_output
)

# Yield the final result as last event
yield self.context.result
# Yield the final result as last event
yield self.context.result

except Exception as e:
raise self._create_runtime_error(e) from e
Expand Down Expand Up @@ -480,10 +460,13 @@ class LangGraphScriptRuntime(LangGraphRuntime):
"""

def __init__(
self, context: LangGraphRuntimeContext, entrypoint: Optional[str] = None
self,
context: UiPathRuntimeContext,
memory: AsyncSqliteSaver,
entrypoint: Optional[str] = None,
):
self.resolver = LangGraphJsonResolver(entrypoint=entrypoint)
super().__init__(context, self.resolver)
super().__init__(context, self.resolver, memory=memory)

@override
async def get_entrypoint(self) -> Entrypoint:
Expand Down
60 changes: 26 additions & 34 deletions src/uipath_langchain/_cli/cli_debug.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import os
from typing import Optional

from openinference.instrumentation.langchain import (
Expand All @@ -9,17 +8,16 @@
from uipath._cli._debug._bridge import UiPathDebugBridge, get_debug_bridge
from uipath._cli._debug._runtime import UiPathDebugRuntime
from uipath._cli._runtime._contracts import (
UiPathRuntimeContext,
UiPathRuntimeFactory,
)
from uipath._cli.middlewares import MiddlewareResult
from uipath.tracing import LlmOpsHttpExporter

from .._tracing import _instrument_traceable_attributes
from ._runtime._exception import LangGraphRuntimeError
from ._runtime._runtime import ( # type: ignore[attr-defined]
LangGraphRuntimeContext,
LangGraphScriptRuntime,
)
from ._runtime._memory import get_memory
from ._runtime._runtime import LangGraphScriptRuntime
from ._utils._graph import LangGraphConfig


Expand All @@ -36,46 +34,40 @@ def langgraph_debug_middleware(
try:

async def execute():
context = LangGraphRuntimeContext.with_defaults(**kwargs)
context = UiPathRuntimeContext.with_defaults(**kwargs)
context.entrypoint = entrypoint
context.input = input
context.resume = resume
context.execution_id = context.job_id or "default"

_instrument_traceable_attributes()

def generate_runtime(
ctx: LangGraphRuntimeContext,
) -> LangGraphScriptRuntime:
runtime = LangGraphScriptRuntime(ctx, ctx.entrypoint)
# If not resuming and no job id, delete the previous state file
if not ctx.resume and ctx.job_id is None:
if os.path.exists(runtime.state_file_path):
os.remove(runtime.state_file_path)
return runtime
async with get_memory(context) as memory:
runtime_factory = UiPathRuntimeFactory(
LangGraphScriptRuntime,
UiPathRuntimeContext,
runtime_generator=lambda ctx: LangGraphScriptRuntime(
ctx, memory, ctx.entrypoint
),
context_generator=lambda: context,
)

runtime_factory = UiPathRuntimeFactory(
LangGraphScriptRuntime,
LangGraphRuntimeContext,
runtime_generator=generate_runtime,
context_generator=lambda: context,
)
if context.job_id:
runtime_factory.add_span_exporter(
LlmOpsHttpExporter(extra_process_spans=True)
)

if context.job_id:
runtime_factory.add_span_exporter(
LlmOpsHttpExporter(extra_process_spans=True)
runtime_factory.add_instrumentor(
LangChainInstrumentor, get_current_span
)

runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span)

debug_bridge: UiPathDebugBridge = get_debug_bridge(context)
debug_bridge: UiPathDebugBridge = get_debug_bridge(context)

async with UiPathDebugRuntime.from_debug_context(
factory=runtime_factory,
context=context,
debug_bridge=debug_bridge,
) as debug_runtime:
await debug_runtime.execute()
async with UiPathDebugRuntime.from_debug_context(
factory=runtime_factory,
context=context,
debug_bridge=debug_bridge,
) as debug_runtime:
await debug_runtime.execute()

asyncio.run(execute())

Expand Down
Loading