Skip to content

Commit 53dd2a8

Browse files
Merge pull request #631 from microsoft/psl-sfi-changesr2
fix: Refactor Azure Authentication and Update Infra Config
2 parents e96b8bf + 5edf74b commit 53dd2a8

File tree

4 files changed

+30
-146
lines changed

4 files changed

+30
-146
lines changed

src/App/app.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,25 +82,17 @@ async def startup():
8282
app.wealth_advisor_agent = await AgentFactory.get_wealth_advisor_agent()
8383
logging.info("Wealth Advisor Agent initialized during application startup")
8484
app.search_agent = await AgentFactory.get_search_agent()
85-
logging.info("Call Transcript Search Agent initialized during application startup")
86-
app.sql_agent = await AgentFactory.get_sql_agent()
87-
logging.info("SQL Agent initialized during application startup")
85+
logging.info(
86+
"Call Transcript Search Agent initialized during application startup"
87+
)
8888

8989
@app.after_serving
9090
async def shutdown():
91-
try:
92-
logging.info("Application shutdown initiated...")
93-
await AgentFactory.delete_all_agent_instance()
94-
if hasattr(app, 'wealth_advisor_agent'):
95-
app.wealth_advisor_agent = None
96-
if hasattr(app, 'search_agent'):
97-
app.search_agent = None
98-
if hasattr(app, 'sql_agent'):
99-
app.sql_agent = None
100-
logging.info("Agents cleaned up successfully")
101-
except Exception as e:
102-
logging.error(f"Error during shutdown: {e}")
103-
logging.exception("Detailed error during shutdown")
91+
await AgentFactory.delete_all_agent_instance()
92+
app.wealth_advisor_agent = None
93+
app.search_agent = None
94+
app.sql_agent = None
95+
logging.info("Agents cleaned up during application shutdown")
10496

10597
# app.secret_key = secrets.token_hex(16)
10698
# app.session_interface = SecureCookieSessionInterface()

src/App/backend/agents/agent_factory.py

Lines changed: 11 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88

99
import asyncio
10-
import logging
1110
from typing import Optional
1211

1312
from azure.ai.projects import AIProjectClient
@@ -27,7 +26,6 @@ class AgentFactory:
2726
_lock = asyncio.Lock()
2827
_wealth_advisor_agent: Optional[AzureAIAgent] = None
2928
_search_agent: Optional[dict] = None
30-
_sql_agent: Optional[dict] = None
3129

