Skip to content

Commit 8addfa6

Browse files
committed
mcp adapter
1 parent 2fda450 commit 8addfa6

File tree

5 files changed

+515
-103
lines changed

5 files changed

+515
-103
lines changed

client/astra_assistants/astra_assistants_manager.py

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,41 @@
11
import logging
22
import os
3-
from typing import List
3+
import uuid
4+
from typing import List, Any
45

56
from litellm import get_llm_provider
67

78
from astra_assistants import patch, OpenAIWithDefaultKey
89
from astra_assistants.astra_assistants_event_handler import AstraEventHandler
910
from astra_assistants.tools.tool_interface import ToolInterface
1011
from astra_assistants.utils import env_var_is_missing, get_env_vars_for_provider
12+
from astra_assistants.mcp_openai_adapter import MCPOpenAIAAdapter
1113

1214
logger = logging.getLogger(__name__)
1315

1416
class AssistantManager:
15-
def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str = "managed_assistant", tools: List[ToolInterface] = None, thread_id: str = None, thread: str = None, assistant_id: str = None, client = None, tool_resources = None):
17+
def __init__(self,
18+
instructions: str = None,
19+
model: str = "gpt-4o",
20+
name: str = "managed_assistant",
21+
tools: List[ToolInterface] = None,
22+
thread_id: str = None,
23+
thread: str = None,
24+
assistant_id: str = None,
25+
client = None,
26+
tool_resources = None,
27+
mcp_represenations = None
28+
):
29+
1630
if instructions is None and assistant_id is None:
1731
raise Exception("Instructions must be provided if assistant_id is not provided")
1832
if tools is None:
1933
tools = []
20-
# Only patch if astra token is provided
34+
35+
36+
self.tools = tools
37+
38+
# Initialize client using the provided client or the default based on environment tokens.
2139
if client is not None:
2240
self.client = client
2341
else:
@@ -31,7 +49,6 @@ def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str =
3149
self.client = OpenAIWithDefaultKey()
3250
self.model = model
3351
self.instructions = instructions
34-
self.tools = tools
3552
self.tool_resources = tool_resources
3653
self.name = name
3754
self.tool_call_arguments = None
@@ -48,9 +65,25 @@ def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str =
4865
elif thread_id is not None:
4966
self.thread = self.client.beta.threads.retrieve(thread_id)
5067

68+
69+
self.mcp_adapter = None
70+
self.register_mcp(mcp_represenations)
71+
5172
logger.info(f'assistant {self.assistant}')
5273
logger.info(f'thread {self.thread}')
5374

75+
def register_mcp(self, mcp_representations):
76+
# If MCP representations are provided, convert them to tools using the adapter.
77+
if mcp_representations is not None:
78+
self.mcp_adapter = MCPOpenAIAAdapter(mcp_representations)
79+
80+
mcp_tools = self.mcp_adapter.get_tools()
81+
self.tools.extend(mcp_tools)
82+
83+
schemas = self.mcp_adapter.get_json_schema_for_tools()
84+
assistant = self.client.beta.assistants.update(assistant_id=self.assistant.id, tools=schemas)
85+
self.assistant = assistant
86+
5487
def get_client(self):
5588
return self.client
5689

@@ -65,25 +98,24 @@ def create_assistant(self):
6598
for tool in self.tools:
6699
if hasattr(tool, 'to_function'):
67100
tool_holder.append(tool.to_function())
68-
69101
if len(tool_holder) == 0:
70102
tool_holder = self.tools
71103

72-
# Create and return the assistant
104+
# Create and return the assistant with the combined tool definitions.
73105
self.assistant = self.client.beta.assistants.create(
74106
name=self.name,
75107
instructions=self.instructions,
76108
model=self.model,
77109
tools=tool_holder,
78110
tool_resources=self.tool_resources
79111
)
80-
logger.debug("Assistant created:", self.assistant)
112+
logger.debug("Assistant created: %s", self.assistant)
81113
return self.assistant
82114

83115
def create_thread(self):
84-
# Create and return a new thread
116+
# Create and return a new thread.
85117
thread = self.client.beta.threads.create()
86-
logger.debug("Thread generated:", thread)
118+
logger.debug("Thread generated: %s", thread)
87119
return thread
88120

89121
def stream_thread(self, content, tool_choice = None, thread_id: str = None, thread = None, additional_instructions = None):
@@ -112,7 +144,6 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
112144
"event_handler": event_handler,
113145
"additional_instructions": additional_instructions
114146
}
115-
# Conditionally add 'tool_choice' if it's not None
116147
if tool_choice is not None:
117148
args["tool_choice"] = tool_choice
118149

@@ -121,8 +152,6 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
121152
for text in stream.text_deltas:
122153
yield text
123154

