diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 5e7b03f57f..2c6442818d 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -1070,9 +1070,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( ) mock_agent_engine = mock.Mock() - mock_agent_engine.async_create_session = mock.AsyncMock( - return_value={"id": "session1"} - ) + mock_agent_engine.create_session.return_value = {"id": "session1"} stream_query_return_value = [ { "id": "1", @@ -1088,13 +1086,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( }, ] - async def _async_iterator(iterable): - for item in iterable: - yield item - - mock_agent_engine.async_stream_query.return_value = _async_iterator( - stream_query_return_value - ) + mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) mock_vertexai_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -1108,10 +1100,10 @@ async def _async_iterator(iterable): mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123" ) - mock_agent_engine.async_create_session.assert_called_once_with( + mock_agent_engine.create_session.assert_called_once_with( user_id="123", state={"a": "1"} ) - mock_agent_engine.async_stream_query.assert_called_once_with( + mock_agent_engine.stream_query.assert_called_once_with( user_id="123", session_id="session1", message="agent prompt" ) @@ -1162,9 +1154,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( ) mock_agent_engine = mock.Mock() - mock_agent_engine.async_create_session = mock.AsyncMock( - return_value={"id": "session1"} - ) + mock_agent_engine.create_session.return_value = {"id": "session1"} stream_query_return_value = [ { "id": "1", @@ -1180,13 +1170,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( }, ] - async def _async_iterator(iterable): - for item in iterable: - yield item - - mock_agent_engine.async_stream_query.return_value = _async_iterator( - stream_query_return_value - ) + mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) mock_vertexai_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -1200,10 +1184,10 @@ async def _async_iterator(iterable): mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123" ) - mock_agent_engine.async_create_session.assert_called_once_with( + mock_agent_engine.create_session.assert_called_once_with( user_id="123", state={"a": "1"} ) - mock_agent_engine.async_stream_query.assert_called_once_with( + mock_agent_engine.stream_query.assert_called_once_with( user_id="123", session_id="session1", message="agent prompt" ) diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index 8388784272..341ae88a7f 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -278,12 +278,10 @@ def agent_run_wrapper( and type(agent_engine).__name__ == "AgentEngine" ): agent_engine_instance = agent_engine - return asyncio.run( - inference_fn_arg( - row=row_arg, - contents=contents_arg, - agent_engine=agent_engine_instance, - ) + return inference_fn_arg( + row=row_arg, + contents=contents_arg, + agent_engine=agent_engine_instance, ) future = executor.submit( @@ -1265,7 +1263,7 @@ def _run_agent( ) -async def _execute_agent_run_with_retry( +def _execute_agent_run_with_retry( row: pd.Series, contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], agent_engine: types.AgentEngine, @@ -1287,7 +1285,7 @@ async def _execute_agent_run_with_retry( ) user_id = session_inputs.user_id session_state = session_inputs.state - session = await agent_engine.async_create_session( + session = agent_engine.create_session( user_id=user_id, state=session_state, ) @@ -1298,7 +1296,7 @@ async def _execute_agent_run_with_retry( for attempt in range(max_retries): try: responses = [] - async for event in agent_engine.async_stream_query( + for event in agent_engine.stream_query( user_id=user_id, session_id=session["id"], message=contents, @@ -1317,7 +1315,7 @@ async def _execute_agent_run_with_retry( ) if attempt == max_retries - 1: return {"error": f"Resource exhausted after retries: {e}"} - await asyncio.sleep(2**attempt) + time.sleep(2**attempt) except Exception as e: # pylint: disable=broad-exception-caught logger.error( "Unexpected error during generate_content on attempt %d/%d: %s", @@ -1328,7 +1326,7 @@ async def _execute_agent_run_with_retry( if attempt == max_retries - 1: return {"error": f"Failed after retries: {e}"} - await asyncio.sleep(1) + time.sleep(1) return {"error": f"Failed to get agent run results after {max_retries} retries"}