Skip to content

Commit d416df8

Browse files
committed
fix: allow for self-referencing pydantic schema
1 parent 62106f2 commit d416df8

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

fastapi_mcp/openapi/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Optional, Set
22

33

44
def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
@@ -16,17 +16,24 @@ def get_single_param_type_from_schema(param_schema: Dict[str, Any]) -> str:
1616
return param_schema.get("type", "string")
1717

1818

19-
def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dict[str, Any]) -> Dict[str, Any]:
19+
def resolve_schema_references(
20+
schema_part: Dict[str, Any],
21+
reference_schema: Dict[str, Any],
22+
seen: Optional[Set[str]] = None,
23+
) -> Dict[str, Any]:
2024
"""
2125
Resolve schema references in OpenAPI schemas.
2226
2327
Args:
2428
schema_part: The part of the schema being processed that may contain references
2529
reference_schema: The complete schema used to resolve references from
30+
seen: A set of already seen references to avoid infinite recursion
2631
2732
Returns:
2833
The schema with references resolved
2934
"""
35+
seen = seen or set()
36+
3037
# Make a copy to avoid modifying the input schema
3138
schema_part = schema_part.copy()
3239

@@ -35,6 +42,9 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
3542
ref_path = schema_part["$ref"]
3643
# Standard OpenAPI references are in the format "#/components/schemas/ModelName"
3744
if ref_path.startswith("#/components/schemas/"):
45+
if ref_path in seen:
46+
return {"$ref": ref_path}
47+
seen.add(ref_path)
3848
model_name = ref_path.split("/")[-1]
3949
if "components" in reference_schema and "schemas" in reference_schema["components"]:
4050
if model_name in reference_schema["components"]["schemas"]:
@@ -47,11 +57,12 @@ def resolve_schema_references(schema_part: Dict[str, Any], reference_schema: Dic
4757
# Recursively resolve references in all dictionary values
4858
for key, value in schema_part.items():
4959
if isinstance(value, dict):
50-
schema_part[key] = resolve_schema_references(value, reference_schema)
60+
schema_part[key] = resolve_schema_references(value, reference_schema, seen)
5161
elif isinstance(value, list):
5262
# Only process list items that are dictionaries since only they can contain refs
5363
schema_part[key] = [
54-
resolve_schema_references(item, reference_schema) if isinstance(item, dict) else item for item in value
64+
resolve_schema_references(item, reference_schema, seen) if isinstance(item, dict) else item
65+
for item in value
5566
]
5667

5768
return schema_part

tests/fixtures/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
from typing import Optional, List, Dict, Any
23
from datetime import datetime, date
34
from enum import Enum
@@ -95,6 +96,7 @@ class Product(BaseModel):
9596
updated_at: Optional[datetime] = None
9697
is_available: bool = True
9798
metadata: Dict[str, Any] = {}
99+
related_products: Optional[List[Product]] = None
98100

99101

100102
class OrderItem(BaseModel):

0 commit comments

Comments
 (0)