124-
tool_call_results = None
125-
tool_call_arguments = None
126155
self.tool_call_arguments = event_handler.arguments
127156
if event_handler.stream is not None:
128157
if event_handler.tool_call_results is not None:
@@ -133,7 +162,7 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
133162
except Exception as e:
134163
logger.error(e)
135164
raise e
136-
165+
137166
async def run_thread(self, content, tool = None, thread_id: str = None, thread = None, additional_instructions = None):
138167
if thread_id is not None:
139168
thread = self.client.beta.threads.retrieve(thread_id)
@@ -142,10 +171,15 @@ async def run_thread(self, content, tool = None, thread_id: str = None, thread =
142171

143172
assistant = self.assistant
144173
event_handler = AstraEventHandler(self.client)
174+
145175
tool_choice = None
146176
if tool is not None:
147177
event_handler.register_tool(tool)
148178
tool_choice = tool.tool_choice_object()
179+
180+
for tool in self.tools:
181+
event_handler.register_tool(tool)
182+
149183
try:
150184
self.client.beta.threads.messages.create(
151185
thread_id=thread.id, role="user", content=content
@@ -156,33 +190,37 @@ async def run_thread(self, content, tool = None, thread_id: str = None, thread =
156190
"event_handler": event_handler,
157191
"additional_instructions": additional_instructions
158192
}
159-
# Conditionally add 'tool_choice' if it's not None
160193
if tool_choice is not None:
161194
args["tool_choice"] = tool_choice
162195

163196
text = ""
164-
with self.client.beta.threads.runs.create_and_stream(**args) as stream:
197+
with self.client.beta.threads.runs.stream(**args) as stream:
165198
for part in stream.text_deltas:
166199
text += part
167-
200+
168201
tool_call_results = None
169202
if event_handler.stream is not None:
170203
with event_handler.stream as stream:
171204
for part in stream.text_deltas:
172205
text += part
173206

174207
tool_call_results = event_handler.tool_call_results
175-
file_search = event_handler.file_search
208+
if tool_call_results is not None:
209+
file_search = event_handler.file_search
176210

177-
tool_call_results['file_search'] = file_search
178-
tool_call_results['text'] = text
179-
tool_call_results['arguments'] = event_handler.arguments
211+
tool_call_results['file_search'] = file_search
212+
tool_call_results['text'] = text
213+
tool_call_results['arguments'] = event_handler.arguments
214+
else:
215+
print("event_handler.stream is not None but tool_call_results is None, bug?")
180216

181217
logger.info(tool_call_results)
182-
tool_call_results
183218
if tool_call_results is not None:
184219
return tool_call_results
185220
return {"text": text, "file_search": event_handler.file_search}
186221
except Exception as e:
187222
logger.error(e)
188-
raise e
223+
raise e
224+
225+
def shutdown(self):
226+
self.mcp_adapter.shutdown()
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import asyncio
2+
import os
3+
import threading
4+
import re
5+
from abc import ABC
6+
from contextlib import AsyncExitStack
7+
from typing import List, Union, Optional, Literal, Dict, Any, Type
8+
9+
from mcp.types import CallToolResult
10+
from pydantic import BaseModel, Field, create_model
11+
12+
# Import the high‐level MCP client interfaces from the official SDK.
13+
from mcp import ClientSession, StdioServerParameters
14+
from mcp.client.stdio import stdio_client
15+
16+
from astra_assistants.tools.tool_interface import ToolInterface, ToolResult
17+
18+
# --- MCP Representation Models ---
19+
20+
class MCPRepresentationBase(BaseModel):
21+
type: str
22+
23+
class MCPRepresentationStdio(MCPRepresentationBase):
24+
type: str = Literal["stdio"]
25+
command: str
26+
arguments: Optional[List[str]] = None
27+
env_vars: Optional[List[str]] = None
28+
29+
class MCPRepresentationSSE(MCPRepresentationBase):
30+
type: str = Literal["sse"]
31+
sse_url: str
32+
33+
MCPRepresentation = Union[MCPRepresentationStdio, MCPRepresentationSSE]
34+
35+
# --- Helper functions ---
36+
37+
def generate_pydantic_model_from_schema(schema: Dict[str, Any], model_name: str = "DynamicModel") -> Type[BaseModel]:
38+
fields = {}
39+
properties = schema.get("properties", {})
40+
required_fields = set(schema.get("required", []))
41+
type_mapping = {
42+
"string": str,
43+
"integer": int,
44+
"number": float,
45+
"boolean": bool,
46+
"array": list,
47+
"object": dict
48+
}
49+
for field_name, field_schema in properties.items():
50+
field_type = type_mapping.get(field_schema.get("type"), Any)
51+
if field_name in required_fields:
52+
fields[field_name] = (field_type, ...)
53+
else:
54+
fields[field_name] = (field_type, None)
55+
return create_model(model_name, **fields)
56+
57+
def to_camel_case(s: str) -> str:
58+
return ''.join(word.capitalize() for word in re.split(r'[_-]', s))
59+
60+
# --- MCP Tool Adapter (implements ToolInterface) ---
61+
62+
class MCPToolAdapter(ToolInterface, ABC):
63+
def __init__(self, representation: MCPRepresentation, mcp_session: ClientSession, mcp_tool):
64+
self.representation = representation
65+
self.mcp_session = mcp_session
66+
self.mcp_tool = mcp_tool
67+
68+
def get_model(self):
69+
return generate_pydantic_model_from_schema(
70+
self.mcp_tool.inputSchema,
71+
to_camel_case(self.mcp_tool.name)
72+
)
73+
74+
def to_function(self) -> dict:
75+
return {
76+
"type": "function",
77+
"function": {
78+
"name": self.mcp_tool.name,
79+
"description": self.mcp_tool.description,
80+
"parameters": self.mcp_tool.inputSchema
81+
}
82+
}
83+
84+
def call(self, arguments: BaseModel) -> CallToolResult:
85+
# Use the background loop to run the async call synchronously.
86+
future = asyncio.run_coroutine_threadsafe(
87+
self.mcp_session.call_tool(
88+
self.mcp_tool.name,
89+
arguments=arguments.model_dump()
90+
),
91+
self.mcp_session_loop # set below when session is created
92+
)
93+
return {"output": future.result().content[0].text}
94+
95+
# --- MCP OpenAI Adapter ---
96+
97+
class MCPOpenAIAAdapter:
98+
"""
99+
This adapter connects to an MCP server using the official Python SDK (via stdio transport)
100+
on a dedicated background thread. This allows synchronous methods (like call) to schedule
101+
async work via asyncio.run_coroutine_threadsafe.
102+
"""
103+
def __init__(
104+
self,
105+
mcp_representations: List[MCPRepresentation] = None,
106+
):
107+
self.exit_stack = AsyncExitStack()
108+
self.mcp_representations = mcp_representations or []
109+
self.server_params = []
110+
for rep in self.mcp_representations:
111+
if rep.type == 'stdio':
112+
env_vars = {"PATH": os.environ["PATH"]}
113+
if rep.env_vars is not None:
114+
# Assume env_vars are provided as a dict-like mapping or as "KEY=VALUE" strings.
115+
for var in rep.env_vars:
116+
if "=" in var:
117+
key, value = var.split("=", 1)
118+
env_vars[key] = value
119+
# Split command into executable and arguments.
120+
parts = rep.command.split()
121+
executable = parts[0]
122+
initial_args = parts[1:]
123+
combined_args = initial_args + (rep.arguments or [])
124+
server_param = StdioServerParameters(
125+
command=executable,
126+
args=combined_args,
127+
env=env_vars,
128+
)
129+
self.server_params.append(server_param)
130+
elif rep.type == 'sse':
131+
self.server_params.append(rep.sse_url)
132+
self.session: Optional[ClientSession] = None
133+
self.tools: List[MCPToolAdapter] = []
134+
self._bg_loop = asyncio.new_event_loop()
135+
self._bg_thread = threading.Thread(target=self._run_bg_loop, daemon=True)
136+
self._bg_thread.start()
137+
138+
def _run_bg_loop(self):
139+
asyncio.set_event_loop(self._bg_loop)
140+
self._bg_loop.run_forever()
141+
142+
def sync_connect(self):
143+
"""
144+
Synchronously connect to the MCP server using the background loop.
145+
This schedules the async connect() coroutine on the background loop.
146+
"""
147+
for server_param in self.server_params:
148+
asyncio.run_coroutine_threadsafe(self._connect(server_param), self._bg_loop).result()
149+
150+
async def _connect(self, server_param):
151+
transport = await self.exit_stack.enter_async_context(stdio_client(server_param))
152+
self.stdio, self.write = transport
153+
# Create the session on the background loop.
154+
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
155+
await self.session.initialize()
156+
# Attach the background loop reference to each tool adapter.
157+
result = await self.session.list_tools()
158+
for rep in self.mcp_representations:
159+
for tool in result.tools:
160+
adapter = MCPToolAdapter(representation=rep, mcp_session=self.session, mcp_tool=tool)
161+
# Set the event loop used by this session (i.e. the background loop)
162+
adapter.mcp_session_loop = self._bg_loop
163+
self.tools.append(adapter)
164+
165+
def get_tools(self) -> List[MCPToolAdapter]:
166+
if self.session is None:
167+
self.sync_connect()
168+
return self.tools
169+
170+
def get_json_schema_for_tools(self) -> List[dict]:
171+
# Since to_function() is synchronous, simply return the schemas.
172+
return [tool_adapter.to_function() for tool_adapter in self.tools]
173+
174+
def shutdown(self):
175+
"""
176+
Cleanly shuts down the background loop and thread.
177+
"""
178+
# First, if session exists, schedule exit of the exit stack.
179+
if self.session is not None:
180+
future = asyncio.run_coroutine_threadsafe(self.exit_stack.aclose(), self._bg_loop)
181+
try:
182+
future.result(timeout=5)
183+
except Exception as e:
184+
print("Error during exit_stack.aclose():", e)
185+
self.session = None
186+
# Signal the background loop to stop.
187+
self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
188+
# Wait for the background thread to finish.
189+
self._bg_thread.join(timeout=5)
190+
# Close the loop.
191+
self._bg_loop.close()

0 commit comments

Comments
 (0)