From 9164dc884c07ceaef750f2c33aa103d353e2684d Mon Sep 17 00:00:00 2001 From: Robert Andrei Moldoveanu Date: Tue, 15 Apr 2025 13:01:50 +0300 Subject: [PATCH] fix: process infer optional types --- src/uipath_langchain/_cli/cli_init.py | 36 ++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/src/uipath_langchain/_cli/cli_init.py b/src/uipath_langchain/_cli/cli_init.py index 72e2cc7b..c8e4f507 100644 --- a/src/uipath_langchain/_cli/cli_init.py +++ b/src/uipath_langchain/_cli/cli_init.py @@ -32,6 +32,25 @@ def resolve_refs(schema, root=None): return schema +def process_nullable_types( + schema: Dict[str, Any] | list[Any] | Any, +) -> Dict[str, Any] | list[Any]: + """Process the schema to handle nullable types by removing anyOf with null and keeping the base type.""" + if isinstance(schema, dict): + if "anyOf" in schema and len(schema["anyOf"]) == 2: + types = [t.get("type") for t in schema["anyOf"]] + if "null" in types: + non_null_type = next( + t for t in schema["anyOf"] if t.get("type") != "null" + ) + return non_null_type + + return {k: process_nullable_types(v) for k, v in schema.items()} + elif isinstance(schema, list): + return [process_nullable_types(item) for item in schema] + return schema + + def generate_schema_from_graph(graph: CompiledStateGraph) -> Dict[str, Any]: """Extract input/output schema from a LangGraph graph""" schema = { @@ -42,12 +61,14 @@ def generate_schema_from_graph(graph: CompiledStateGraph) -> Dict[str, Any]: if hasattr(graph, "input_schema"): if hasattr(graph.input_schema, "model_json_schema"): input_schema = graph.input_schema.model_json_schema() - unpacked_ref_def_properties = resolve_refs(input_schema) - schema["input"]["properties"] = unpacked_ref_def_properties.get( - "properties", {} + # Process the schema to handle nullable types + processed_properties = process_nullable_types( + unpacked_ref_def_properties.get("properties", {}) ) + + schema["input"]["properties"] = processed_properties schema["input"]["required"] = unpacked_ref_def_properties.get( "required", [] ) @@ -55,11 +76,14 @@ def generate_schema_from_graph(graph: CompiledStateGraph) -> Dict[str, Any]: if hasattr(graph, "output_schema"): if hasattr(graph.output_schema, "model_json_schema"): output_schema = graph.output_schema.model_json_schema() - unpacked_ref_def_properties = resolve_refs(output_schema) - schema["output"]["properties"] = unpacked_ref_def_properties.get( - "properties", {} + + # Process the schema to handle nullable types + processed_properties = process_nullable_types( + unpacked_ref_def_properties.get("properties", {}) ) + + schema["output"]["properties"] = processed_properties schema["output"]["required"] = unpacked_ref_def_properties.get( "required", [] )