Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions google/genai/_interactions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
GeminiNextGenAPIClientError,
)
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._client_adapter import GeminiNextGenAPIClientAdapter
from ._utils._logs import setup_logging as _setup_logging

__all__ = [
Expand Down Expand Up @@ -96,6 +97,7 @@
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"DefaultAioHttpClient",
"GeminiNextGenAPIClientAdapter",
]

if not _t.TYPE_CHECKING:
Expand Down
66 changes: 54 additions & 12 deletions google/genai/_interactions/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
SyncAPIClient,
AsyncAPIClient,
)
from ._client_adapter import GeminiNextGenAPIClientAdapter
from ._models import FinalRequestOptions

if TYPE_CHECKING:
from .resources import interactions
Expand All @@ -66,12 +68,13 @@ class GeminiNextGenAPIClient(SyncAPIClient):
# client options
api_key: str | None
api_version: str
client_adapter: GeminiNextGenAPIClientAdapter

def __init__(
self,
*,
api_key: str | None = None,
api_version: str | None = "v1beta",
api_version: str | None = "v1alpha",
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -81,6 +84,7 @@ def __init__(
# We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.Client | None = None,
client_adapter: GeminiNextGenAPIClientAdapter,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
Expand All @@ -100,14 +104,16 @@ def __init__(
self.api_key = api_key

if api_version is None:
api_version = "v1beta"
api_version = "v1alpha"
self.api_version = api_version

if base_url is None:
base_url = os.environ.get("GEMINI_NEXT_GEN_API_BASE_URL")
if base_url is None:
base_url = f"https://generativelanguage.googleapis.com"

self.client_adapter = client_adapter

super().__init__(
version=__version__,
base_url=base_url,
Expand Down Expand Up @@ -144,9 +150,13 @@ def qs(self) -> Querystring:
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if api_key is None:
return {}
return {"x-goog-api-key": api_key}
if api_key is not None:
return {"x-goog-api-key": api_key}

if self.client_adapter.is_vertex_ai():
return self.client_adapter.get_auth_headers() or {}

return {}

@property
@override
Expand All @@ -159,14 +169,24 @@ def default_headers(self) -> dict[str, str | Omit]:

@override
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
return
if self.api_key and headers.get("x-goog-api-key"):
return
if isinstance(custom_headers.get("x-goog-api-key"), Omit):
if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
return

raise TypeError(
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
)

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if self.client_adapter.is_vertex_ai() and not options.url.startswith(f"/{self.api_version}/projects/"):
old_url = options.url[len(self.api_version) + 1:]
options.url = f"{self.api_version}/projects/{self.client_adapter.get_project()}/locations/{self.client_adapter.get_location()}{old_url}"

return options

def copy(
self,
Expand All @@ -181,6 +201,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
client_adapter: GeminiNextGenAPIClientAdapter | NotGiven = not_given,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -214,6 +235,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
client_adapter=client_adapter if is_given(client_adapter) else self.client_adapter,
**_extra_kwargs,
)

Expand Down Expand Up @@ -262,12 +284,13 @@ class AsyncGeminiNextGenAPIClient(AsyncAPIClient):
# client options
api_key: str | None
api_version: str
client_adapter: GeminiNextGenAPIClientAdapter

def __init__(
self,
*,
api_key: str | None = None,
api_version: str | None = "v1beta",
api_version: str | None = "v1alpha",
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = not_given,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -277,6 +300,7 @@ def __init__(
# We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
http_client: httpx.AsyncClient | None = None,
client_adapter: GeminiNextGenAPIClientAdapter,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
Expand All @@ -296,14 +320,16 @@ def __init__(
self.api_key = api_key

if api_version is None:
api_version = "v1beta"
api_version = "v1alpha"
self.api_version = api_version

if base_url is None:
base_url = os.environ.get("GEMINI_NEXT_GEN_API_BASE_URL")
if base_url is None:
base_url = f"https://generativelanguage.googleapis.com"

self.client_adapter = client_adapter

super().__init__(
version=__version__,
base_url=base_url,
Expand Down Expand Up @@ -340,9 +366,13 @@ def qs(self) -> Querystring:
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
if api_key is None:
return {}
return {"x-goog-api-key": api_key}
if api_key is not None:
return {"x-goog-api-key": api_key}

if self.client_adapter.is_vertex_ai():
return self.client_adapter.get_auth_headers() or {}

return {}

@property
@override
Expand All @@ -355,14 +385,24 @@ def default_headers(self) -> dict[str, str | Omit]:

@override
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
return
if self.api_key and headers.get("x-goog-api-key"):
return
if isinstance(custom_headers.get("x-goog-api-key"), Omit):
if custom_headers.get("x-goog-api-key") or isinstance(custom_headers.get("x-goog-api-key"), Omit):
return

raise TypeError(
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
)

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if self.client_adapter.is_vertex_ai() and not options.url.startswith(f"/{self.api_version}/projects/"):
old_url = options.url[len(self.api_version) + 1:]
options.url = f"{self.api_version}/projects/{self.client_adapter.get_project()}/locations/{self.client_adapter.get_location()}{old_url}"

return options

def copy(
self,
Expand All @@ -377,6 +417,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
client_adapter: GeminiNextGenAPIClientAdapter | NotGiven = not_given,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -410,6 +451,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
client_adapter=client_adapter if is_given(client_adapter) else self.client_adapter,
**_extra_kwargs,
)

Expand Down
21 changes: 21 additions & 0 deletions google/genai/_interactions/_client_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import typing
from abc import ABC, abstractmethod

from ._types import Headers

class GeminiNextGenAPIClientAdapter(ABC):
@abstractmethod
def is_vertex_ai(self) -> bool:
...

@abstractmethod
def get_project(self) -> str | None:
...

@abstractmethod
def get_location(self) -> str | None:
...

@abstractmethod
def get_auth_headers(self) -> dict[str, str] | None:
...
92 changes: 36 additions & 56 deletions google/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,13 @@

from . import _common

from ._interactions import AsyncGeminiNextGenAPIClient, DEFAULT_MAX_RETRIES, DefaultAioHttpClient, GeminiNextGenAPIClient
from ._interactions._models import FinalRequestOptions
from ._interactions._types import Headers
from ._interactions._utils import is_given
from ._interactions import AsyncGeminiNextGenAPIClient, DEFAULT_MAX_RETRIES, DefaultAioHttpClient, GeminiNextGenAPIClient, GeminiNextGenAPIClientAdapter

from ._interactions.resources import AsyncInteractionsResource as AsyncNextGenInteractionsResource, InteractionsResource as NextGenInteractionsResource
_interactions_experimental_warned = False


class AsyncClient:
class AsyncClient(GeminiNextGenAPIClientAdapter):
"""Client for making asynchronous (non-blocking) requests."""

def __init__(self, api_client: BaseApiClient):
Expand All @@ -68,6 +66,21 @@ def __init__(self, api_client: BaseApiClient):
self._operations = AsyncOperations(self._api_client)
self._nextgen_client_instance: Optional[AsyncGeminiNextGenAPIClient] = None

def is_vertex_ai(self) -> bool:
return self._api_client.vertexai or False

def get_project(self) -> str | None:
return self._api_client.project

def get_location(self) -> str | None:
return self._api_client.location

def get_auth_headers(self) -> dict[str, str]:
access_token = asyncio.run(self._api_client._async_access_token())
return {
"Authorization": f"Bearer {access_token}"
}

@property
def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
if self._nextgen_client_instance is not None:
Expand Down Expand Up @@ -121,6 +134,7 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
# uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds.
timeout=http_opts.timeout / 1000 if http_opts.timeout else None,
max_retries=max_retries,
client_adapter=self
)

client = self._nextgen_client_instance
Expand All @@ -129,30 +143,6 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
client._vertex_project = self._api_client.project
client._vertex_location = self._api_client.location

async def prepare_options(options: FinalRequestOptions) -> FinalRequestOptions:
headers = {}
if is_given(options.headers):
headers = {**options.headers}

headers['Authorization'] = f'Bearer {await self._api_client._async_access_token()}'
if (
self._api_client._credentials
and self._api_client._credentials.quota_project_id
):
headers['x-goog-user-project'] = (
self._api_client._credentials.quota_project_id
)
options.headers = headers

return options

if self._api_client.project or self._api_client.location:
client._prepare_options = prepare_options # type: ignore[method-assign]

def validate_headers(headers: Headers, custom_headers: Headers) -> None:
return

client._validate_headers = validate_headers # type: ignore[method-assign]
return self._nextgen_client_instance

@property
Expand Down Expand Up @@ -278,7 +268,7 @@ class DebugConfig(pydantic.BaseModel):
)


class Client:
class Client(GeminiNextGenAPIClientAdapter):
"""Client for making synchronous requests.

Use this client to make a request to the Gemini Developer API or Vertex AI
Expand Down Expand Up @@ -449,6 +439,20 @@ def _get_api_client(
http_options=http_options,
)

def is_vertex_ai(self) -> bool:
return self._api_client.vertexai or False

def get_project(self) -> str | None:
return self._api_client.project

def get_location(self) -> str | None:
return self._api_client.location

def get_auth_headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self._api_client._access_token()}"
}

@property
def _nextgen_client(self) -> GeminiNextGenAPIClient:
if self._nextgen_client_instance is not None:
Expand Down Expand Up @@ -490,39 +494,15 @@ def _nextgen_client(self) -> GeminiNextGenAPIClient:
# uSDk expects ms, nextgen uses a httpx Timeout -> expects seconds.
timeout=http_opts.timeout / 1000 if http_opts.timeout else None,
max_retries=max_retries,
client_adapter=self
)

client = self._nextgen_client_instance
if self.vertexai:
if self._api_client.vertexai:
client._is_vertex = True
client._vertex_project = self._api_client.project
client._vertex_location = self._api_client.location

def prepare_options(options: FinalRequestOptions) -> FinalRequestOptions:
headers = {}
if is_given(options.headers):
headers = {**options.headers}
options.headers = headers

headers['Authorization'] = f'Bearer {self._api_client._access_token()}'
if (
self._api_client._credentials
and self._api_client._credentials.quota_project_id
):
headers['x-goog-user-project'] = (
self._api_client._credentials.quota_project_id
)

return options

if self._api_client.project or self._api_client.location:
client._prepare_options = prepare_options # type: ignore[method-assign]

def validate_headers(headers: Headers, custom_headers: Headers) -> None:
return

client._validate_headers = validate_headers # type: ignore[method-assign]

return self._nextgen_client_instance

@property
Expand Down
Loading