Skip to content

Revert "Fix a bug with handling FastAPI root_path parameter" #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions fastapi_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@
str,
Doc(
"""
Path where the MCP server will be mounted.
Mount path is appended to the root path of FastAPI router, or to the prefix of APIRouter.
Defaults to '/mcp'.
Path where the MCP server will be mounted. Defaults to '/mcp'.
"""
),
] = "/mcp",
Expand Down Expand Up @@ -330,9 +328,14 @@
router = self.fastapi

# Build the base path correctly for the SSE transport
assert isinstance(router, (FastAPI, APIRouter)), f"Invalid router type: {type(router)}"
base_path = mount_path if isinstance(router, FastAPI) else router.prefix + mount_path
messages_path = f"{base_path}/messages/"
if isinstance(router, FastAPI):
base_path = router.root_path
elif isinstance(router, APIRouter):
base_path = self.fastapi.root_path + router.prefix

Check warning on line 334 in fastapi_mcp/server.py

View check run for this annotation

Codecov / codecov/patch

fastapi_mcp/server.py#L333-L334

Added lines #L333 - L334 were not covered by tests
else:
raise ValueError(f"Invalid router type: {type(router)}")

Check warning on line 336 in fastapi_mcp/server.py

View check run for this annotation

Codecov / codecov/patch

fastapi_mcp/server.py#L336

Added line #L336 was not covered by tests

messages_path = f"{base_path}{mount_path}/messages/"

sse_transport = FastApiSseTransport(messages_path)

Expand Down
9 changes: 5 additions & 4 deletions tests/fixtures/complex_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from fastapi import FastAPI, Query, Path, Body, Header, Cookie
import pytest

from tests.fixtures.conftest import make_fastapi_app_base

from .types import (
Product,
Customer,
Expand All @@ -21,9 +19,12 @@ def make_complex_fastapi_app(
example_product: Product,
example_customer: Customer,
example_order_response: OrderResponse,
parametrized_config: dict[str, Any] | None = None,
) -> FastAPI:
app = make_fastapi_app_base(parametrized_config=parametrized_config)
app = FastAPI(
title="Complex E-Commerce API",
description="A more complex API with nested models and various schemas",
version="1.0.0",
)

@app.get(
"/products",
Expand Down
12 changes: 0 additions & 12 deletions tests/fixtures/conftest.py

This file was deleted.

18 changes: 8 additions & 10 deletions tests/fixtures/simple_app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Optional, List, Any
from typing import Optional, List

from fastapi import FastAPI, Query, Path, Body, HTTPException
import pytest

from tests.fixtures.conftest import make_fastapi_app_base

from .types import Item


def make_simple_fastapi_app(parametrized_config: dict[str, Any] | None = None) -> FastAPI:
app = make_fastapi_app_base(parametrized_config=parametrized_config)
def make_simple_fastapi_app() -> FastAPI:
app = FastAPI(
title="Test API",
description="A test API app for unit testing",
version="0.1.0",
)

items = [
Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"),
Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]),
Expand Down Expand Up @@ -67,8 +70,3 @@ async def raise_error() -> None:
@pytest.fixture
def simple_fastapi_app() -> FastAPI:
return make_simple_fastapi_app()


@pytest.fixture
def simple_fastapi_app_with_root_path() -> FastAPI:
return make_simple_fastapi_app(parametrized_config={"root_path": "/api/v1"})
54 changes: 28 additions & 26 deletions tests/test_sse_real_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import threading
import coverage
from typing import AsyncGenerator, Generator
from fastapi import FastAPI
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp import InitializeResult
Expand All @@ -19,12 +18,26 @@
import uvicorn
from fastapi_mcp import FastApiMCP

from .fixtures.simple_app import make_simple_fastapi_app


HOST = "127.0.0.1"
SERVER_NAME = "Test MCP Server"


def run_server(server_port: int, fastapi_app: FastAPI) -> None:
@pytest.fixture
def server_port() -> int:
with socket.socket() as s:
s.bind((HOST, 0))
return s.getsockname()[1]


@pytest.fixture
def server_url(server_port: int) -> str:
return f"http://{HOST}:{server_port}"


def run_server(server_port: int) -> None:
# Initialize coverage for subprocesses
cov = None
if "COVERAGE_PROCESS_START" in os.environ:
Expand Down Expand Up @@ -59,15 +72,16 @@ def periodic_save():
save_thread.start()

# Configure the server
fastapi = make_simple_fastapi_app()
mcp = FastApiMCP(
fastapi_app,
fastapi,
name=SERVER_NAME,
description="Test description",
)
mcp.mount()

# Start the server
server = uvicorn.Server(config=uvicorn.Config(app=fastapi_app, host=HOST, port=server_port, log_level="error"))
server = uvicorn.Server(config=uvicorn.Config(app=fastapi, host=HOST, port=server_port, log_level="error"))
server.run()

# Give server time to start
Expand All @@ -80,24 +94,13 @@ def periodic_save():
cov.save()


@pytest.fixture(params=["simple_fastapi_app", "simple_fastapi_app_with_root_path"])
def server(request: pytest.FixtureRequest) -> Generator[str, None, None]:
@pytest.fixture()
def server(server_port: int) -> Generator[None, None, None]:
# Ensure COVERAGE_PROCESS_START is set in the environment for subprocesses
coverage_rc = os.path.abspath(".coveragerc")
os.environ["COVERAGE_PROCESS_START"] = coverage_rc

# Get a free port
with socket.socket() as s:
s.bind((HOST, 0))
server_port = s.getsockname()[1]

# Run the server in a subprocess
fastapi_app = request.getfixturevalue(request.param)
proc = multiprocessing.Process(
target=run_server,
kwargs={"server_port": server_port, "fastapi_app": fastapi_app},
daemon=True,
)
proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True)
proc.start()

# Wait for server to be running
Expand All @@ -114,8 +117,7 @@ def server(request: pytest.FixtureRequest) -> Generator[str, None, None]:
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")

# Return the server URL
yield f"http://{HOST}:{server_port}{fastapi_app.root_path}"
yield

# Signal the server to stop - added graceful shutdown before kill
try:
Expand All @@ -132,8 +134,8 @@ def server(request: pytest.FixtureRequest) -> Generator[str, None, None]:


@pytest.fixture()
async def http_client(server: str) -> AsyncGenerator[httpx.AsyncClient, None]:
async with httpx.AsyncClient(base_url=server) as client:
async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]:
async with httpx.AsyncClient(base_url=server_url) as client:
yield client


Expand Down Expand Up @@ -163,8 +165,8 @@ async def connection_test() -> None:


@pytest.mark.anyio
async def test_sse_basic_connection(server: str) -> None:
async with sse_client(server + "/mcp") as streams:
async def test_sse_basic_connection(server: None, server_url: str) -> None:
async with sse_client(server_url + "/mcp") as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
Expand All @@ -177,8 +179,8 @@ async def test_sse_basic_connection(server: str) -> None:


@pytest.mark.anyio
async def test_sse_tool_call(server: str) -> None:
async with sse_client(server + "/mcp") as streams:
async def test_sse_tool_call(server: None, server_url: str) -> None:
async with sse_client(server_url + "/mcp") as streams:
async with ClientSession(*streams) as session:
await session.initialize()

Expand Down