Skip to content

Commit aefbe61

Browse files
authored
Merge pull request #231 from UiPath/bai/simplify-middleware
feat: refactor cli internals, add langgraph runtime factory
2 parents 3a36a91 + 5e123a4 commit aefbe61

File tree

7 files changed

+138
-99
lines changed

7 files changed

+138
-99
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
[project]
22
name = "uipath-langchain"
3-
version = "0.0.144"
3+
version = "0.0.145"
44
description = "UiPath Langchain"
55
readme = { file = "README.md", content-type = "text/markdown" }
66
requires-python = ">=3.10"
77
dependencies = [
8-
"uipath>=2.1.103, <2.2.0",
8+
"uipath>=2.1.110, <2.2.0",
99
"langgraph>=0.5.0, <0.7.0",
1010
"langchain-core>=0.3.34",
1111
"langgraph-checkpoint-sqlite>=2.0.3",
@@ -111,4 +111,4 @@ asyncio_mode = "auto"
111111
name = "testpypi"
112112
url = "https://test.pypi.org/simple/"
113113
publish-url = "https://test.pypi.org/legacy/"
114-
explicit = true
114+
explicit = true

src/uipath_langchain/_cli/_runtime/_runtime.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import os
33
from contextlib import asynccontextmanager
44
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Sequence
5+
from uuid import uuid4
56

67
from langchain_core.runnables.config import RunnableConfig
78
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
89
from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError
910
from langgraph.graph.state import CompiledStateGraph, StateGraph
1011
from langgraph.types import Interrupt, StateSnapshot
12+
from typing_extensions import override
1113
from uipath._cli._runtime._contracts import (
1214
UiPathBaseRuntime,
1315
UiPathBreakpointResult,
@@ -17,12 +19,14 @@
1719
UiPathRuntimeResult,
1820
UiPathRuntimeStatus,
1921
)
22+
from uipath._cli.models.runtime_schema import Entrypoint
2023
from uipath._events._events import (
2124
UiPathAgentMessageEvent,
2225
UiPathAgentStateEvent,
2326
UiPathRuntimeEvent,
2427
)
2528

29+
from .._utils._schema import generate_schema_from_graph
2630
from ._context import LangGraphRuntimeContext
2731
from ._exception import LangGraphErrorCode, LangGraphRuntimeError
2832
from ._graph_resolver import AsyncResolver, LangGraphJsonResolver
@@ -481,6 +485,21 @@ def __init__(
481485
self.resolver = LangGraphJsonResolver(entrypoint=entrypoint)
482486
super().__init__(context, self.resolver)
483487

488+
@override
489+
async def get_entrypoint(self) -> Entrypoint:
490+
"""Get entrypoint for this LangGraph runtime."""
491+
graph = await self.resolver()
492+
compiled_graph = graph.compile()
493+
schema = generate_schema_from_graph(compiled_graph)
494+
495+
return Entrypoint(
496+
file_path=self.context.entrypoint, # type: ignore[call-arg]
497+
unique_id=str(uuid4()),
498+
type="agent",
499+
input=schema["input"],
500+
output=schema["output"],
501+
)
502+
484503
async def cleanup(self) -> None:
485504
"""Cleanup runtime resources including resolver."""
486505
await super().cleanup()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Any, Dict
2+
3+
from langgraph.graph.state import CompiledStateGraph
4+
5+
6+
def resolve_refs(schema, root=None):
7+
"""Recursively resolves $ref references in a JSON schema."""
8+
if root is None:
9+
root = schema # Store the root schema to resolve $refs
10+
11+
if isinstance(schema, dict):
12+
if "$ref" in schema:
13+
ref_path = schema["$ref"].lstrip("#/").split("/")
14+
ref_schema = root
15+
for part in ref_path:
16+
ref_schema = ref_schema.get(part, {})
17+
return resolve_refs(ref_schema, root)
18+
19+
return {k: resolve_refs(v, root) for k, v in schema.items()}
20+
21+
elif isinstance(schema, list):
22+
return [resolve_refs(item, root) for item in schema]
23+
24+
return schema
25+
26+
27+
def process_nullable_types(
28+
schema: Dict[str, Any] | list[Any] | Any,
29+
) -> Dict[str, Any] | list[Any]:
30+
"""Process the schema to handle nullable types by removing anyOf with null and keeping the base type."""
31+
if isinstance(schema, dict):
32+
if "anyOf" in schema and len(schema["anyOf"]) == 2:
33+
types = [t.get("type") for t in schema["anyOf"]]
34+
if "null" in types:
35+
non_null_type = next(
36+
t for t in schema["anyOf"] if t.get("type") != "null"
37+
)
38+
return non_null_type
39+
40+
return {k: process_nullable_types(v) for k, v in schema.items()}
41+
elif isinstance(schema, list):
42+
return [process_nullable_types(item) for item in schema]
43+
return schema
44+
45+
46+
def generate_schema_from_graph(
47+
graph: CompiledStateGraph[Any, Any, Any],
48+
) -> Dict[str, Any]:
49+
"""Extract input/output schema from a LangGraph graph"""
50+
schema = {
51+
"input": {"type": "object", "properties": {}, "required": []},
52+
"output": {"type": "object", "properties": {}, "required": []},
53+
}
54+
55+
if hasattr(graph, "input_schema"):
56+
if hasattr(graph.input_schema, "model_json_schema"):
57+
input_schema = graph.input_schema.model_json_schema()
58+
unpacked_ref_def_properties = resolve_refs(input_schema)
59+
60+
# Process the schema to handle nullable types
61+
processed_properties = process_nullable_types(
62+
unpacked_ref_def_properties.get("properties", {})
63+
)
64+
65+
schema["input"]["properties"] = processed_properties
66+
schema["input"]["required"] = unpacked_ref_def_properties.get(
67+
"required", []
68+
)
69+
70+
if hasattr(graph, "output_schema"):
71+
if hasattr(graph.output_schema, "model_json_schema"):
72+
output_schema = graph.output_schema.model_json_schema()
73+
unpacked_ref_def_properties = resolve_refs(output_schema)
74+
75+
# Process the schema to handle nullable types
76+
processed_properties = process_nullable_types(
77+
unpacked_ref_def_properties.get("properties", {})
78+
)
79+
80+
schema["output"]["properties"] = processed_properties
81+
schema["output"]["required"] = unpacked_ref_def_properties.get(
82+
"required", []
83+
)
84+
85+
return schema

src/uipath_langchain/_cli/cli_eval.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
get_current_span,
77
)
88
from uipath._cli._evals._console_progress_reporter import ConsoleProgressReporter
9+
from uipath._cli._evals._evaluate import evaluate
910
from uipath._cli._evals._progress_reporter import StudioWebProgressReporter
10-
from uipath._cli._evals._runtime import UiPathEvalContext, UiPathEvalRuntime
11+
from uipath._cli._evals._runtime import UiPathEvalContext
1112
from uipath._cli._runtime._contracts import (
1213
UiPathRuntimeFactory,
1314
)
@@ -82,14 +83,7 @@ def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphScriptRuntime:
8283

8384
runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span)
8485

