diff --git a/tests/unit/vertexai/genai/test_genai_client.py b/tests/unit/vertexai/genai/test_genai_client.py index 2e1437fb42..850bd0c501 100644 --- a/tests/unit/vertexai/genai/test_genai_client.py +++ b/tests/unit/vertexai/genai/test_genai_client.py @@ -17,9 +17,11 @@ import importlib import pytest +from unittest import mock from google.cloud import aiplatform import vertexai +from vertexai._genai import client as vertexai_client from google.cloud.aiplatform import initializer as aiplatform_initializer @@ -66,3 +68,28 @@ def test_live_client(self): def test_types(self): assert vertexai.types is not None assert vertexai.types.LLMMetric is not None + + @pytest.mark.asyncio + @pytest.mark.usefixtures("google_auth_mock") + async def test_async_content_manager(self): + with mock.patch.object( + vertexai_client.AsyncClient, "aclose", autospec=True + ) as mock_aclose: + async with vertexai.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ).aio as async_client: + assert isinstance(async_client, vertexai_client.AsyncClient) + + mock_aclose.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.usefixtures("google_auth_mock") + async def test_call_aclose_async_client(self): + with mock.patch.object( + vertexai_client.AsyncClient, "aclose", autospec=True + ) as mock_aclose: + async_client = vertexai.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ).aio + await async_client.aclose() + mock_aclose.assert_called_once() diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py index 3418ff7c5a..e0ad1afb22 100644 --- a/vertexai/_genai/client.py +++ b/vertexai/_genai/client.py @@ -13,8 +13,10 @@ # limitations under the License. # +import asyncio import importlib from typing import Optional, Union, Any +from types import TracebackType import google.auth from google.cloud.aiplatform import version as aip_version @@ -48,7 +50,7 @@ def _add_tracking_headers(headers: dict[str, str]) -> None: class AsyncClient: """Async Gen AI Client for the Vertex SDK.""" - def __init__(self, api_client: genai_client.Client): + def __init__(self, api_client: genai_client.BaseApiClient): self._api_client = api_client self._live = live.AsyncLive(self._api_client) self._evals = None @@ -132,6 +134,40 @@ def datasets(self): ) return self._datasets.AsyncDatasets(self._api_client) + async def aclose(self) -> None: + """Closes the async client explicitly. + + Example usage: + + from vertexai import Client + + async_client = vertexai.Client( + project='my-project-id', location='us-central1' + ).aio + prompt_1 = await async_client.prompts.create(...) + prompt_2 = await async_client.prompts.create(...) + # Close the client to release resources. + await async_client.aclose() + """ + await self._api_client.aclose() + + async def __aenter__(self) -> "AsyncClient": + return self + + async def __aexit__( + self, + exc_type: Optional[Exception], + exc_value: Optional[Exception], + traceback: Optional[TracebackType], + ) -> None: + await self.aclose() + + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.aclose()) + except Exception: + pass + class Client: """Gen AI Client for the Vertex SDK.