3230
@classmethod
3331
async def get_wealth_advisor_agent(cls):
@@ -96,67 +94,18 @@ async def delete_all_agent_instance(cls):
9694
Delete the singleton AzureAIAgent instances if it exists.
9795
"""
9896
async with cls._lock:
99-
logging.info("Starting agent deletion process...")
100-
101-
# Delete Wealth Advisor Agent
10297
if cls._wealth_advisor_agent is not None:
103-
try:
104-
agent_id = cls._wealth_advisor_agent.id
105-
logging.info(f"Deleting wealth advisor agent: {agent_id}")
106-
if hasattr(cls._wealth_advisor_agent, 'client') and cls._wealth_advisor_agent.client:
107-
await cls._wealth_advisor_agent.client.agents.delete_agent(agent_id)
108-
logging.info("Wealth advisor agent deleted successfully")
109-
else:
110-
logging.warning("Wealth advisor agent client is None")
111-
except Exception as e:
112-
logging.error(f"Error deleting wealth advisor agent: {e}")
113-
logging.exception("Detailed wealth advisor agent deletion error")
114-
finally:
115-
cls._wealth_advisor_agent = None
116-
117-
# Delete Search Agent
98+
await cls._wealth_advisor_agent.client.agents.delete_agent(
99+
cls._wealth_advisor_agent.id
100+
)
101+
cls._wealth_advisor_agent = None
102+
118103
if cls._search_agent is not None:
119-
try:
120-
agent_id = cls._search_agent['agent'].id
121-
logging.info(f"Deleting search agent: {agent_id}")
122-
if cls._search_agent.get("client") and hasattr(cls._search_agent["client"], "agents"):
123-
# AIProjectClient.agents.delete_agent is synchronous, don't await it
124-
cls._search_agent["client"].agents.delete_agent(agent_id)
125-
logging.info("Search agent deleted successfully")
126-
127-
# Close the client if it has a close method
128-
if hasattr(cls._search_agent["client"], "close"):
129-
cls._search_agent["client"].close()
130-
else:
131-
logging.warning("Search agent client is None or invalid")
132-
except Exception as e:
133-
logging.error(f"Error deleting search agent: {e}")
134-
logging.exception("Detailed search agent deletion error")
135-
finally:
136-
cls._search_agent = None
137-
138-
# Delete SQL Agent
139-
if cls._sql_agent is not None:
140-
try:
141-
agent_id = cls._sql_agent['agent'].id
142-
logging.info(f"Deleting SQL agent: {agent_id}")
143-
if cls._sql_agent.get("client") and hasattr(cls._sql_agent["client"], "agents"):
144-
# AIProjectClient.agents.delete_agent is synchronous, don't await it
145-
cls._sql_agent["client"].agents.delete_agent(agent_id)
146-
logging.info("SQL agent deleted successfully")
147-
148-
# Close the client if it has a close method
149-
if hasattr(cls._sql_agent["client"], "close"):
150-
cls._sql_agent["client"].close()
151-
else:
152-
logging.warning("SQL agent client is None or invalid")
153-
except Exception as e:
154-
logging.error(f"Error deleting SQL agent: {e}")
155-
logging.exception("Detailed SQL agent deletion error")
156-
finally:
157-
cls._sql_agent = None
158-
159-
logging.info("Agent deletion process completed")
104+
cls._search_agent["client"].agents.delete_agent(
105+
cls._search_agent["agent"].id
106+
)
107+
cls._search_agent["client"].close()
108+
cls._search_agent = None
160109

161110
@classmethod
162111
async def get_sql_agent(cls) -> dict:
@@ -165,7 +114,7 @@ async def get_sql_agent(cls) -> dict:
165114
This agent is used to generate T-SQL queries from natural language input.
166115
"""
167116
async with cls._lock:
168-
if cls._sql_agent is None:
117+
if not hasattr(cls, "_sql_agent") or cls._sql_agent is None:
169118