85-
async def execute():
86-
async with UiPathEvalRuntime.from_eval_context(
87-
factory=runtime_factory, context=eval_context, event_bus=event_bus
88-
) as eval_runtime:
89-
await eval_runtime.execute()
90-
await event_bus.wait_for_all()
91-
92-
asyncio.run(execute())
86+
asyncio.run(evaluate(runtime_factory, eval_context, event_bus))
9387
return MiddlewareResult(should_continue=False)
9488

9589
except Exception as e:

src/uipath_langchain/_cli/cli_init.py

Lines changed: 3 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66
import uuid
77
from collections.abc import Generator
88
from enum import Enum
9-
from typing import Any, Callable, Dict, overload
9+
from typing import Any, Callable, overload
1010

1111
import click
1212
from langgraph.graph.state import CompiledStateGraph
1313
from uipath._cli._utils._console import ConsoleLogger
1414
from uipath._cli._utils._parse_ast import generate_bindings_json # type: ignore
1515
from uipath._cli.middlewares import MiddlewareResult
1616

17+
from uipath_langchain._cli._utils._schema import generate_schema_from_graph
18+
1719
from ._utils._graph import LangGraphConfig
1820

1921
console = ConsoleLogger()
@@ -27,88 +29,6 @@ class FileOperationStatus(str, Enum):
2729
SKIPPED = "skipped"
2830

