From 8dca8d6a2eb1aec22234debcc1f1513c35a9a038 Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Mon, 3 Nov 2025 09:17:23 +0200 Subject: [PATCH] fix: shared memory for evals fixes #222 --- .../_cli/_runtime/_context.py | 10 - src/uipath_langchain/_cli/_runtime/_input.py | 4 +- src/uipath_langchain/_cli/_runtime/_memory.py | 26 +++ .../_cli/_runtime/_runtime.py | 179 ++++++++---------- src/uipath_langchain/_cli/cli_debug.py | 60 +++--- src/uipath_langchain/_cli/cli_dev.py | 33 ++-- src/uipath_langchain/_cli/cli_eval.py | 88 +++++---- src/uipath_langchain/_cli/cli_run.py | 70 +++---- src/uipath_langchain/runtime_factories.py | 9 +- 9 files changed, 238 insertions(+), 241 deletions(-) delete mode 100644 src/uipath_langchain/_cli/_runtime/_context.py create mode 100644 src/uipath_langchain/_cli/_runtime/_memory.py diff --git a/src/uipath_langchain/_cli/_runtime/_context.py b/src/uipath_langchain/_cli/_runtime/_context.py deleted file mode 100644 index 6165875..0000000 --- a/src/uipath_langchain/_cli/_runtime/_context.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional - -from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver -from uipath._cli._runtime._contracts import UiPathRuntimeContext - - -class LangGraphRuntimeContext(UiPathRuntimeContext): - """Context information passed throughout the runtime execution.""" - - memory: Optional[AsyncSqliteSaver] = None diff --git a/src/uipath_langchain/_cli/_runtime/_input.py b/src/uipath_langchain/_cli/_runtime/_input.py index ad0b252..3de45e8 100644 --- a/src/uipath_langchain/_cli/_runtime/_input.py +++ b/src/uipath_langchain/_cli/_runtime/_input.py @@ -8,10 +8,10 @@ UiPathErrorCategory, UiPathResumeTrigger, UiPathResumeTriggerType, + UiPathRuntimeContext, ) from uipath._cli._runtime._hitl import HitlReader -from ._context import LangGraphRuntimeContext from ._conversation import uipath_to_human_messages from ._exception import LangGraphErrorCode, LangGraphRuntimeError @@ -19,7 +19,7 @@ async def get_graph_input( - context: LangGraphRuntimeContext, + context: UiPathRuntimeContext, memory: AsyncSqliteSaver, resume_triggers_table: str = "__uipath_resume_triggers", ) -> Any: diff --git a/src/uipath_langchain/_cli/_runtime/_memory.py b/src/uipath_langchain/_cli/_runtime/_memory.py new file mode 100644 index 0000000..d805a5c --- /dev/null +++ b/src/uipath_langchain/_cli/_runtime/_memory.py @@ -0,0 +1,26 @@ +import os +from contextlib import asynccontextmanager + +from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver +from uipath._cli._runtime._contracts import UiPathRuntimeContext + + +def get_connection_string(context: UiPathRuntimeContext) -> str: + if context.runtime_dir and context.state_file: + path = os.path.join(context.runtime_dir, context.state_file) + if not context.resume and context.job_id is None: + # If not resuming and no job id, delete the previous state file + if os.path.exists(path): + os.remove(path) + os.makedirs(context.runtime_dir, exist_ok=True) + return path + return os.path.join("__uipath", "state.db") + + +@asynccontextmanager +async def get_memory(context: UiPathRuntimeContext): + """Create and manage the AsyncSqliteSaver instance.""" + async with AsyncSqliteSaver.from_conn_string( + get_connection_string(context) + ) as memory: + yield memory diff --git a/src/uipath_langchain/_cli/_runtime/_runtime.py b/src/uipath_langchain/_cli/_runtime/_runtime.py index 6fc20e7..73802b3 100644 --- a/src/uipath_langchain/_cli/_runtime/_runtime.py +++ b/src/uipath_langchain/_cli/_runtime/_runtime.py @@ -1,7 +1,6 @@ import logging import os -from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, AsyncIterator, Optional, Sequence +from typing import Any, AsyncGenerator, Optional, Sequence from uuid import uuid4 from langchain_core.runnables.config import RunnableConfig @@ -16,6 +15,7 @@ UiPathErrorCategory, UiPathErrorCode, UiPathResumeTrigger, + UiPathRuntimeContext, UiPathRuntimeResult, UiPathRuntimeStatus, ) @@ -27,7 +27,6 @@ ) from .._utils._schema import generate_schema_from_graph -from ._context import LangGraphRuntimeContext from ._exception import LangGraphErrorCode, LangGraphRuntimeError from ._graph_resolver import AsyncResolver, LangGraphJsonResolver from ._input import get_graph_input @@ -42,35 +41,18 @@ class LangGraphRuntime(UiPathBaseRuntime): This allows using the class with 'async with' statements. """ - def __init__(self, context: LangGraphRuntimeContext, graph_resolver: AsyncResolver): + def __init__( + self, + context: UiPathRuntimeContext, + graph_resolver: AsyncResolver, + memory: AsyncSqliteSaver, + ): super().__init__(context) - self.context: LangGraphRuntimeContext = context + self.context: UiPathRuntimeContext = context self.graph_resolver: AsyncResolver = graph_resolver + self.memory: AsyncSqliteSaver = memory self.resume_triggers_table: str = "__uipath_resume_triggers" - @asynccontextmanager - async def _get_or_create_memory(self) -> AsyncIterator[AsyncSqliteSaver]: - """ - Get existing memory from context or create a new one. - - If memory is created, it will be automatically disposed at the end. - If memory already exists in context, it will be reused without disposal. - - Yields: - AsyncSqliteSaver instance - """ - # Check if memory already exists in context - if self.context.memory is not None: - # Use existing memory, don't dispose - yield self.context.memory - else: - # Create new memory and dispose at the end - async with AsyncSqliteSaver.from_conn_string( - self.state_file_path - ) as memory: - yield memory - # Memory is automatically disposed by the context manager - async def execute(self) -> Optional[UiPathRuntimeResult]: """Execute the graph with the provided input and configuration.""" graph = await self.graph_resolver() @@ -78,22 +60,21 @@ async def execute(self) -> Optional[UiPathRuntimeResult]: return None try: - async with self._get_or_create_memory() as memory: - compiled_graph = await self._setup_graph(memory, graph) - graph_input = await self._get_graph_input(memory) - graph_config = self._get_graph_config() - - # Execute without streaming - graph_output = await compiled_graph.ainvoke( - graph_input, - graph_config, - interrupt_before=self.context.breakpoints, - ) + compiled_graph = await self._setup_graph(self.memory, graph) + graph_input = await self._get_graph_input(self.memory) + graph_config = self._get_graph_config() + + # Execute without streaming + graph_output = await compiled_graph.ainvoke( + graph_input, + graph_config, + interrupt_before=self.context.breakpoints, + ) - # Get final state and create result - self.context.result = await self._create_runtime_result( - compiled_graph, graph_config, memory, graph_output - ) + # Get final state and create result + self.context.result = await self._create_runtime_result( + compiled_graph, graph_config, self.memory, graph_output + ) return self.context.result @@ -137,61 +118,60 @@ async def stream( return try: - async with self._get_or_create_memory() as memory: - compiled_graph = await self._setup_graph(memory, graph) - graph_input = await self._get_graph_input(memory) - graph_config = self._get_graph_config() - - # Track final chunk for result creation - final_chunk: Optional[dict[Any, Any]] = None - - # Stream events from graph - async for stream_chunk in compiled_graph.astream( - graph_input, - graph_config, - interrupt_before=self.context.breakpoints, - stream_mode=["messages", "updates"], - subgraphs=True, - ): - _, chunk_type, data = stream_chunk - - # Emit UiPathAgentMessageEvent for messages - if chunk_type == "messages": - if isinstance(data, tuple): - message, _ = data - event = UiPathAgentMessageEvent( - payload=message, - execution_id=self.context.execution_id, - ) - yield event - - # Emit UiPathAgentStateEvent for state updates - elif chunk_type == "updates": - if isinstance(data, dict): - final_chunk = data - - # Emit state update event for each node - for node_name, agent_data in data.items(): - if isinstance(agent_data, dict): - state_event = UiPathAgentStateEvent( - payload=agent_data, - node_name=node_name, - execution_id=self.context.execution_id, - ) - yield state_event - - # Extract output from final chunk - graph_output = self._extract_graph_result( - final_chunk, compiled_graph.output_channels - ) + compiled_graph = await self._setup_graph(self.memory, graph) + graph_input = await self._get_graph_input(self.memory) + graph_config = self._get_graph_config() + + # Track final chunk for result creation + final_chunk: Optional[dict[Any, Any]] = None + + # Stream events from graph + async for stream_chunk in compiled_graph.astream( + graph_input, + graph_config, + interrupt_before=self.context.breakpoints, + stream_mode=["messages", "updates"], + subgraphs=True, + ): + _, chunk_type, data = stream_chunk + + # Emit UiPathAgentMessageEvent for messages + if chunk_type == "messages": + if isinstance(data, tuple): + message, _ = data + event = UiPathAgentMessageEvent( + payload=message, + execution_id=self.context.execution_id, + ) + yield event + + # Emit UiPathAgentStateEvent for state updates + elif chunk_type == "updates": + if isinstance(data, dict): + final_chunk = data + + # Emit state update event for each node + for node_name, agent_data in data.items(): + if isinstance(agent_data, dict): + state_event = UiPathAgentStateEvent( + payload=agent_data, + node_name=node_name, + execution_id=self.context.execution_id, + ) + yield state_event + + # Extract output from final chunk + graph_output = self._extract_graph_result( + final_chunk, compiled_graph.output_channels + ) - # Get final state and create result - self.context.result = await self._create_runtime_result( - compiled_graph, graph_config, memory, graph_output - ) + # Get final state and create result + self.context.result = await self._create_runtime_result( + compiled_graph, graph_config, self.memory, graph_output + ) - # Yield the final result as last event - yield self.context.result + # Yield the final result as last event + yield self.context.result except Exception as e: raise self._create_runtime_error(e) from e @@ -480,10 +460,13 @@ class LangGraphScriptRuntime(LangGraphRuntime): """ def __init__( - self, context: LangGraphRuntimeContext, entrypoint: Optional[str] = None + self, + context: UiPathRuntimeContext, + memory: AsyncSqliteSaver, + entrypoint: Optional[str] = None, ): self.resolver = LangGraphJsonResolver(entrypoint=entrypoint) - super().__init__(context, self.resolver) + super().__init__(context, self.resolver, memory=memory) @override async def get_entrypoint(self) -> Entrypoint: diff --git a/src/uipath_langchain/_cli/cli_debug.py b/src/uipath_langchain/_cli/cli_debug.py index 3c0c526..857fcec 100644 --- a/src/uipath_langchain/_cli/cli_debug.py +++ b/src/uipath_langchain/_cli/cli_debug.py @@ -1,5 +1,4 @@ import asyncio -import os from typing import Optional from openinference.instrumentation.langchain import ( @@ -9,6 +8,7 @@ from uipath._cli._debug._bridge import UiPathDebugBridge, get_debug_bridge from uipath._cli._debug._runtime import UiPathDebugRuntime from uipath._cli._runtime._contracts import ( + UiPathRuntimeContext, UiPathRuntimeFactory, ) from uipath._cli.middlewares import MiddlewareResult @@ -16,10 +16,8 @@ from .._tracing import _instrument_traceable_attributes from ._runtime._exception import LangGraphRuntimeError -from ._runtime._runtime import ( # type: ignore[attr-defined] - LangGraphRuntimeContext, - LangGraphScriptRuntime, -) +from ._runtime._memory import get_memory +from ._runtime._runtime import LangGraphScriptRuntime from ._utils._graph import LangGraphConfig @@ -36,46 +34,40 @@ def langgraph_debug_middleware( try: async def execute(): - context = LangGraphRuntimeContext.with_defaults(**kwargs) + context = UiPathRuntimeContext.with_defaults(**kwargs) context.entrypoint = entrypoint context.input = input context.resume = resume - context.execution_id = context.job_id or "default" _instrument_traceable_attributes() - def generate_runtime( - ctx: LangGraphRuntimeContext, - ) -> LangGraphScriptRuntime: - runtime = LangGraphScriptRuntime(ctx, ctx.entrypoint) - # If not resuming and no job id, delete the previous state file - if not ctx.resume and ctx.job_id is None: - if os.path.exists(runtime.state_file_path): - os.remove(runtime.state_file_path) - return runtime + async with get_memory(context) as memory: + runtime_factory = UiPathRuntimeFactory( + LangGraphScriptRuntime, + UiPathRuntimeContext, + runtime_generator=lambda ctx: LangGraphScriptRuntime( + ctx, memory, ctx.entrypoint + ), + context_generator=lambda: context, + ) - runtime_factory = UiPathRuntimeFactory( - LangGraphScriptRuntime, - LangGraphRuntimeContext, - runtime_generator=generate_runtime, - context_generator=lambda: context, - ) + if context.job_id: + runtime_factory.add_span_exporter( + LlmOpsHttpExporter(extra_process_spans=True) + ) - if context.job_id: - runtime_factory.add_span_exporter( - LlmOpsHttpExporter(extra_process_spans=True) + runtime_factory.add_instrumentor( + LangChainInstrumentor, get_current_span ) - runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span) - - debug_bridge: UiPathDebugBridge = get_debug_bridge(context) + debug_bridge: UiPathDebugBridge = get_debug_bridge(context) - async with UiPathDebugRuntime.from_debug_context( - factory=runtime_factory, - context=context, - debug_bridge=debug_bridge, - ) as debug_runtime: - await debug_runtime.execute() + async with UiPathDebugRuntime.from_debug_context( + factory=runtime_factory, + context=context, + debug_bridge=debug_bridge, + ) as debug_runtime: + await debug_runtime.execute() asyncio.run(execute()) diff --git a/src/uipath_langchain/_cli/cli_dev.py b/src/uipath_langchain/_cli/cli_dev.py index 563c9a4..41dde69 100644 --- a/src/uipath_langchain/_cli/cli_dev.py +++ b/src/uipath_langchain/_cli/cli_dev.py @@ -6,12 +6,12 @@ get_current_span, ) from uipath._cli._dev._terminal import UiPathDevTerminal -from uipath._cli._runtime._contracts import UiPathRuntimeFactory +from uipath._cli._runtime._contracts import UiPathRuntimeContext, UiPathRuntimeFactory from uipath._cli._utils._console import ConsoleLogger from uipath._cli.middlewares import MiddlewareResult from .._tracing import _instrument_traceable_attributes -from ._runtime._context import LangGraphRuntimeContext +from ._runtime._memory import get_memory from ._runtime._runtime import LangGraphScriptRuntime from ._utils._graph import LangGraphConfig @@ -28,20 +28,27 @@ def langgraph_dev_middleware(interface: Optional[str]) -> MiddlewareResult: try: if interface == "terminal": + _instrument_traceable_attributes() - def generate_runtime( - ctx: LangGraphRuntimeContext, - ) -> LangGraphScriptRuntime: - return LangGraphScriptRuntime(ctx, ctx.entrypoint) + async def execute(): + context = UiPathRuntimeContext.with_defaults() - runtime_factory = UiPathRuntimeFactory( - LangGraphScriptRuntime, LangGraphRuntimeContext, generate_runtime - ) + async with get_memory(context) as memory: + runtime_factory = UiPathRuntimeFactory( + LangGraphScriptRuntime, + UiPathRuntimeContext, + lambda ctx: LangGraphScriptRuntime(ctx, memory, ctx.entrypoint), + ) - _instrument_traceable_attributes() - runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span) - app = UiPathDevTerminal(runtime_factory) - asyncio.run(app.run_async()) + runtime_factory.add_instrumentor( + LangChainInstrumentor, get_current_span + ) + + app = UiPathDevTerminal(runtime_factory) + + await app.run_async() + + asyncio.run(execute()) else: console.error(f"Unknown interface: {interface}") except KeyboardInterrupt: diff --git a/src/uipath_langchain/_cli/cli_eval.py b/src/uipath_langchain/_cli/cli_eval.py index 5b5d5f4..f62d4d0 100644 --- a/src/uipath_langchain/_cli/cli_eval.py +++ b/src/uipath_langchain/_cli/cli_eval.py @@ -10,6 +10,7 @@ from uipath._cli._evals._progress_reporter import StudioWebProgressReporter from uipath._cli._evals._runtime import UiPathEvalContext from uipath._cli._runtime._contracts import ( + UiPathRuntimeContext, UiPathRuntimeFactory, ) from uipath._cli._utils._eval_set import EvalHelpers @@ -18,7 +19,7 @@ from uipath.eval._helpers import auto_discover_entrypoint from uipath.tracing import LlmOpsHttpExporter -from uipath_langchain._cli._runtime._context import LangGraphRuntimeContext +from uipath_langchain._cli._runtime._memory import get_memory from uipath_langchain._cli._runtime._runtime import LangGraphScriptRuntime from uipath_langchain._cli._utils._graph import LangGraphConfig from uipath_langchain._tracing import ( @@ -38,52 +39,59 @@ def langgraph_eval_middleware( try: _instrument_traceable_attributes() - event_bus = EventBus() + async def execute(): + event_bus = EventBus() - if kwargs.get("register_progress_reporter", False): - progress_reporter = StudioWebProgressReporter( - spans_exporter=LlmOpsHttpExporter(extra_process_spans=True) - ) - asyncio.run(progress_reporter.subscribe_to_eval_runtime_events(event_bus)) - console_reporter = ConsoleProgressReporter() - asyncio.run(console_reporter.subscribe_to_eval_runtime_events(event_bus)) + if kwargs.get("register_progress_reporter", False): + progress_reporter = StudioWebProgressReporter( + spans_exporter=LlmOpsHttpExporter(extra_process_spans=True) + ) + await progress_reporter.subscribe_to_eval_runtime_events(event_bus) - def generate_runtime_context( - context_entrypoint: str, **context_kwargs - ) -> LangGraphRuntimeContext: - context = LangGraphRuntimeContext.with_defaults(**context_kwargs) - context.entrypoint = context_entrypoint - return context + console_reporter = ConsoleProgressReporter() + await console_reporter.subscribe_to_eval_runtime_events(event_bus) - runtime_entrypoint = entrypoint or auto_discover_entrypoint() + def generate_runtime_context( + context_entrypoint: str, **context_kwargs + ) -> UiPathRuntimeContext: + context = UiPathRuntimeContext.with_defaults(**context_kwargs) + context.entrypoint = context_entrypoint + return context - eval_context = UiPathEvalContext.with_defaults( - entrypoint=runtime_entrypoint, **kwargs - ) - eval_context.eval_set = eval_set or EvalHelpers.auto_discover_eval_set() - eval_context.eval_ids = eval_ids - - def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphScriptRuntime: - return LangGraphScriptRuntime(ctx, ctx.entrypoint) - - runtime_factory = UiPathRuntimeFactory( - LangGraphScriptRuntime, - LangGraphRuntimeContext, - context_generator=lambda **context_kwargs: generate_runtime_context( - context_entrypoint=runtime_entrypoint, - **context_kwargs, - ), - runtime_generator=generate_runtime, - ) + runtime_entrypoint = entrypoint or auto_discover_entrypoint() - if eval_context.job_id: - runtime_factory.add_span_exporter( - LlmOpsHttpExporter(extra_process_spans=True) + eval_context = UiPathEvalContext.with_defaults( + entrypoint=runtime_entrypoint, **kwargs ) + eval_context.eval_set = eval_set or EvalHelpers.auto_discover_eval_set() + eval_context.eval_ids = eval_ids + + async with get_memory(eval_context) as memory: + runtime_factory = UiPathRuntimeFactory( + LangGraphScriptRuntime, + UiPathRuntimeContext, + context_generator=lambda **context_kwargs: generate_runtime_context( + context_entrypoint=runtime_entrypoint, + **context_kwargs, + ), + runtime_generator=lambda ctx: LangGraphScriptRuntime( + ctx, memory, ctx.entrypoint + ), + ) + + if eval_context.job_id: + runtime_factory.add_span_exporter( + LlmOpsHttpExporter(extra_process_spans=True) + ) + + runtime_factory.add_instrumentor( + LangChainInstrumentor, get_current_span + ) + + await evaluate(runtime_factory, eval_context, event_bus) + + asyncio.run(execute()) - runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span) - - asyncio.run(evaluate(runtime_factory, eval_context, event_bus)) return MiddlewareResult(should_continue=False) except Exception as e: diff --git a/src/uipath_langchain/_cli/cli_run.py b/src/uipath_langchain/_cli/cli_run.py index 819c2ae..8208023 100644 --- a/src/uipath_langchain/_cli/cli_run.py +++ b/src/uipath_langchain/_cli/cli_run.py @@ -1,5 +1,4 @@ import asyncio -import os from typing import Optional from openinference.instrumentation.langchain import ( @@ -8,6 +7,7 @@ ) from uipath._cli._debug._bridge import ConsoleDebugBridge, UiPathDebugBridge from uipath._cli._runtime._contracts import ( + UiPathRuntimeContext, UiPathRuntimeFactory, UiPathRuntimeResult, ) @@ -19,10 +19,8 @@ _instrument_traceable_attributes, ) from ._runtime._exception import LangGraphRuntimeError -from ._runtime._runtime import ( # type: ignore[attr-defined] - LangGraphRuntimeContext, - LangGraphScriptRuntime, -) +from ._runtime._memory import get_memory +from ._runtime._runtime import LangGraphScriptRuntime from ._utils._graph import LangGraphConfig @@ -41,49 +39,43 @@ def langgraph_run_middleware( ) # Continue with normal flow if no langgraph.json try: + _instrument_traceable_attributes() async def execute(): - context = LangGraphRuntimeContext.with_defaults(**kwargs) + context = UiPathRuntimeContext.with_defaults(**kwargs) context.entrypoint = entrypoint context.input = input context.resume = resume - context.execution_id = context.job_id or "default" - _instrument_traceable_attributes() - - def generate_runtime( - ctx: LangGraphRuntimeContext, - ) -> LangGraphScriptRuntime: - runtime = LangGraphScriptRuntime(ctx, ctx.entrypoint) - # If not resuming and no job id, delete the previous state file - if not ctx.resume and ctx.job_id is None: - if os.path.exists(runtime.state_file_path): - os.remove(runtime.state_file_path) - return runtime - runtime_factory = UiPathRuntimeFactory( - LangGraphScriptRuntime, - LangGraphRuntimeContext, - runtime_generator=generate_runtime, - ) + async with get_memory(context) as memory: + runtime_factory = UiPathRuntimeFactory( + LangGraphScriptRuntime, + UiPathRuntimeContext, + runtime_generator=lambda ctx: LangGraphScriptRuntime( + ctx, memory, ctx.entrypoint + ), + ) - runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span) + runtime_factory.add_instrumentor( + LangChainInstrumentor, get_current_span + ) - if trace_file: - runtime_factory.add_span_exporter(JsonLinesFileExporter(trace_file)) + if trace_file: + runtime_factory.add_span_exporter(JsonLinesFileExporter(trace_file)) - if context.job_id: - runtime_factory.add_span_exporter( - LlmOpsHttpExporter(extra_process_spans=True) - ) - await runtime_factory.execute(context) - else: - debug_bridge: UiPathDebugBridge = ConsoleDebugBridge() - await debug_bridge.emit_execution_started(context.execution_id) - async for event in runtime_factory.stream(context): - if isinstance(event, UiPathRuntimeResult): - await debug_bridge.emit_execution_completed(event) - elif isinstance(event, UiPathAgentStateEvent): - await debug_bridge.emit_state_update(event) + if context.job_id: + runtime_factory.add_span_exporter( + LlmOpsHttpExporter(extra_process_spans=True) + ) + await runtime_factory.execute(context) + else: + debug_bridge: UiPathDebugBridge = ConsoleDebugBridge() + await debug_bridge.emit_execution_started("default") + async for event in runtime_factory.stream(context): + if isinstance(event, UiPathRuntimeResult): + await debug_bridge.emit_execution_completed(event) + elif isinstance(event, UiPathAgentStateEvent): + await debug_bridge.emit_state_update(event) asyncio.run(execute()) diff --git a/src/uipath_langchain/runtime_factories.py b/src/uipath_langchain/runtime_factories.py index 2244345..03022a8 100644 --- a/src/uipath_langchain/runtime_factories.py +++ b/src/uipath_langchain/runtime_factories.py @@ -1,21 +1,20 @@ """Runtime factory for LangGraph projects.""" -from uipath._cli._runtime._contracts import UiPathRuntimeFactory +from uipath._cli._runtime._contracts import UiPathRuntimeContext, UiPathRuntimeFactory -from ._cli._runtime._context import LangGraphRuntimeContext from ._cli._runtime._runtime import LangGraphScriptRuntime class LangGraphRuntimeFactory( - UiPathRuntimeFactory[LangGraphScriptRuntime, LangGraphRuntimeContext] + UiPathRuntimeFactory[LangGraphScriptRuntime, UiPathRuntimeContext] ): """Factory for LangGraph runtimes.""" def __init__(self): super().__init__( LangGraphScriptRuntime, - LangGraphRuntimeContext, - context_generator=lambda **kwargs: LangGraphRuntimeContext.with_defaults( + UiPathRuntimeContext, + context_generator=lambda **kwargs: UiPathRuntimeContext.with_defaults( **kwargs ), )