Skip to content

Commit 8644859

Browse files
committed
fix: allow for self-referencing pydantic schema
1 parent 62106f2 commit 8644859

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

fastapi_mcp/openapi/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Optional, Set
22

33

44
def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
@@ -16,7 +16,11 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
1616
return param_schema.get("type", "string")
1717

1818

19-
def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]:
19+
def resolve_schema_references(
20+
schema_part: Dict[str, Any],
21+
reference_schema: Dict[str, Any],
22+
seen: Optional[Set[str]] = None,
23+
) -> Dict[str, Any]:
2024
"""
2125
Resolve schema references in OpenAPI schemas.
2226
@@ -27,6 +31,8 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
2731
Returns:
2832
The schema with references resolved
2933
"""
34+
seen = seen or set()
35+
3036
# Make a copy to avoid modifying the input schema
3137
schema_part = schema_part.copy()
3238

@@ -35,6 +41,9 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
3541
ref_path = schema_part["$ref"]
3642
# Standard OpenAPI references are in the format "#/components/schemas/ModelName"
3743
if ref_path.startswith("#/components/schemas/"):
44+
if ref_path in seen:
45+
return {"$ref": ref_path}
46+
seen.add(ref_path)
3847
model_name = ref_path.split("/")[-1]
3948
if "components" in reference_schema and "schemas" in reference_schema["components"]:
4049
if model_name in reference_schema["components"]["schemas"]:
@@ -47,11 +56,12 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
4756
# Recursively resolve references in all dictionary values
4857
for key, value in schema_part.items():
4958
if isinstance(value, dict):
50-
schema_part[key] = resolve_schema_references(value, reference_schema)
59+
schema_part[key] = resolve_schema_references(value, reference_schema, seen)
5160
elif isinstance(value, list):
5261
# Only process list items that are dictionaries since only they can contain refs
5362
schema_part[key] = [
54-
resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value
63+
resolve_schema_references(item, reference_schema, seen) if isinstance(item, dict) else item
64+
for item in value
5565
]
5666

5767
return schema_part

tests/fixtures/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from typing import Optional, List, Dict, Any
23
from datetime import datetime, date
34
from enum import Enum
@@ -95,6 +96,7 @@ class Product(BaseModel):
9596
updated_at: Optional[datetime] = None
9697
is_available: bool = True
9798
metadata: Dict[str, Any] = {}
99+
related_products: Optional[List[Product]] = None
98100

99101

100102
class OrderItem(BaseModel):

0 commit comments

Comments
 (0)