2931

30-
def resolve_refs(schema, root=None):
31-
"""Recursively resolves $ref references in a JSON schema."""
32-
if root is None:
33-
root = schema # Store the root schema to resolve $refs
34-
35-
if isinstance(schema, dict):
36-
if "$ref" in schema:
37-
ref_path = schema["$ref"].lstrip("#/").split("/")
38-
ref_schema = root
39-
for part in ref_path:
40-
ref_schema = ref_schema.get(part, {})
41-
return resolve_refs(ref_schema, root)
42-
43-
return {k: resolve_refs(v, root) for k, v in schema.items()}
44-
45-
elif isinstance(schema, list):
46-
return [resolve_refs(item, root) for item in schema]
47-
48-
return schema
49-
50-
51-
def process_nullable_types(
52-
schema: Dict[str, Any] | list[Any] | Any,
53-
) -> Dict[str, Any] | list[Any]:
54-
"""Process the schema to handle nullable types by removing anyOf with null and keeping the base type."""
55-
if isinstance(schema, dict):
56-
if "anyOf" in schema and len(schema["anyOf"]) == 2:
57-
types = [t.get("type") for t in schema["anyOf"]]
58-
if "null" in types:
59-
non_null_type = next(
60-
t for t in schema["anyOf"] if t.get("type") != "null"
61-
)
62-
return non_null_type
63-
64-
return {k: process_nullable_types(v) for k, v in schema.items()}
65-
elif isinstance(schema, list):
66-
return [process_nullable_types(item) for item in schema]
67-
return schema
68-
69-
70-
def generate_schema_from_graph(
71-
graph: CompiledStateGraph[Any, Any, Any],
72-
) -> Dict[str, Any]:
73-
"""Extract input/output schema from a LangGraph graph"""
74-
schema = {
75-
"input": {"type": "object", "properties": {}, "required": []},
76-
"output": {"type": "object", "properties": {}, "required": []},
77-
}
78-
79-
if hasattr(graph, "input_schema"):
80-
if hasattr(graph.input_schema, "model_json_schema"):
81-
input_schema = graph.input_schema.model_json_schema()
82-
unpacked_ref_def_properties = resolve_refs(input_schema)
83-
84-
# Process the schema to handle nullable types
85-
processed_properties = process_nullable_types(
86-
unpacked_ref_def_properties.get("properties", {})
87-
)
88-
89-
schema["input"]["properties"] = processed_properties
90-
schema["input"]["required"] = unpacked_ref_def_properties.get(
91-
"required", []
92-
)
93-
94-
if hasattr(graph, "output_schema"):
95-
if hasattr(graph.output_schema, "model_json_schema"):
96-
output_schema = graph.output_schema.model_json_schema()
97-
unpacked_ref_def_properties = resolve_refs(output_schema)
98-
99-
# Process the schema to handle nullable types
100-
processed_properties = process_nullable_types(
101-
unpacked_ref_def_properties.get("properties", {})
102-
)
103-
104-
schema["output"]["properties"] = processed_properties
105-
schema["output"]["required"] = unpacked_ref_def_properties.get(
106-
"required", []
107-
)
108-
109-
return schema
110-
111-
11232
def generate_agent_md_file(
11333
target_directory: str,
11434
file_name: str,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Runtime factory for LangGraph projects."""
2+
3+
from uipath._cli._runtime._contracts import UiPathRuntimeFactory
4+
5+
from ._cli._runtime._context import LangGraphRuntimeContext
6+
from ._cli._runtime._runtime import LangGraphScriptRuntime
7+
8+
9+
class LangGraphRuntimeFactory(
10+
UiPathRuntimeFactory[LangGraphScriptRuntime, LangGraphRuntimeContext]
11+
):
12+
"""Factory for LangGraph runtimes."""
13+
14+
def __init__(self):
15+
super().__init__(
16+
LangGraphScriptRuntime,
17+
LangGraphRuntimeContext,
18+
context_generator=lambda **kwargs: LangGraphRuntimeContext.with_defaults(
19+
**kwargs
20+
),
21+
)

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)