1
1
import logging
2
2
import os
3
- from typing import List
3
+ import uuid
4
+ from typing import List , Any
4
5
5
6
from litellm import get_llm_provider
6
7
7
8
from astra_assistants import patch , OpenAIWithDefaultKey
8
9
from astra_assistants .astra_assistants_event_handler import AstraEventHandler
9
10
from astra_assistants .tools .tool_interface import ToolInterface
10
11
from astra_assistants .utils import env_var_is_missing , get_env_vars_for_provider
12
+ from astra_assistants .mcp_openai_adapter import MCPOpenAIAAdapter
11
13
12
14
logger = logging .getLogger (__name__ )
13
15
14
16
class AssistantManager :
15
- def __init__ (self , instructions : str = None , model : str = "gpt-4o" , name : str = "managed_assistant" , tools : List [ToolInterface ] = None , thread_id : str = None , thread : str = None , assistant_id : str = None , client = None , tool_resources = None ):
17
+ def __init__ (self ,
18
+ instructions : str = None ,
19
+ model : str = "gpt-4o" ,
20
+ name : str = "managed_assistant" ,
21
+ tools : List [ToolInterface ] = None ,
22
+ thread_id : str = None ,
23
+ thread : str = None ,
24
+ assistant_id : str = None ,
25
+ client = None ,
26
+ tool_resources = None ,
27
+ mcp_represenations = None
28
+ ):
29
+
16
30
if instructions is None and assistant_id is None :
17
31
raise Exception ("Instructions must be provided if assistant_id is not provided" )
18
32
if tools is None :
19
33
tools = []
20
- # Only patch if astra token is provided
34
+
35
+
36
+ self .tools = tools
37
+
38
+ # Initialize client using the provided client or the default based on environment tokens.
21
39
if client is not None :
22
40
self .client = client
23
41
else :
@@ -31,7 +49,6 @@ def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str =
31
49
self .client = OpenAIWithDefaultKey ()
32
50
self .model = model
33
51
self .instructions = instructions
34
- self .tools = tools
35
52
self .tool_resources = tool_resources
36
53
self .name = name
37
54
self .tool_call_arguments = None
@@ -48,9 +65,25 @@ def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str =
48
65
elif thread_id is not None :
49
66
self .thread = self .client .beta .threads .retrieve (thread_id )
50
67
68
+
69
+ self .mcp_adapter = None
70
+ self .register_mcp (mcp_represenations )
71
+
51
72
logger .info (f'assistant { self .assistant } ' )
52
73
logger .info (f'thread { self .thread } ' )
53
74
75
+ def register_mcp (self , mcp_representations ):
76
+ # If MCP representations are provided, convert them to tools using the adapter.
77
+ if mcp_representations is not None :
78
+ self .mcp_adapter = MCPOpenAIAAdapter (mcp_representations )
79
+
80
+ mcp_tools = self .mcp_adapter .get_tools ()
81
+ self .tools .extend (mcp_tools )
82
+
83
+ schemas = self .mcp_adapter .get_json_schema_for_tools ()
84
+ assistant = self .client .beta .assistants .update (assistant_id = self .assistant .id , tools = schemas )
85
+ self .assistant = assistant
86
+
54
87
def get_client (self ):
55
88
return self .client
56
89
@@ -65,25 +98,24 @@ def create_assistant(self):
65
98
for tool in self .tools :
66
99
if hasattr (tool , 'to_function' ):
67
100
tool_holder .append (tool .to_function ())
68
-
69
101
if len (tool_holder ) == 0 :
70
102
tool_holder = self .tools
71
103
72
- # Create and return the assistant
104
+ # Create and return the assistant with the combined tool definitions.
73
105
self .assistant = self .client .beta .assistants .create (
74
106
name = self .name ,
75
107
instructions = self .instructions ,
76
108
model = self .model ,
77
109
tools = tool_holder ,
78
110
tool_resources = self .tool_resources
79
111
)
80
- logger .debug ("Assistant created:" , self .assistant )
112
+ logger .debug ("Assistant created: %s " , self .assistant )
81
113
return self .assistant
82
114
83
115
def create_thread (self ):
84
- # Create and return a new thread
116
+ # Create and return a new thread.
85
117
thread = self .client .beta .threads .create ()
86
- logger .debug ("Thread generated:" , thread )
118
+ logger .debug ("Thread generated: %s " , thread )
87
119
return thread
88
120
89
121
def stream_thread (self , content , tool_choice = None , thread_id : str = None , thread = None , additional_instructions = None ):
@@ -112,7 +144,6 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
112
144
"event_handler" : event_handler ,
113
145
"additional_instructions" : additional_instructions
114
146
}
115
- # Conditionally add 'tool_choice' if it's not None
116
147
if tool_choice is not None :
117
148
args ["tool_choice" ] = tool_choice
118
149
@@ -121,8 +152,6 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
121
152
for text in stream .text_deltas :
122
153
yield text
123
154
124
- tool_call_results = None
125
- tool_call_arguments = None
126
155
self .tool_call_arguments = event_handler .arguments
127
156
if event_handler .stream is not None :
128
157
if event_handler .tool_call_results is not None :
@@ -133,7 +162,7 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
133
162
except Exception as e :
134
163
logger .error (e )
135
164
raise e
136
-
165
+
137
166
async def run_thread (self , content , tool = None , thread_id : str = None , thread = None , additional_instructions = None ):
138
167
if thread_id is not None :
139
168
thread = self .client .beta .threads .retrieve (thread_id )
@@ -142,10 +171,15 @@ async def run_thread(self, content, tool = None, thread_id: str = None, thread =
142
171
143
172
assistant = self .assistant
144
173
event_handler = AstraEventHandler (self .client )
174
+
145
175
tool_choice = None
146
176
if tool is not None :
147
177
event_handler .register_tool (tool )
148
178
tool_choice = tool .tool_choice_object ()
179
+
180
+ for tool in self .tools :
181
+ event_handler .register_tool (tool )
182
+
149
183
try :
150
184
self .client .beta .threads .messages .create (
151
185
thread_id = thread .id , role = "user" , content = content
@@ -156,33 +190,37 @@ async def run_thread(self, content, tool = None, thread_id: str = None, thread =
156
190
"event_handler" : event_handler ,
157
191
"additional_instructions" : additional_instructions
158
192
}
159
- # Conditionally add 'tool_choice' if it's not None
160
193
if tool_choice is not None :
161
194
args ["tool_choice" ] = tool_choice
162
195
163
196
text = ""
164
- with self .client .beta .threads .runs .create_and_stream (** args ) as stream :
197
+ with self .client .beta .threads .runs .stream (** args ) as stream :
165
198
for part in stream .text_deltas :
166
199
text += part
167
-
200
+
168
201
tool_call_results = None
169
202
if event_handler .stream is not None :
170
203
with event_handler .stream as stream :
171
204
for part in stream .text_deltas :
172
205
text += part
173
206
174
207
tool_call_results = event_handler .tool_call_results
175
- file_search = event_handler .file_search
208
+ if tool_call_results is not None :
209
+ file_search = event_handler .file_search
176
210
177
- tool_call_results ['file_search' ] = file_search
178
- tool_call_results ['text' ] = text
179
- tool_call_results ['arguments' ] = event_handler .arguments
211
+ tool_call_results ['file_search' ] = file_search
212
+ tool_call_results ['text' ] = text
213
+ tool_call_results ['arguments' ] = event_handler .arguments
214
+ else :
215
+ print ("event_handler.stream is not None but tool_call_results is None, bug?" )
180
216
181
217
logger .info (tool_call_results )
182
- tool_call_results
183
218
if tool_call_results is not None :
184
219
return tool_call_results
185
220
return {"text" : text , "file_search" : event_handler .file_search }
186
221
except Exception as e :
187
222
logger .error (e )
188
- raise e
223
+ raise e
224
+
225
+ def shutdown (self ):
226
+ self .mcp_adapter .shutdown ()
0 commit comments