44from abc import ABC , abstractmethod
55from copy import deepcopy
66from dataclasses import dataclass
7- from typing import Any , Literal
7+ from typing import Any , Literal , cast
88
99from .exceptions import UserError
1010
1111JsonSchema = dict [str , Any ]
1212
13+ __all__ = ['JsonSchemaTransformer' , 'InlineDefsJsonSchemaTransformer' , 'flatten_allof' ]
14+
1315
1416@dataclass (init = False )
1517class JsonSchemaTransformer (ABC ):
@@ -30,7 +32,9 @@ def __init__(
3032 self .schema = schema
3133
3234 self .strict = strict
33- self .is_strict_compatible = True # Can be set to False by subclasses to set `strict` on `ToolDefinition` when set not set by user explicitly
35+ # Can be set to False by subclasses to set `strict` on `ToolDefinition`
36+ # when not set explicitly by the user.
37+ self .is_strict_compatible = True
3438
3539 self .prefer_inlined_defs = prefer_inlined_defs
3640 self .simplify_nullable_unions = simplify_nullable_unions
@@ -188,3 +192,108 @@ def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
188192
189193 def transform (self , schema : JsonSchema ) -> JsonSchema :
190194 return schema
195+
196+
197+ def _allof_is_object_like (member : JsonSchema ) -> bool :
198+ member_type = member .get ('type' )
199+ if member_type is None :
200+ keys = ('properties' , 'additionalProperties' , 'patternProperties' )
201+ return bool (any (k in member for k in keys ))
202+ return member_type == 'object'
203+
204+
205+ def _merge_additional_properties_values (values : list [Any ]) -> bool | JsonSchema :
206+ if any (isinstance (v , dict ) for v in values ):
207+ return True
208+ return False if values and all (v is False for v in values ) else True
209+
210+
211+ def _flatten_current_level (s : JsonSchema ) -> JsonSchema :
212+ raw_members = s .get ('allOf' )
213+ if not isinstance (raw_members , list ) or not raw_members :
214+ return s
215+
216+ members = cast (list [JsonSchema ], raw_members )
217+ for raw in members :
218+ if not isinstance (raw , dict ):
219+ return s
220+ if not all (_allof_is_object_like (member ) for member in members ):
221+ return s
222+
223+ processed_members = [_recurse_flatten_allof (member ) for member in members ]
224+ merged : JsonSchema = {k : v for k , v in s .items () if k != 'allOf' }
225+ merged ['type' ] = 'object'
226+
227+ properties : dict [str , JsonSchema ] = {}
228+ if isinstance (merged .get ('properties' ), dict ):
229+ properties .update (merged ['properties' ])
230+
231+ required : set [str ] = set (merged .get ('required' , []) or [])
232+ pattern_properties : dict [str , JsonSchema ] = dict (merged .get ('patternProperties' , {}) or {})
233+ additional_values : list [Any ] = []
234+
235+ for m in processed_members :
236+ if isinstance (m .get ('properties' ), dict ):
237+ properties .update (m ['properties' ])
238+ if isinstance (m .get ('required' ), list ):
239+ required .update (m ['required' ])
240+ if isinstance (m .get ('patternProperties' ), dict ):
241+ pattern_properties .update (m ['patternProperties' ])
242+ if 'additionalProperties' in m :
243+ additional_values .append (m ['additionalProperties' ])
244+
245+ if properties :
246+ merged ['properties' ] = {k : _recurse_flatten_allof (v ) for k , v in properties .items ()}
247+ if required :
248+ merged ['required' ] = sorted (required )
249+ if pattern_properties :
250+ merged ['patternProperties' ] = {k : _recurse_flatten_allof (v ) for k , v in pattern_properties .items ()}
251+
252+ if additional_values :
253+ merged ['additionalProperties' ] = _merge_additional_properties_values (additional_values )
254+
255+ return merged
256+
257+
258+ def _recurse_children (s : JsonSchema ) -> JsonSchema :
259+ t = s .get ('type' )
260+ if t == 'object' :
261+ if isinstance (s .get ('properties' ), dict ):
262+ s ['properties' ] = {
263+ k : _recurse_flatten_allof (cast (JsonSchema , v ))
264+ for k , v in s ['properties' ].items ()
265+ if isinstance (v , dict )
266+ }
267+ ap = s .get ('additionalProperties' )
268+ if isinstance (ap , dict ):
269+ ap_schema = cast (JsonSchema , ap )
270+ s ['additionalProperties' ] = _recurse_flatten_allof (ap_schema )
271+ if isinstance (s .get ('patternProperties' ), dict ):
272+ s ['patternProperties' ] = {
273+ k : _recurse_flatten_allof (cast (JsonSchema , v ))
274+ for k , v in s ['patternProperties' ].items ()
275+ if isinstance (v , dict )
276+ }
277+ elif t == 'array' :
278+ items = s .get ('items' )
279+ if isinstance (items , dict ):
280+ s ['items' ] = _recurse_flatten_allof (cast (JsonSchema , items ))
281+ return s
282+
283+
284+ def _recurse_flatten_allof (schema : JsonSchema ) -> JsonSchema :
285+ s = deepcopy (schema )
286+ s = _flatten_current_level (s )
287+ s = _recurse_children (s )
288+ return s
289+
290+
291+ def flatten_allof (schema : JsonSchema ) -> JsonSchema :
292+ """Flatten simple object-only allOf combinations by merging object members.
293+
294+ - Merges properties and unions required lists.
295+ - Combines additionalProperties conservatively: only False if all are False; otherwise True.
296+ - Recurses into nested object/array members.
297+ - Leaves non-object allOfs untouched.
298+ """
299+ return _recurse_flatten_allof (schema )
0 commit comments