Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/unit/vertexai/genai/test_genai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

import importlib
import pytest
from unittest import mock

from google.cloud import aiplatform
import vertexai
from vertexai._genai import client as vertexai_client
from google.cloud.aiplatform import initializer as aiplatform_initializer


Expand Down Expand Up @@ -66,3 +68,28 @@ def test_live_client(self):
def test_types(self):
assert vertexai.types is not None
assert vertexai.types.LLMMetric is not None

@pytest.mark.asyncio
@pytest.mark.usefixtures("google_auth_mock")
async def test_async_content_manager(self):
with mock.patch.object(
vertexai_client.AsyncClient, "aclose", autospec=True
) as mock_aclose:
async with vertexai.Client(
project=_TEST_PROJECT, location=_TEST_LOCATION
).aio as async_client:
assert isinstance(async_client, vertexai_client.AsyncClient)

mock_aclose.assert_called_once()

@pytest.mark.asyncio
@pytest.mark.usefixtures("google_auth_mock")
async def test_call_aclose_async_client(self):
with mock.patch.object(
vertexai_client.AsyncClient, "aclose", autospec=True
) as mock_aclose:
async_client = vertexai.Client(
project=_TEST_PROJECT, location=_TEST_LOCATION
).aio
await async_client.aclose()
mock_aclose.assert_called_once()
38 changes: 37 additions & 1 deletion vertexai/_genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
#

import asyncio
import importlib
from typing import Optional, Union, Any
from types import TracebackType

import google.auth
from google.cloud.aiplatform import version as aip_version
Expand Down Expand Up @@ -48,7 +50,7 @@ def _add_tracking_headers(headers: dict[str, str]) -> None:
class AsyncClient:
"""Async Gen AI Client for the Vertex SDK."""

def __init__(self, api_client: genai_client.Client):
def __init__(self, api_client: genai_client.BaseApiClient):
self._api_client = api_client
self._live = live.AsyncLive(self._api_client)
self._evals = None
Expand Down Expand Up @@ -132,6 +134,40 @@ def datasets(self):
)
return self._datasets.AsyncDatasets(self._api_client)

async def aclose(self) -> None:
"""Closes the async client explicitly.

Example usage:

from vertexai import Client

async_client = vertexai.Client(
project='my-project-id', location='us-central1'
).aio
prompt_1 = await async_client.prompts.create(...)
prompt_2 = await async_client.prompts.create(...)
# Close the client to release resources.
await async_client.aclose()
"""
await self._api_client.aclose()

async def __aenter__(self) -> "AsyncClient":
return self

async def __aexit__(
self,
exc_type: Optional[Exception],
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> None:
await self.aclose()

def __del__(self) -> None:
try:
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass


class Client:
"""Gen AI Client for the Vertex SDK.
Expand Down
Loading