Skip to content

Commit 8ecc7c5

Browse files
committed
more tests passing
1 parent 91ed867 commit 8ecc7c5

File tree

3 files changed

+145
-6
lines changed

3 files changed

+145
-6
lines changed

src/mcpcat/modules/tools.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from mcp import Tool
5+
from mcp import ServerResult, Tool
66
from mcp.server import FastMCP
77
from mcp.types import CallToolRequest, CallToolResult, ListToolsRequest, TextContent
88

@@ -25,7 +25,9 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ListToolsResu
2525
"""Intercept list_tools requests to add MCPCat tools and modify existing ones."""
2626
# Call the original handler to get the tools
2727
original_result = await original_list_tools_handler(request)
28-
tools_list = original_result.tools if hasattr(original_result, 'tools') else []
28+
if not original_result or not hasattr(original_result, 'root') or not hasattr(original_result.root, 'tools'):
29+
return original_result
30+
tools_list = original_result.root.tools
2931

3032
# Add report_missing tool if enabled
3133
if data.options.enableReportMissing:
@@ -77,7 +79,7 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ListToolsResu
7779
else:
7880
tool.inputSchema["required"] = ["context"]
7981

80-
return ListToolsResult(tools=tools_list)
82+
return ServerResult(ListToolsResult(tools=tools_list))
8183

