Skip to content

Commit 2b9e2d0

Browse files
committed
JsonSchemaTransformer: Add back support for the simplify_nullable_unions kwarg
1 parent c977993 commit 2b9e2d0

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

pydantic_ai_slim/pydantic_ai/_json_schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ def __init__(
2525
*,
2626
strict: bool | None = None,
2727
prefer_inlined_defs: bool = False,
28+
simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used
2829
):
2930
self.schema = schema
3031

3132
self.strict = strict
3233
self.is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
3334

3435
self.prefer_inlined_defs = prefer_inlined_defs
36+
self.simplify_nullable_unions = simplify_nullable_unions # TODO (v2): Remove this, no longer used
3537

3638
self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
3739
self.refs_stack: list[str] = []
@@ -144,11 +146,39 @@ def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf'
144146

145147
handled = [self._handle(member) for member in members]
146148

149+
# TODO (v2): Remove this feature, no longer used
150+
if self.simplify_nullable_unions:
151+
handled = self._simplify_nullable_union(handled)
152+
153+
if len(handled) == 1:
154+
# In this case, no need to retain the union
155+
return handled[0] | schema
156+
147157
# If we have keys besides the union kind (such as title or discriminator), keep them without modifications
148158
schema = schema.copy()
149159
schema[union_kind] = handled
150160
return schema
151161

162+
@staticmethod
163+
def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]:
164+
# TODO (v2): Remove this method, no longer used
165+
if len(cases) == 2 and {'type': 'null'} in cases:
166+
# Find the non-null schema
167+
non_null_schema = next(
168+
(item for item in cases if item != {'type': 'null'}),
169+
None,
170+
)
171+
if non_null_schema:
172+
# Create a new schema based on the non-null part, mark as nullable
173+
new_schema = deepcopy(non_null_schema)
174+
new_schema['nullable'] = True
175+
return [new_schema]
176+
else: # pragma: no cover
177+
# they are both null, so just return one of them
178+
return [cases[0]]
179+
180+
return cases
181+
152182

153183
class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer):
154184
"""Transforms the JSON Schema to inline $defs."""

tests/test_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import os
77
from collections.abc import AsyncIterator
88
from importlib.metadata import distributions
9+
from typing import Any
910

1011
import pytest
1112
from inline_snapshot import snapshot
1213

1314
from pydantic_ai import UserError
15+
from pydantic_ai._json_schema import JsonSchemaTransformer
1416
from pydantic_ai._utils import (
1517
UNSET,
1618
PeekableAsyncStream,
@@ -541,3 +543,47 @@ def test_validate_empty_kwargs_preserves_order():
541543
assert '`first`' in error_msg
542544
assert '`second`' in error_msg
543545
assert '`third`' in error_msg
546+
547+
548+
def test_simplify_nullable_unions():
549+
"""Test the simplify_nullable_unions feature (deprecated, to be removed in v2)."""
550+
551+
# Create a concrete subclass for testing
552+
class TestTransformer(JsonSchemaTransformer):
553+
def transform(self, schema: dict[str, Any]) -> dict[str, Any]:
554+
return schema
555+
556+
# Test with simplify_nullable_unions=True
557+
schema_with_null = {
558+
'anyOf': [
559+
{'type': 'string'},
560+
{'type': 'null'},
561+
]
562+
}
563+
transformer = TestTransformer(schema_with_null, simplify_nullable_unions=True)
564+
result = transformer.walk()
565+
566+
# Should collapse to a single nullable string
567+
assert result == {'type': 'string', 'nullable': True}
568+
569+
# Test with simplify_nullable_unions=False (default)
570+
transformer2 = TestTransformer(schema_with_null, simplify_nullable_unions=False)
571+
result2 = transformer2.walk()
572+
573+
# Should keep the anyOf structure
574+
assert 'anyOf' in result2
575+
assert len(result2['anyOf']) == 2
576+
577+
# Test that non-nullable unions are unaffected
578+
schema_no_null = {
579+
'anyOf': [
580+
{'type': 'string'},
581+
{'type': 'number'},
582+
]
583+
}
584+
transformer3 = TestTransformer(schema_no_null, simplify_nullable_unions=True)
585+
result3 = transformer3.walk()
586+
587+
# Should keep anyOf since it's not nullable
588+
assert 'anyOf' in result3
589+
assert len(result3['anyOf']) == 2

0 commit comments

Comments
 (0)