Skip to content

Commit 4f6e907

Browse files
Shreyas-MicrosoftShreyas-Microsoft
andauthored
feat: working agent management (#626)
* working agent management * fix pylint * fix test cases * fix pylint --------- Co-authored-by: Shreyas-Microsoft <v-swaikar@microsft.com>
1 parent 61b4421 commit 4f6e907

File tree

4 files changed

+146
-29
lines changed

4 files changed

+146
-29
lines changed

src/App/app.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,25 @@ async def startup():
8181
app.wealth_advisor_agent = await AgentFactory.get_wealth_advisor_agent()
8282
logging.info("Wealth Advisor Agent initialized during application startup")
8383
app.search_agent = await AgentFactory.get_search_agent()
84-
logging.info(
85-
"Call Transcript Search Agent initialized during application startup"
86-
)
84+
logging.info("Call Transcript Search Agent initialized during application startup")
85+
app.sql_agent = await AgentFactory.get_sql_agent()
86+
logging.info("SQL Agent initialized during application startup")
8787

8888
@app.after_serving
8989
async def shutdown():
90-
await AgentFactory.delete_all_agent_instance()
91-
app.wealth_advisor_agent = None
92-
app.search_agent = None
93-
logging.info("Agents cleaned up during application shutdown")
90+
try:
91+
logging.info("Application shutdown initiated...")
92+
await AgentFactory.delete_all_agent_instance()
93+
if hasattr(app, 'wealth_advisor_agent'):
94+
app.wealth_advisor_agent = None
95+
if hasattr(app, 'search_agent'):
96+
app.search_agent = None
97+
if hasattr(app, 'sql_agent'):
98+
app.sql_agent = None
99+
logging.info("Agents cleaned up successfully")
100+
except Exception as e:
101+
logging.error(f"Error during shutdown: {e}")
102+
logging.exception("Detailed error during shutdown")
94103

95104
# app.secret_key = secrets.token_hex(16)
96105
# app.session_interface = SecureCookieSessionInterface()

src/App/backend/agents/agent_factory.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import asyncio
10+
import logging
1011
from typing import Optional
1112

1213
from azure.ai.projects import AIProjectClient
@@ -26,6 +27,7 @@ class AgentFactory:
2627
_lock = asyncio.Lock()
2728
_wealth_advisor_agent: Optional[AzureAIAgent] = None
2829
_search_agent: Optional[dict] = None
30+
_sql_agent: Optional[dict] = None
2931

3032
@classmethod
3133
async def get_wealth_advisor_agent(cls):
@@ -94,18 +96,67 @@ async def delete_all_agent_instance(cls):
9496
Delete the singleton AzureAIAgent instances if it exists.
9597
"""
9698
async with cls._lock:
97-
if cls._wealth_advisor_agent is not None:
98-
await cls._wealth_advisor_agent.client.agents.delete_agent(
99-
cls._wealth_advisor_agent.id
100-
)
101-
cls._wealth_advisor_agent = None
99+
logging.info("Starting agent deletion process...")
102100

101+
# Delete Wealth Advisor Agent
102+
if cls._wealth_advisor_agent is not None:
103+
try:
104+
agent_id = cls._wealth_advisor_agent.id
105+
logging.info(f"Deleting wealth advisor agent: {agent_id}")
106+
if hasattr(cls._wealth_advisor_agent, 'client') and cls._wealth_advisor_agent.client:
107+
await cls._wealth_advisor_agent.client.agents.delete_agent(agent_id)
108+
logging.info("Wealth advisor agent deleted successfully")
109+
else:
110+
logging.warning("Wealth advisor agent client is None")
111+
except Exception as e:
112+
logging.error(f"Error deleting wealth advisor agent: {e}")
113+
logging.exception("Detailed wealth advisor agent deletion error")
114+
finally:
115+
cls._wealth_advisor_agent = None
116+
117+
# Delete Search Agent
103118
if cls._search_agent is not None:
104-
cls._search_agent["client"].agents.delete_agent(
105-
cls._search_agent["agent"].id
106-
)
107-
cls._search_agent["client"].close()
108-
cls._search_agent = None
119+
try:
120+
agent_id = cls._search_agent['agent'].id
121+
logging.info(f"Deleting search agent: {agent_id}")
122+
if cls._search_agent.get("client") and hasattr(cls._search_agent["client"], "agents"):
123+
# AIProjectClient.agents.delete_agent is synchronous, don't await it
124+
cls._search_agent["client"].agents.delete_agent(agent_id)
125+
logging.info("Search agent deleted successfully")
126+
127+
# Close the client if it has a close method
128+
if hasattr(cls._search_agent["client"], "close"):
129+
cls._search_agent["client"].close()
130+
else:
131+
logging.warning("Search agent client is None or invalid")
132+
except Exception as e:
133+
logging.error(f"Error deleting search agent: {e}")
134+
logging.exception("Detailed search agent deletion error")
135+
finally:
136+
cls._search_agent = None
137+
138+
# Delete SQL Agent
139+
if cls._sql_agent is not None:
140+
try:
141+
agent_id = cls._sql_agent['agent'].id
142+
logging.info(f"Deleting SQL agent: {agent_id}")
143+
if cls._sql_agent.get("client") and hasattr(cls._sql_agent["client"], "agents"):
144+
# AIProjectClient.agents.delete_agent is synchronous, don't await it
145+
cls._sql_agent["client"].agents.delete_agent(agent_id)
146+
logging.info("SQL agent deleted successfully")
147+
148+
# Close the client if it has a close method
149+
if hasattr(cls._sql_agent["client"], "close"):
150+
cls._sql_agent["client"].close()
151+
else:
152+
logging.warning("SQL agent client is None or invalid")
153+
except Exception as e:
154+
logging.error(f"Error deleting SQL agent: {e}")
155+
logging.exception("Detailed SQL agent deletion error")
156+
finally:
157+
cls._sql_agent = None
158+
159+
logging.info("Agent deletion process completed")
109160

110161
@classmethod
111162
async def get_sql_agent(cls) -> dict:
@@ -114,7 +165,7 @@ async def get_sql_agent(cls) -> dict:
114165
This agent is used to generate T-SQL queries from natural language input.
115166
"""
116167
async with cls._lock:
117-
if not hasattr(cls, "_sql_agent") or cls._sql_agent is None:
168+
if cls._sql_agent is None:
118169

119170
agent_instructions = config.SQL_SYSTEM_PROMPT or """
120171
You are an expert assistant in generating T-SQL queries based on user questions.

src/App/backend/plugins/chat_with_data_plugin.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from azure.ai.projects import AIProjectClient
1212
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
1313
from semantic_kernel.functions.kernel_function_decorator import kernel_function
14+
from quart import current_app
1415

1516
from backend.common.config import config
1617
from backend.services.sqldb_service import get_connection
@@ -41,9 +42,14 @@ async def get_SQL_Response(
4142
if not input or not input.strip():
4243
return "Error: Query input is required"
4344

45+
thread = None
4446
try:
47+
# TEMPORARY: Use AgentFactory directly to debug the issue
48+
logging.info(f"Using AgentFactory directly for SQL agent for ClientId: {ClientId}")
4549
from backend.agents.agent_factory import AgentFactory
4650
agent_info = await AgentFactory.get_sql_agent()
51+
52+
logging.info(f"SQL agent retrieved: {agent_info is not None}")
4753
agent = agent_info["agent"]
4854
project_client = agent_info["client"]
4955

@@ -72,30 +78,42 @@ async def get_SQL_Response(
7278
role=MessageRole.AGENT
7379
)
7480
sql_query = message.text.value.strip() if message else None
81+
logging.info(f"Generated SQL query: {sql_query}")
7582

7683
if not sql_query:
7784
return "No SQL query was generated."
7885

7986
# Clean up triple backticks (if any)
8087
sql_query = sql_query.replace("```sql", "").replace("```", "")
88+
logging.info(f"Cleaned SQL query: {sql_query}")
8189

8290
# Execute the query
8391
conn = get_connection()
8492
cursor = conn.cursor()
8593
cursor.execute(sql_query)
8694
rows = cursor.fetchall()
95+
logging.info(f"Query returned {len(rows)} rows")
8796

8897
if not rows:
8998
result = "No data found for that client."
9099
else:
91100
result = "\n".join(str(row) for row in rows)
101+
logging.info(f"Result preview: {result[:200]}...")
92102

93103
conn.close()
94104

95105
return result[:20000] if len(result) > 20000 else result
96106
except Exception as e:
97107
logging.exception("Error in get_SQL_Response")
98108
return f"Error retrieving SQL data: {str(e)}"
109+
finally:
110+
if thread:
111+
try:
112+
logging.info(f"Attempting to delete thread {thread.id}")
113+
await project_client.agents.threads.delete(thread.id)
114+
logging.info(f"Thread {thread.id} deleted successfully")
115+
except Exception as e:
116+
logging.error(f"Error deleting thread {thread.id}: {str(e)}")
99117

100118
@kernel_function(
101119
name="ChatWithCallTranscripts",
@@ -114,12 +132,17 @@ async def get_answers_from_calltranscripts(
114132
if not question or not question.strip():
115133
return "Error: Question input is required"
116134

135+
thread = None
117136
try:
118137
response_text = ""
119138

120-
from backend.agents.agent_factory import AgentFactory
121-
122-
agent_info: dict = await AgentFactory.get_search_agent()
139+
# Use the singleton search agent from app context
140+
if not hasattr(current_app, 'search_agent') or current_app.search_agent is None:
141+
logging.error("Search agent not found in app context, falling back to AgentFactory")
142+
from backend.agents.agent_factory import AgentFactory
143+
agent_info = await AgentFactory.get_search_agent()
144+
else:
145+
agent_info = current_app.search_agent
123146

124147
agent: Agent = agent_info["agent"]
125148
project_client: AIProjectClient = agent_info["client"]
@@ -190,7 +213,11 @@ async def get_answers_from_calltranscripts(
190213

191214
finally:
192215
if thread:
193-
project_client.agents.threads.delete(thread.id)
216+
try:
217+
await project_client.agents.threads.delete(thread.id)
218+
logging.info(f"Thread {thread.id} deleted successfully")
219+
except Exception as e:
220+
logging.error(f"Error deleting thread {thread.id}: {str(e)}")
194221

195222
if not response_text.strip():
196223
return "No data found for that client."

src/App/tests/backend/plugins/test_chat_with_data_plugin.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,14 @@ async def test_get_sql_response_openai_error(self, mock_get_sql_agent, mock_conf
180180
assert "OpenAI API error" in result
181181

182182
@pytest.mark.asyncio
183+
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
183184
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
185+
@patch("backend.plugins.chat_with_data_plugin.config")
184186
async def test_get_answers_from_calltranscripts_success(
185-
self, mock_get_search_agent
187+
self, mock_config, mock_get_search_agent, mock_hasattr
186188
):
187189
"""Test successful retrieval of answers from call transcripts using AI Search Agent."""
188-
# Setup mocks for agent factory
190+
# Setup mocks for agent factory (fallback case when current_app.search_agent is None)
189191
mock_agent = MagicMock()
190192
mock_agent.id = "test-agent-id"
191193

@@ -195,6 +197,10 @@ async def test_get_answers_from_calltranscripts_success(
195197
"client": mock_project_client,
196198
}
197199

200+
# Mock config values
201+
mock_config.AZURE_SEARCH_INDEX = "test-index"
202+
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
203+
198204
# Mock project index creation
199205
mock_index = MagicMock()
200206
mock_index.name = "project-index-test"
@@ -229,7 +235,7 @@ async def test_get_answers_from_calltranscripts_success(
229235
assert "Based on call transcripts" in result
230236
assert "investment options" in result
231237

232-
# Verify agent factory was called
238+
# Verify agent factory was called (fallback case)
233239
mock_get_search_agent.assert_called_once()
234240

235241
# Verify project index was created/updated
@@ -249,9 +255,11 @@ async def test_get_answers_from_calltranscripts_success(
249255
mock_project_client.agents.runs.create_and_process.assert_called_once()
250256

251257
@pytest.mark.asyncio
258+
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
252259
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
260+
@patch("backend.plugins.chat_with_data_plugin.config")
253261
async def test_get_answers_from_calltranscripts_no_results(
254-
self, mock_get_search_agent
262+
self, mock_config, mock_get_search_agent, mock_hasattr
255263
):
256264
"""Test call transcripts search with no results."""
257265
# Setup mocks for agent factory
@@ -264,6 +272,10 @@ async def test_get_answers_from_calltranscripts_no_results(
264272
"client": mock_project_client,
265273
}
266274

275+
# Mock config values
276+
mock_config.AZURE_SEARCH_INDEX = "test-index"
277+
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
278+
267279
# Mock project index creation
268280
mock_index = MagicMock()
269281
mock_index.name = "project-index-test"
@@ -295,9 +307,11 @@ async def test_get_answers_from_calltranscripts_no_results(
295307
assert "No data found for that client." in result
296308

297309
@pytest.mark.asyncio
310+
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
298311
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
312+
@patch("backend.plugins.chat_with_data_plugin.config")
299313
async def test_get_answers_from_calltranscripts_openai_error(
300-
self, mock_get_search_agent
314+
self, mock_config, mock_get_search_agent, mock_hasattr
301315
):
302316
"""Test call transcripts with AI Search processing error."""
303317
# Setup mocks for agent factory
@@ -310,6 +324,10 @@ async def test_get_answers_from_calltranscripts_openai_error(
310324
"client": mock_project_client,
311325
}
312326

327+
# Mock config values
328+
mock_config.AZURE_SEARCH_INDEX = "test-index"
329+
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
330+
313331
# Mock project index creation
314332
mock_index = MagicMock()
315333
mock_index.name = "project-index-test"
@@ -336,9 +354,11 @@ async def test_get_answers_from_calltranscripts_openai_error(
336354
assert "Error retrieving data from call transcripts" in result
337355

338356
@pytest.mark.asyncio
357+
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
339358
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
359+
@patch("backend.plugins.chat_with_data_plugin.config")
340360
async def test_get_answers_from_calltranscripts_failed_run(
341-
self, mock_get_search_agent
361+
self, mock_config, mock_get_search_agent, mock_hasattr
342362
):
343363
"""Test call transcripts with failed AI Search run."""
344364
# Setup mocks for agent factory
@@ -351,6 +371,10 @@ async def test_get_answers_from_calltranscripts_failed_run(
351371
"client": mock_project_client,
352372
}
353373

374+
# Mock config values
375+
mock_config.AZURE_SEARCH_INDEX = "test-index"
376+
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
377+
354378
# Mock project index creation
355379
mock_index = MagicMock()
356380
mock_index.name = "project-index-test"
@@ -378,9 +402,11 @@ async def test_get_answers_from_calltranscripts_failed_run(
378402
assert "Error retrieving data from call transcripts" in result
379403

380404
@pytest.mark.asyncio
405+
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
381406
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
407+
@patch("backend.plugins.chat_with_data_plugin.config")
382408
async def test_get_answers_from_calltranscripts_empty_response(
383-
self, mock_get_search_agent
409+
self, mock_config, mock_get_search_agent, mock_hasattr
384410
):
385411
"""Test call transcripts with empty response text."""
386412
# Setup mocks for agent factory
@@ -393,6 +419,10 @@ async def test_get_answers_from_calltranscripts_empty_response(
393419
"client": mock_project_client,
394420
}
395421

422+
# Mock config values
423+
mock_config.AZURE_SEARCH_INDEX = "test-index"
424+
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
425+
396426
# Mock project index creation
397427
mock_index = MagicMock()
398428
mock_index.name = "project-index-test"

0 commit comments

Comments
 (0)