diff --git a/google/genai/_interactions/__init__.py b/google/genai/_interactions/__init__.py index 02d45e4d6..f963648c6 100644 --- a/google/genai/_interactions/__init__.py +++ b/google/genai/_interactions/__init__.py @@ -52,6 +52,7 @@ GeminiNextGenAPIClientError, ) from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient +from ._client_adapter import GeminiNextGenAPIClientAdapter from ._utils._logs import setup_logging as _setup_logging __all__ = [ @@ -96,6 +97,7 @@ "DefaultHttpxClient", "DefaultAsyncHttpxClient", "DefaultAioHttpClient", + "GeminiNextGenAPIClientAdapter", ] if not _t.TYPE_CHECKING: diff --git a/google/genai/_interactions/_client.py b/google/genai/_interactions/_client.py index 6505df88e..24273a3ed 100644 --- a/google/genai/_interactions/_client.py +++ b/google/genai/_interactions/_client.py @@ -45,6 +45,8 @@ SyncAPIClient, AsyncAPIClient, ) +from ._client_adapter import GeminiNextGenAPIClientAdapter +from ._models import FinalRequestOptions if TYPE_CHECKING: from .resources import interactions @@ -66,12 +68,13 @@ class GeminiNextGenAPIClient(SyncAPIClient): # client options api_key: str | None api_version: str + client_adapter: GeminiNextGenAPIClientAdapter def __init__( self, *, api_key: str | None = None, - api_version: str | None = "v1beta", + api_version: str | None = "v1alpha", base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, @@ -81,6 +84,7 @@ def __init__( # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: httpx.Client | None = None, + client_adapter: GeminiNextGenAPIClientAdapter, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -100,7 +104,7 @@ def __init__( self.api_key = api_key if api_version is None: - api_version = "v1beta" + api_version = "v1alpha" self.api_version = api_version if base_url is None: @@ -108,6 +112,8 @@ def __init__( if base_url is None: base_url = f"https://generativelanguage.googleapis.com" + self.client_adapter = client_adapter + super().__init__( version=__version__, base_url=base_url, @@ -144,9 +150,13 @@ def qs(self) -> Querystring: @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key - if api_key is None: - return {} - return {"x-goog-api-key": api_key} + if api_key is not None: + return {"x-goog-api-key": api_key} + + if self.client_adapter.is_vertex_ai(): + return self.client_adapter.get_auth_headers() or {} + + return {} @property @override @@ -159,14 +169,24 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: + if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit): + return if self.api_key and headers.get("x-goog-api-key"): return - if isinstance(custom_headers.get("x-goog-api-key"), Omit): + if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit): return raise TypeError( '"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"' ) + + @override + def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + if self.client_adapter.is_vertex_ai() and not options.url.startswith(f"/{self.api_version}/projects/"): + old_url = options.url[len(self.api_version) + 1:] + options.url = f"{self.api_version}/projects/{self.client_adapter.get_project()}/locations/{self.client_adapter.get_location()}{old_url}" + + return options def copy( self, @@ -181,6 +201,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + client_adapter: GeminiNextGenAPIClientAdapter | NotGiven = not_given, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -214,6 +235,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + client_adapter=client_adapter if is_given(client_adapter) else self.client_adapter, **_extra_kwargs, ) @@ -262,12 +284,13 @@ class AsyncGeminiNextGenAPIClient(AsyncAPIClient): # client options api_key: str | None api_version: str + client_adapter: GeminiNextGenAPIClientAdapter def __init__( self, *, api_key: str | None = None, - api_version: str | None = "v1beta", + api_version: str | None = "v1alpha", base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, max_retries: int = DEFAULT_MAX_RETRIES, @@ -277,6 +300,7 @@ def __init__( # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. http_client: httpx.AsyncClient | None = None, + client_adapter: GeminiNextGenAPIClientAdapter, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -296,7 +320,7 @@ def __init__( self.api_key = api_key if api_version is None: - api_version = "v1beta" + api_version = "v1alpha" self.api_version = api_version if base_url is None: @@ -304,6 +328,8 @@ def __init__( if base_url is None: base_url = f"https://generativelanguage.googleapis.com" + self.client_adapter = client_adapter + super().__init__( version=__version__, base_url=base_url, @@ -340,9 +366,13 @@ def qs(self) -> Querystring: @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key - if api_key is None: - return {} - return {"x-goog-api-key": api_key} + if api_key is not None: + return {"x-goog-api-key": api_key} + + if self.client_adapter.is_vertex_ai(): + return self.client_adapter.get_auth_headers() or {} + + return {} @property @override @@ -355,14 +385,24 @@ def default_headers(self) -> dict[str, str | Omit]: @override def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None: + if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit): + return if self.api_key and headers.get("x-goog-api-key"): return - if isinstance(custom_headers.get("x-goog-api-key"), Omit): + if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit): return raise TypeError( '"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"' ) + + @override + async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: + if self.client_adapter.is_vertex_ai() and not options.url.startswith(f"/{self.api_version}/projects/"): + old_url = options.url[len(self.api_version) + 1:] + options.url = f"{self.api_version}/projects/{self.client_adapter.get_project()}/locations/{self.client_adapter.get_location()}{old_url}" + + return options def copy( self, @@ -377,6 +417,7 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, + client_adapter: GeminiNextGenAPIClientAdapter | NotGiven = not_given, _extra_kwargs: Mapping[str, Any] = {}, ) -> Self: """ @@ -410,6 +451,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + client_adapter=client_adapter if is_given(client_adapter) else self.client_adapter, **_extra_kwargs, ) diff --git a/google/genai/_interactions/_client_adapter.py b/google/genai/_interactions/_client_adapter.py new file mode 100644 index 000000000..be5d44b1c --- /dev/null +++ b/google/genai/_interactions/_client_adapter.py @@ -0,0 +1,21 @@ +import typing +from abc import ABC, abstractmethod + +from ._types import Headers + +class GeminiNextGenAPIClientAdapter(ABC): + @abstractmethod + def is_vertex_ai(self) -> bool: + ... + + @abstractmethod + def get_project(self) -> str | None: + ... + + @abstractmethod + def get_location(self) -> str | None: + ... + + @abstractmethod + def get_auth_headers(self) -> dict[str, str] | None: + ... diff --git a/google/genai/client.py b/google/genai/client.py index 269b592cf..686eb4af6 100644 --- a/google/genai/client.py +++ b/google/genai/client.py @@ -43,15 +43,13 @@ from . import _common -from ._interactions import AsyncGeminiNextGenAPIClient, DEFAULT_MAX_RETRIES, DefaultAioHttpClient, GeminiNextGenAPIClient -from ._interactions._models import FinalRequestOptions -from ._interactions._types import Headers -from ._interactions._utils import is_given +from ._interactions import AsyncGeminiNextGenAPIClient, DEFAULT_MAX_RETRIES, DefaultAioHttpClient, GeminiNextGenAPIClient, GeminiNextGenAPIClientAdapter + from ._interactions.resources import AsyncInteractionsResource as AsyncNextGenInteractionsResource, InteractionsResource as NextGenInteractionsResource _interactions_experimental_warned = False -class AsyncClient: +class AsyncClient(GeminiNextGenAPIClientAdapter): """Client for making asynchronous (non-blocking) requests.""" def __init__(self, api_client: BaseApiClient): @@ -68,6 +66,21 @@ def __init__(self, api_client: BaseApiClient): self._operations = AsyncOperations(self._api_client) self._nextgen_client_instance: Optional[AsyncGeminiNextGenAPIClient] = None + def is_vertex_ai(self) -> bool: + return self._api_client.vertexai or False + + def get_project(self) -> str | None: + return self._api_client.project + + def get_location(self) -> str | None: + return self._api_client.location + + def get_auth_headers(self) -> dict[str, str]: + access_token = asyncio.run(self._api_client._async_access_token()) + return { + "Authorization": f"Bearer {access_token}" + } + @property def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: if self._nextgen_client_instance is not None: @@ -121,6 +134,7 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: # uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds. timeout=http_opts.timeout / 1000 if http_opts.timeout else None, max_retries=max_retries, + client_adapter=self ) client = self._nextgen_client_instance @@ -129,30 +143,6 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient: client._vertex_project = self._api_client.project client._vertex_location = self._api_client.location - async def prepare_options(options: FinalRequestOptions) -> FinalRequestOptions: - headers = {} - if is_given(options.headers): - headers = {**options.headers} - - headers['Authorization'] = f'Bearer {await self._api_client._async_access_token()}' - if ( - self._api_client._credentials - and self._api_client._credentials.quota_project_id - ): - headers['x-goog-user-project'] = ( - self._api_client._credentials.quota_project_id - ) - options.headers = headers - - return options - - if self._api_client.project or self._api_client.location: - client._prepare_options = prepare_options # type: ignore[method-assign] - - def validate_headers(headers: Headers, custom_headers: Headers) -> None: - return - - client._validate_headers = validate_headers # type: ignore[method-assign] return self._nextgen_client_instance @property @@ -278,7 +268,7 @@ class DebugConfig(pydantic.BaseModel): ) -class Client: +class Client(GeminiNextGenAPIClientAdapter): """Client for making synchronous requests. Use this client to make a request to the Gemini Developer API or Vertex AI @@ -449,6 +439,20 @@ def _get_api_client( http_options=http_options, ) + def is_vertex_ai(self) -> bool: + return self._api_client.vertexai or False + + def get_project(self) -> str | None: + return self._api_client.project + + def get_location(self) -> str | None: + return self._api_client.location + + def get_auth_headers(self) -> dict[str, str]: + return { + "Authorization": f"Bearer {self._api_client._access_token()}" + } + @property def _nextgen_client(self) -> GeminiNextGenAPIClient: if self._nextgen_client_instance is not None: @@ -490,39 +494,15 @@ def _nextgen_client(self) -> GeminiNextGenAPIClient: # uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds. timeout=http_opts.timeout / 1000 if http_opts.timeout else None, max_retries=max_retries, + client_adapter=self ) client = self._nextgen_client_instance - if self.vertexai: + if self._api_client.vertexai: client._is_vertex = True client._vertex_project = self._api_client.project client._vertex_location = self._api_client.location - def prepare_options(options: FinalRequestOptions) -> FinalRequestOptions: - headers = {} - if is_given(options.headers): - headers = {**options.headers} - options.headers = headers - - headers['Authorization'] = f'Bearer {self._api_client._access_token()}' - if ( - self._api_client._credentials - and self._api_client._credentials.quota_project_id - ): - headers['x-goog-user-project'] = ( - self._api_client._credentials.quota_project_id - ) - - return options - - if self._api_client.project or self._api_client.location: - client._prepare_options = prepare_options # type: ignore[method-assign] - - def validate_headers(headers: Headers, custom_headers: Headers) -> None: - return - - client._validate_headers = validate_headers # type: ignore[method-assign] - return self._nextgen_client_instance @property diff --git a/google/genai/tests/interactions/__init__.py b/google/genai/tests/interactions/__init__.py index e69de29bb..4b5d0a6a8 100644 --- a/google/genai/tests/interactions/__init__.py +++ b/google/genai/tests/interactions/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +"""Tests for the Google GenAI SDK's interactions module.""" \ No newline at end of file diff --git a/google/genai/tests/interactions/test_vertex_auth.py b/google/genai/tests/interactions/test_vertex_auth.py new file mode 100644 index 000000000..0ba391110 --- /dev/null +++ b/google/genai/tests/interactions/test_vertex_auth.py @@ -0,0 +1,222 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing +import pytest +from unittest.mock import Mock, MagicMock +from httpx import Request, Response +from ..._interactions import GeminiNextGenAPIClient + + +@pytest.fixture +def client_adapter(): + """ + Shared mock equivalent to the jasmine.SpyObj. + """ + adapter = MagicMock() + adapter.is_vertex_ai.return_value = False + adapter.get_project.return_value = "my-project" + adapter.get_location.return_value = "my-location" + adapter.get_auth_headers.return_value = {} + return adapter + + +class TestInteractionsRoutedToGemini: + @pytest.fixture(autouse=True) + def setup_client(self, client_adapter, monkeypatch): + client_adapter.is_vertex_ai.return_value = False + + client = GeminiNextGenAPIClient( + client_adapter=client_adapter, + base_url="https://my.base.host", + api_key="my-api-key", + api_version="somev1", + ) + + # Spy on client._client.send + def send_mock( + request: Request, + *args, + **kwargs + ): + return Response(request=request, status_code=200) + + self.send_mock = client._client.send = Mock(wraps=send_mock) + self.client_adapter = client_adapter + self.client = client + + def test_should_send_requests_to_existing_paths_without_client_auth_headers(self): + self.client.interactions.create( + agent="some-agent", + input="some input", + ) + + req: Request = self.send_mock.call_args_list[0][0][0] + + assert req.url == "https://my.base.host/somev1/interactions" + assert req.method.lower() == "post" + # in Gemini mode with apiKey, auth headers from adapter should NOT be called + self.client_adapter.get_auth_headers.assert_not_called() + + def test_should_retry_the_call(self): + def failing_fetch(request=Request, *args, **kwargs): + return Response( + request=request, + status_code=500, + headers={"retry-after-ms": "1"}, + ) + + self.client.max_retries = 4 + self.send_mock.side_effect = failing_fetch + + with pytest.raises(Exception): + self.client.interactions.create( + agent="some-agent", + input="some input", + ) + + # initial call + 4 retries + assert self.send_mock.call_count == 5 + + def test_should_not_invoke_client_auth_headers_if_manually_given(self): + # First call: manual Authorization header + self.client.api_key = None + self.client.interactions.create( + agent="some-agent", + input="some input", + extra_headers={ + "Authorization": "Bearer some-manual-token", + } + ) + + self.client_adapter.get_auth_headers.assert_not_called() + + req: Request = self.send_mock.call_args_list[0][0][0] + assert req.headers.get("Authorization") == "Bearer some-manual-token" + assert "x-goog-api-key" not in req.headers + + # Reset spies + self.send_mock.reset_mock() + self.client_adapter.get_auth_headers.reset_mock() + + # Second call: manual x-goog-api-key + self.client.interactions.create( + agent="some-agent", + input="some input", + extra_headers={ + "x-goog-api-key": "some-manual-key" + } + ) + + self.client_adapter.get_auth_headers.assert_not_called() + + req: Request = self.send_mock.call_args_list[0][0][0] + assert req.headers.get("x-goog-api-key") == "some-manual-key" + assert "Authorization" not in req.headers + + +class TestInteractionsRoutedToVertex: + @pytest.fixture(autouse=True) + def setup_client(self, client_adapter): + client_adapter.is_vertex_ai.return_value = True + client_adapter.get_auth_headers.return_value = { + "Authorization": "Bearer some-token", + } + + client = GeminiNextGenAPIClient( + client_adapter=client_adapter, + base_url="https://my.base.host", + api_version="somev1", + ) + + # Spy on client._client.send + def send_mock(request: Request, *args, **kwargs): + return Response(request=request, status_code=200) + + self.send_mock = client._client.send = Mock(wraps=send_mock) + self.client_adapter = client_adapter + self.client = client + + def test_should_send_requests_to_new_paths_with_client_auth_headers(self): + # Override auth headers for this test + self.client_adapter.get_auth_headers.return_value = { + "Authorization": "Bearer my-access-token", + } + + self.client.interactions.create( + agent="some-agent", + input="some input", + ) + + req: Request = self.send_mock.call_args_list[0][0][0] + + assert ( + req.url + == "https://my.base.host/somev1/projects/my-project/locations/my-location/interactions" + ) + assert req.method.lower() == "post" + self.client_adapter.get_auth_headers.assert_called() + + assert req.headers.get("Authorization") == "Bearer my-access-token" + + def test_should_retry_the_call(self): + def failing_fetch(request: Request, *args, **kwargs): + return Response( + request=request, + status_code=500, + headers={"retry-after-ms": "1"}, + ) + + self.send_mock.side_effect = failing_fetch + self.client.max_retries = 4 + + with pytest.raises(Exception): + self.client.interactions.create( + agent="some-agent", + input="some input", + ) + + assert self.send_mock.call_count == 5 + # Vertex path should re-fetch auth headers on each retry + assert self.client_adapter.get_auth_headers.call_count == 5 + + def test_should_override_client_auth_headers_if_manually_given(self): + # Manual Authorization header + self.client.interactions.create( + agent="some-agent", + input="some input", + extra_headers={ + "Authorization": "Bearer some-manual-token" + } + ) + + req: Request = self.send_mock.call_args_list[0][0][0] + assert req.headers.get("Authorization") == "Bearer some-manual-token" + + # Reset spies + self.send_mock.reset_mock() + self.client_adapter.get_auth_headers.reset_mock() + + # Manual x-goog-api-key header + self.client.interactions.create( + agent="some-agent", + input="some input", + extra_headers={ + "x-goog-api-key": "some-manual-key" + } + ) + + req: Request = self.send_mock.call_args_list[0][0][0] + assert req.headers.get("x-goog-api-key") == "some-manual-key"