diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index eea94ac..f5c4fc6 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -299,9 +299,7 @@ def mount( str, Doc( """ - Path where the MCP server will be mounted. - Mount path is appended to the root path of FastAPI router, or to the prefix of APIRouter. - Defaults to '/mcp'. + Path where the MCP server will be mounted. Defaults to '/mcp'. """ ), ] = "/mcp", @@ -330,9 +328,14 @@ def mount( router = self.fastapi # Build the base path correctly for the SSE transport - assert isinstance(router, (FastAPI, APIRouter)), f"Invalid router type: {type(router)}" - base_path = mount_path if isinstance(router, FastAPI) else router.prefix + mount_path - messages_path = f"{base_path}/messages/" + if isinstance(router, FastAPI): + base_path = router.root_path + elif isinstance(router, APIRouter): + base_path = self.fastapi.root_path + router.prefix + else: + raise ValueError(f"Invalid router type: {type(router)}") + + messages_path = f"{base_path}{mount_path}/messages/" sse_transport = FastApiSseTransport(messages_path) diff --git a/tests/fixtures/complex_app.py b/tests/fixtures/complex_app.py index d248308..72ba14d 100644 --- a/tests/fixtures/complex_app.py +++ b/tests/fixtures/complex_app.py @@ -4,8 +4,6 @@ from fastapi import FastAPI, Query, Path, Body, Header, Cookie import pytest -from tests.fixtures.conftest import make_fastapi_app_base - from .types import ( Product, Customer, @@ -21,9 +19,12 @@ def make_complex_fastapi_app( example_product: Product, example_customer: Customer, example_order_response: OrderResponse, - parametrized_config: dict[str, Any] | None = None, ) -> FastAPI: - app = make_fastapi_app_base(parametrized_config=parametrized_config) + app = FastAPI( + title="Complex E-Commerce API", + description="A more complex API with nested models and various schemas", + version="1.0.0", + ) @app.get( "/products", diff --git a/tests/fixtures/conftest.py b/tests/fixtures/conftest.py deleted file mode 100644 index ca0635a..0000000 --- a/tests/fixtures/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Any -from fastapi import FastAPI - - -def make_fastapi_app_base(parametrized_config: dict[str, Any] | None = None) -> FastAPI: - fastapi_config: dict[str, Any] = { - "title": "Test API", - "description": "A test API app for unit testing", - "version": "0.1.0", - } - app = FastAPI(**fastapi_config | parametrized_config if parametrized_config is not None else {}) - return app diff --git a/tests/fixtures/simple_app.py b/tests/fixtures/simple_app.py index 5d8298a..2b21872 100644 --- a/tests/fixtures/simple_app.py +++ b/tests/fixtures/simple_app.py @@ -1,15 +1,18 @@ -from typing import Optional, List, Any +from typing import Optional, List from fastapi import FastAPI, Query, Path, Body, HTTPException import pytest -from tests.fixtures.conftest import make_fastapi_app_base - from .types import Item -def make_simple_fastapi_app(parametrized_config: dict[str, Any] | None = None) -> FastAPI: - app = make_fastapi_app_base(parametrized_config=parametrized_config) +def make_simple_fastapi_app() -> FastAPI: + app = FastAPI( + title="Test API", + description="A test API app for unit testing", + version="0.1.0", + ) + items = [ Item(id=1, name="Item 1", price=10.0, tags=["tag1", "tag2"], description="Item 1 description"), Item(id=2, name="Item 2", price=20.0, tags=["tag2", "tag3"]), @@ -67,8 +70,3 @@ async def raise_error() -> None: @pytest.fixture def simple_fastapi_app() -> FastAPI: return make_simple_fastapi_app() - - -@pytest.fixture -def simple_fastapi_app_with_root_path() -> FastAPI: - return make_simple_fastapi_app(parametrized_config={"root_path": "/api/v1"}) diff --git a/tests/test_sse_real_transport.py b/tests/test_sse_real_transport.py index 408e117..1ac307c 100644 --- a/tests/test_sse_real_transport.py +++ b/tests/test_sse_real_transport.py @@ -9,7 +9,6 @@ import threading import coverage from typing import AsyncGenerator, Generator -from fastapi import FastAPI from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp import InitializeResult @@ -19,12 +18,26 @@ import uvicorn from fastapi_mcp import FastApiMCP +from .fixtures.simple_app import make_simple_fastapi_app + HOST = "127.0.0.1" SERVER_NAME = "Test MCP Server" -def run_server(server_port: int, fastapi_app: FastAPI) -> None: +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind((HOST, 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://{HOST}:{server_port}" + + +def run_server(server_port: int) -> None: # Initialize coverage for subprocesses cov = None if "COVERAGE_PROCESS_START" in os.environ: @@ -59,15 +72,16 @@ def periodic_save(): save_thread.start() # Configure the server + fastapi = make_simple_fastapi_app() mcp = FastApiMCP( - fastapi_app, + fastapi, name=SERVER_NAME, description="Test description", ) mcp.mount() # Start the server - server = uvicorn.Server(config=uvicorn.Config(app=fastapi_app, host=HOST, port=server_port, log_level="error")) + server = uvicorn.Server(config=uvicorn.Config(app=fastapi, host=HOST, port=server_port, log_level="error")) server.run() # Give server time to start @@ -80,24 +94,13 @@ def periodic_save(): cov.save() -@pytest.fixture(params=["simple_fastapi_app", "simple_fastapi_app_with_root_path"]) -def server(request: pytest.FixtureRequest) -> Generator[str, None, None]: +@pytest.fixture() +def server(server_port: int) -> Generator[None, None, None]: # Ensure COVERAGE_PROCESS_START is set in the environment for subprocesses coverage_rc = os.path.abspath(".coveragerc") os.environ["COVERAGE_PROCESS_START"] = coverage_rc - # Get a free port - with socket.socket() as s: - s.bind((HOST, 0)) - server_port = s.getsockname()[1] - - # Run the server in a subprocess - fastapi_app = request.getfixturevalue(request.param) - proc = multiprocessing.Process( - target=run_server, - kwargs={"server_port": server_port, "fastapi_app": fastapi_app}, - daemon=True, - ) + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) proc.start() # Wait for server to be running @@ -114,8 +117,7 @@ def server(request: pytest.FixtureRequest) -> Generator[str, None, None]: else: raise RuntimeError(f"Server failed to start after {max_attempts} attempts") - # Return the server URL - yield f"http://{HOST}:{server_port}{fastapi_app.root_path}" + yield # Signal the server to stop - added graceful shutdown before kill try: @@ -132,8 +134,8 @@ def server(request: pytest.FixtureRequest) -> Generator[str, None, None]: @pytest.fixture() -async def http_client(server: str) -> AsyncGenerator[httpx.AsyncClient, None]: - async with httpx.AsyncClient(base_url=server) as client: +async def http_client(server: None, server_url: str) -> AsyncGenerator[httpx.AsyncClient, None]: + async with httpx.AsyncClient(base_url=server_url) as client: yield client @@ -163,8 +165,8 @@ async def connection_test() -> None: @pytest.mark.anyio -async def test_sse_basic_connection(server: str) -> None: - async with sse_client(server + "/mcp") as streams: +async def test_sse_basic_connection(server: None, server_url: str) -> None: + async with sse_client(server_url + "/mcp") as streams: async with ClientSession(*streams) as session: # Test initialization result = await session.initialize() @@ -177,8 +179,8 @@ async def test_sse_basic_connection(server: str) -> None: @pytest.mark.anyio -async def test_sse_tool_call(server: str) -> None: - async with sse_client(server + "/mcp") as streams: +async def test_sse_tool_call(server: None, server_url: str) -> None: + async with sse_client(server_url + "/mcp") as streams: async with ClientSession(*streams) as session: await session.initialize()