170119
agent_instructions = config.SQL_SYSTEM_PROMPT or """
171120
You are an expert assistant in generating T-SQL queries based on user questions.

src/App/backend/plugins/chat_with_data_plugin.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from azure.identity import get_bearer_token_provider
1313
from backend.helpers.azure_credential_utils import get_azure_credential
1414
from semantic_kernel.functions.kernel_function_decorator import kernel_function
15-
from quart import current_app
1615

1716
from backend.common.config import config
1817
from backend.services.sqldb_service import get_connection
@@ -43,14 +42,9 @@ async def get_SQL_Response(
4342
if not input or not input.strip():
4443
return "Error: Query input is required"
4544

46-
thread = None
4745
try:
48-
# TEMPORARY: Use AgentFactory directly to debug the issue
49-
logging.info(f"Using AgentFactory directly for SQL agent for ClientId: {ClientId}")
5046
from backend.agents.agent_factory import AgentFactory
5147
agent_info = await AgentFactory.get_sql_agent()
52-
53-
logging.info(f"SQL agent retrieved: {agent_info is not None}")
5448
agent = agent_info["agent"]
5549
project_client = agent_info["client"]
5650

@@ -79,42 +73,30 @@ async def get_SQL_Response(
7973
role=MessageRole.AGENT
8074
)
8175
sql_query = message.text.value.strip() if message else None
82-
logging.info(f"Generated SQL query: {sql_query}")
8376

8477
if not sql_query:
8578
return "No SQL query was generated."
8679

8780
# Clean up triple backticks (if any)
8881
sql_query = sql_query.replace("```sql", "").replace("```", "")
89-
logging.info(f"Cleaned SQL query: {sql_query}")
9082

9183
# Execute the query
9284
conn = get_connection()
9385
cursor = conn.cursor()
9486
cursor.execute(sql_query)
9587
rows = cursor.fetchall()
96-
logging.info(f"Query returned {len(rows)} rows")
9788

9889
if not rows:
9990
result = "No data found for that client."
10091
else:
10192
result = "\n".join(str(row) for row in rows)
102-
logging.info(f"Result preview: {result[:200]}...")
10393

10494
conn.close()
10595

10696
return result[:20000] if len(result) > 20000 else result
10797
except Exception as e:
10898
logging.exception("Error in get_SQL_Response")
10999
return f"Error retrieving SQL data: {str(e)}"
110-
finally:
111-
if thread:
112-
try:
113-
logging.info(f"Attempting to delete thread {thread.id}")
114-
await project_client.agents.threads.delete(thread.id)
115-
logging.info(f"Thread {thread.id} deleted successfully")
116-
except Exception as e:
117-
logging.error(f"Error deleting thread {thread.id}: {str(e)}")
118100

119101
@kernel_function(
120102
name="ChatWithCallTranscripts",
@@ -133,17 +115,12 @@ async def get_answers_from_calltranscripts(
133115
if not question or not question.strip():
134116
return "Error: Question input is required"
135117

136-
thread = None
137118
try:
138119
response_text = ""
139120

140-
# Use the singleton search agent from app context
141-
if not hasattr(current_app, 'search_agent') or current_app.search_agent is None:
142-
logging.error("Search agent not found in app context, falling back to AgentFactory")
143-
from backend.agents.agent_factory import AgentFactory
144-
agent_info = await AgentFactory.get_search_agent()
145-
else:
146-
agent_info = current_app.search_agent
121+
from backend.agents.agent_factory import AgentFactory
122+
123+
agent_info: dict = await AgentFactory.get_search_agent()
147124

148125
agent: Agent = agent_info["agent"]
149126
project_client: AIProjectClient = agent_info["client"]
@@ -214,11 +191,7 @@ async def get_answers_from_calltranscripts(
214191

215192
finally:
216193
if thread:
217-
try:
218-
await project_client.agents.threads.delete(thread.id)
219-
logging.info(f"Thread {thread.id} deleted successfully")
220-
except Exception as e:
221-
logging.error(f"Error deleting thread {thread.id}: {str(e)}")
194+
project_client.agents.threads.delete(thread.id)
222195

223196
if not response_text.strip():
224197
return "No data found for that client."

src/App/tests/backend/plugins/test_chat_with_data_plugin.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,12 @@ async def test_get_sql_response_openai_error(self, mock_get_sql_agent, mock_conf
180180
assert "OpenAI API error" in result
181181

182182
@pytest.mark.asyncio
183-
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
184183
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
185-
@patch("backend.plugins.chat_with_data_plugin.config")
186184
async def test_get_answers_from_calltranscripts_success(
187-
self, mock_config, mock_get_search_agent, mock_hasattr
185+
self, mock_get_search_agent
188186
):
189187
"""Test successful retrieval of answers from call transcripts using AI Search Agent."""
190-
# Setup mocks for agent factory (fallback case when current_app.search_agent is None)
188+
# Setup mocks for agent factory
191189
mock_agent = MagicMock()
192190
mock_agent.id = "test-agent-id"
193191

@@ -197,10 +195,6 @@ async def test_get_answers_from_calltranscripts_success(
197195
"client": mock_project_client,
198196
}
199197

200-
# Mock config values
201-
mock_config.AZURE_SEARCH_INDEX = "test-index"
202-
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
203-
204198
# Mock project index creation
205199
mock_index = MagicMock()
206200
mock_index.name = "project-index-test"
@@ -235,7 +229,7 @@ async def test_get_answers_from_calltranscripts_success(
235229
assert "Based on call transcripts" in result
236230
assert "investment options" in result
237231

238-
# Verify agent factory was called (fallback case)
232+
# Verify agent factory was called
239233
mock_get_search_agent.assert_called_once()
240234

241235
# Verify project index was created/updated
@@ -255,11 +249,9 @@ async def test_get_answers_from_calltranscripts_success(
255249
mock_project_client.agents.runs.create_and_process.assert_called_once()
256250

257251
@pytest.mark.asyncio
258-
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
259252
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
260-
@patch("backend.plugins.chat_with_data_plugin.config")
261253
async def test_get_answers_from_calltranscripts_no_results(
262-
self, mock_config, mock_get_search_agent, mock_hasattr
254+
self, mock_get_search_agent
263255
):
264256
"""Test call transcripts search with no results."""
265257
# Setup mocks for agent factory
@@ -272,10 +264,6 @@ async def test_get_answers_from_calltranscripts_no_results(
272264
"client": mock_project_client,
273265
}
274266

275-
# Mock config values
276-
mock_config.AZURE_SEARCH_INDEX = "test-index"
277-
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
278-
279267
# Mock project index creation
280268
mock_index = MagicMock()
281269
mock_index.name = "project-index-test"
@@ -307,11 +295,9 @@ async def test_get_answers_from_calltranscripts_no_results(
307295
assert "No data found for that client." in result
308296

309297
@pytest.mark.asyncio
310-
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
311298
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
312-
@patch("backend.plugins.chat_with_data_plugin.config")
313299
async def test_get_answers_from_calltranscripts_openai_error(
314-
self, mock_config, mock_get_search_agent, mock_hasattr
300+
self, mock_get_search_agent
315301
):
316302
"""Test call transcripts with AI Search processing error."""
317303
# Setup mocks for agent factory
@@ -324,10 +310,6 @@ async def test_get_answers_from_calltranscripts_openai_error(
324310
"client": mock_project_client,
325311
}
326312

327-
# Mock config values
328-
mock_config.AZURE_SEARCH_INDEX = "test-index"
329-
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
330-
331313
# Mock project index creation
332314
mock_index = MagicMock()
333315
mock_index.name = "project-index-test"
@@ -354,11 +336,9 @@ async def test_get_answers_from_calltranscripts_openai_error(
354336
assert "Error retrieving data from call transcripts" in result
355337

356338
@pytest.mark.asyncio
357-
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
358339
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
359-
@patch("backend.plugins.chat_with_data_plugin.config")
360340
async def test_get_answers_from_calltranscripts_failed_run(
361-
self, mock_config, mock_get_search_agent, mock_hasattr
341+
self, mock_get_search_agent
362342
):
363343
"""Test call transcripts with failed AI Search run."""
364344
# Setup mocks for agent factory
@@ -371,10 +351,6 @@ async def test_get_answers_from_calltranscripts_failed_run(
371351
"client": mock_project_client,
372352
}
373353

374-
# Mock config values
375-
mock_config.AZURE_SEARCH_INDEX = "test-index"
376-
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
377-
378354
# Mock project index creation
379355
mock_index = MagicMock()
380356
mock_index.name = "project-index-test"
@@ -402,11 +378,9 @@ async def test_get_answers_from_calltranscripts_failed_run(
402378
assert "Error retrieving data from call transcripts" in result
403379

404380
@pytest.mark.asyncio
405-
@patch("backend.plugins.chat_with_data_plugin.hasattr", return_value=False)
406381
@patch("backend.agents.agent_factory.AgentFactory.get_search_agent")
407-
@patch("backend.plugins.chat_with_data_plugin.config")
408382
async def test_get_answers_from_calltranscripts_empty_response(
409-
self, mock_config, mock_get_search_agent, mock_hasattr
383+
self, mock_get_search_agent
410384
):
411385
"""Test call transcripts with empty response text."""
412386
# Setup mocks for agent factory
@@ -419,10 +393,6 @@ async def test_get_answers_from_calltranscripts_empty_response(
419393
"client": mock_project_client,
420394
}
421395

422-
# Mock config values
423-
mock_config.AZURE_SEARCH_INDEX = "test-index"
424-
mock_config.AZURE_SEARCH_CONNECTION_NAME = "test-connection"
425-
426396
# Mock project index creation
427397
mock_index = MagicMock()
428398
mock_index.name = "project-index-test"

0 commit comments

Comments
 (0)