8284
async def wrapped_call_tool_handler(request: CallToolRequest) -> CallToolResult:
8385
"""Intercept call_tool requests to add MCPCat tracking and handle special tools."""
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Test that mcpcat.track preserves existing tools and only adds report_missing."""
2+
3+
import pytest
4+
5+
from mcpcat import track
6+
from mcpcat.types import MCPCatOptions
7+
8+
from .test_utils.client import create_test_client
9+
from .test_utils.todo_server import create_todo_server
10+
11+
12+
class TestPreserveExistingTools:
13+
"""Test that existing tools are preserved when tracking."""
14+
15+
async def test_track_preserves_all_existing_tools(self):
16+
"""Should preserve all existing tools and only add report_missing."""
17+
# Create server with existing tools
18+
server = create_todo_server()
19+
20+
# Get original tools before tracking
21+
async with create_test_client(server) as client:
22+
result = await client.list_tools()
23+
original_tools = result.tools
24+
original_tool_names = {tool.name for tool in original_tools}
25+
26+
# Track the server
27+
tracked_server = track(server)
28+
29+
# Get tools after tracking
30+
async with create_test_client(tracked_server) as client:
31+
result = await client.list_tools()
32+
tracked_tools = result.tools
33+
tracked_tool_names = {tool.name for tool in tracked_tools}
34+
35+
# Verify all original tools are preserved
36+
assert original_tool_names.issubset(tracked_tool_names), "Original tools should be preserved"
37+
38+
# Verify only report_missing was added
39+
added_tools = tracked_tool_names - original_tool_names
40+
assert added_tools == {"report_missing"}, "Only report_missing should be added"
41+
42+
# Verify the exact count
43+
assert len(tracked_tools) == len(original_tools) + 1, "Should have exactly one more tool"
44+
45+
# Verify original tools are still present
46+
assert "add_todo" in tracked_tool_names
47+
assert "list_todos" in tracked_tool_names
48+
assert "complete_todo" in tracked_tool_names
49+
50+
async def test_track_without_report_missing_preserves_exact_tools(self):
51+
"""Should preserve exact tool list when report_missing is disabled."""
52+
# Create server with existing tools
53+
server = create_todo_server()
54+
55+
# Get original tools before tracking
56+
async with create_test_client(server) as client:
57+
result = await client.list_tools()
58+
original_tools = result.tools
59+
original_tool_names = {tool.name for tool in original_tools}
60+
61+
# Track the server with report_missing disabled
62+
tracked_server = track(server, MCPCatOptions(enableReportMissing=False))
63+
64+
# Get tools after tracking
65+
async with create_test_client(tracked_server) as client:
66+
result = await client.list_tools()
67+
tracked_tools = result.tools
68+
tracked_tool_names = {tool.name for tool in tracked_tools}
69+
70+
# Verify tool lists are identical
71+
assert tracked_tool_names == original_tool_names, "Tool names should be identical"
72+
assert len(tracked_tools) == len(original_tools), "Tool count should be identical"
73+
74+
async def test_track_with_context_modifies_existing_tools_but_preserves_them(self):
75+
"""Should modify existing tools to add context but preserve all of them."""
76+
# Create server with existing tools
77+
server = create_todo_server()
78+
79+
# Get original tools before tracking
80+
async with create_test_client(server) as client:
81+
result = await client.list_tools()
82+
original_tools = result.tools
83+
original_tool_names = {tool.name for tool in original_tools}
84+
85+
# Track the server with context enabled
86+
tracked_server = track(server, MCPCatOptions(enableToolCallContext=True))
87+
88+
# Get tools after tracking
89+
async with create_test_client(tracked_server) as client:
90+
result = await client.list_tools()
91+
tracked_tools = result.tools
92+
93+
# Find modified tools (excluding report_missing)
94+
for tool in tracked_tools:
95+
if tool.name in original_tool_names:
96+
# Verify context was added to schema
97+
assert "context" in tool.inputSchema["properties"], f"Context should be added to {tool.name}"
98+
assert "context" in tool.inputSchema["required"], f"Context should be required for {tool.name}"
99+
elif tool.name == "report_missing":
100+
# Verify report_missing doesn't have context parameter
101+
assert "context" not in tool.inputSchema["properties"], "report_missing should not have context"
102+
103+
# Verify all original tools are still present
104+
tracked_tool_names = {tool.name for tool in tracked_tools}
105+
assert original_tool_names.issubset(tracked_tool_names), "All original tools should be preserved"
106+
107+
async def test_multiple_track_calls_do_not_duplicate_tools(self):
108+
"""Should not duplicate tools when track is called multiple times."""
109+
# Create server with existing tools
110+
server = create_todo_server()
111+
112+
# Track the server multiple times
113+
track(server)
114+
track(server) # Should be a no-op
115+
track(server) # Should be a no-op
116+
117+
# Get tools after multiple track calls
118+
async with create_test_client(server) as client:
119+
result = await client.list_tools()
120+
tools = result.tools
121+
tool_names = [tool.name for tool in tools]
122+
123+
# Count occurrences of each tool
124+
tool_counts = {}
125+
for name in tool_names:
126+
tool_counts[name] = tool_counts.get(name, 0) + 1
127+
128+
# Verify no duplicates
129+
for name, count in tool_counts.items():
130+
assert count == 1, f"Tool {name} should appear exactly once, but appears {count} times"
131+
132+
# Verify expected tools are present
133+
assert "add_todo" in tool_names
134+
assert "list_todos" in tool_names
135+
assert "complete_todo" in tool_names
136+
assert "report_missing" in tool_names
137+
assert len(tools) == 4, "Should have exactly 4 tools"

tests/test_utils/todo_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def create_todo_server() -> FastMCP:
2222
next_id = 1
2323

2424
@server.tool()
25-
def add_todo(text: str, context: str | None = None) -> str:
25+
def add_todo(text: str) -> str:
2626
"""Add a new todo item."""
2727
nonlocal next_id
2828
todo = Todo(str(next_id), text)
@@ -31,7 +31,7 @@ def add_todo(text: str, context: str | None = None) -> str:
3131
return f'Added todo: "{text}" with ID {todo.id}'
3232

3333
@server.tool()
34-
def list_todos(context: str | None = None) -> str:
34+
def list_todos() -> str:
3535
"""List all todo items."""
3636
if not todos:
3737
return "No todos found"
@@ -44,7 +44,7 @@ def list_todos(context: str | None = None) -> str:
4444
return "\n".join(todo_list)
4545

4646
@server.tool()
47-
def complete_todo(id: str, context: str | None = None) -> str:
47+
def complete_todo(id: str) -> str:
4848
"""Mark a todo item as completed."""
4949
for todo in todos:
5050
if todo.id == id:

0 commit comments

Comments
 (0)