Skip to content

Commit cddc32b

Browse files
committed
remove some duplication
1 parent 42f2342 commit cddc32b

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

src/fenic/api/mcp/tool_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def _auto_generate_sql_tool(
413413
raise ConfigurationError("Cannot create SQL tool: no datasets provided.")
414414

415415
def analyze_func(
416-
full_sql: Annotated[str, "Full SELECT SQL. Refer to DataFrames by name in braces, e.g., {orders}."]
416+
full_sql: Annotated[str, "Full SELECT SQL. Refer to DataFrames by name in braces, e.g., `SELECT * FROM {orders}`. JOINs between the provided datasets are allowed. SQL dialect: DuckDB. DDL/DML, CTEs, subqueries, UNION, and multiple top-level queries are not allowed"]
417417
) -> LogicalPlan:
418418
return session.sql(full_sql.strip(), **{spec.table_name: spec.df for spec in datasets})._logical_plan
419419

@@ -789,6 +789,7 @@ def _auto_generate_core_tools(
789789
tool_name=f"{tool_group_name} - Read",
790790
tool_description="\n\n".join([
791791
"Read rows from a single dataset. Use to sample data, or to execute simple queries over the data that do not require filtering or grouping.",
792+
"Use `include_columns` and `exclude_columns` to filter columns by name -- this is important to conserve token usage. Use the `Profile` tool to understand the columns and their sizes.",
792793
"Available datasets:",
793794
group_desc,
794795
]),

src/fenic/core/mcp/_server.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from functools import wraps
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

18+
import polars as pl
1819
from pydantic import BaseModel, ConfigDict
1920
from typing_extensions import Annotated, Literal
2021

@@ -48,7 +49,8 @@ class MCPResultSet(BaseModel):
4849

4950
table_schema: Optional[List[Dict[str, Any]]]
5051
rows: Union[List[Dict[str, Any]], str]
51-
row_count: int
52+
returned_result_count: int
53+
total_result_count: int
5254

5355
MCPTransport = Literal["http", "stdio"]
5456

@@ -135,6 +137,24 @@ def http_app(self, **kwargs):
135137
"""Create a Starlette ASGI app for the MCP server."""
136138
return self.mcp.http_app(**kwargs)
137139

140+
def _handle_result_set(self, pl_df: pl.DataFrame, effective_limit: Optional[int], table_format: TableFormat) -> MCPResultSet:
141+
"""Handle the result set from a logical plan."""
142+
original_result_count = len(pl_df)
143+
if effective_limit and original_result_count > effective_limit:
144+
pl_df = pl_df.limit(effective_limit)
145+
rows_list = pl_df.to_dicts()
146+
schema_fields = [{"name": name, "type": str(dtype)} for name, dtype in pl_df.schema.items()]
147+
result_set = MCPResultSet(
148+
table_schema=schema_fields,
149+
rows=rows_list,
150+
returned_result_count=len(rows_list),
151+
total_result_count=original_result_count,
152+
)
153+
if table_format == "markdown":
154+
result_set.rows = _render_markdown_preview(rows_list)
155+
result_set.table_schema = None
156+
return result_set
157+
138158
def _build_parameterized_tool(self, tool: ParameterizedToolDefinition):
139159
"""Build a keyword-argument tool function with per-field schema for FastMCP.
140160
@@ -162,24 +182,13 @@ async def tool_fn_wrapper(*args, **kwargs) -> MCPResultSet:
162182
bound_plan = bind_parameters(tool._parameterized_view, payload, tool.params)
163183
async with self._collect_semaphore:
164184
pl_df, metrics = await asyncio.to_thread(
165-
lambda: self.session_state.execution.collect(bound_plan, n=effective_limit)
185+
lambda: self.session_state.execution.collect(bound_plan)
166186
)
167187
logger.info(f"Completed query for {tool.name}")
168188
logger.info(metrics.get_summary())
169189
logger.debug(f"Query Details: {params_obj.model_dump_json()}")
170190

171-
rows_list = pl_df.to_dicts()
172-
schema_fields = [{"name": name, "type": str(dtype)} for name, dtype in pl_df.schema.items()]
173-
result_set = MCPResultSet(
174-
table_schema=schema_fields,
175-
rows=rows_list,
176-
row_count=len(rows_list),
177-
)
178-
if table_format == "markdown":
179-
result_set.rows = _render_markdown_preview(rows_list)
180-
result_set.table_schema = None
181-
182-
return result_set
191+
return self._handle_result_set(pl_df, effective_limit, table_format)
183192
except Exception as e:
184193
from fastmcp.exceptions import ToolError
185194
raise ToolError(f"Fenic server failed to execute tool {tool.name}. Underlying error: {e}") from e
@@ -263,19 +272,13 @@ async def wrapper(*args, **kwargs) -> MCPResultSet:
263272
# collections with a semaphore to protect the backend executor.
264273
async with self._collect_semaphore:
265274
pl_df, metrics = await asyncio.to_thread(
266-
lambda: self.session_state.execution.collect(bound_plan, n=effective_limit)
275+
lambda: self.session_state.execution.collect(bound_plan)
267276
)
268277
logger.info(f"Completed query for {tool.name}")
269278
logger.info(metrics.get_summary())
270279
logger.debug(f"Query Details: {args if args else kwargs}")
271-
rows_list = pl_df.to_dicts()
272-
schema_fields = [{"name": name, "type": str(dtype)} for name, dtype in pl_df.schema.items()]
273-
out = MCPResultSet(table_schema=schema_fields, rows=rows_list, row_count=len(rows_list))
274-
if table_format == "markdown":
275-
out.rows = _render_markdown_preview(rows_list)
276-
out.table_schema = None
277-
278-
return out
280+
281+
return self._handle_result_set(pl_df, effective_limit, table_format)
279282
except Exception as e:
280283
from fastmcp.exceptions import ToolError
281284
raise ToolError(f"Fenic server failed to execute tool {tool.name}. Underlying error: {e}") from e

0 commit comments

Comments
 (0)