Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,12 +984,7 @@ def _access_token(self) -> str:
self.project = project

if self._credentials:
if self._credentials.expired or not self._credentials.token:
# Only refresh when it needs to. Default expiration is 3600 seconds.
refresh_auth(self._credentials)
if not self._credentials.token:
raise RuntimeError('Could not resolve API token from the environment')
return self._credentials.token # type: ignore[no-any-return]
return get_token_from_credentials(self, self._credentials) # type: ignore[no-any-return]
else:
raise RuntimeError('Could not resolve API token from the environment')

Expand Down Expand Up @@ -1034,18 +1029,10 @@ async def _async_access_token(self) -> Union[str, Any]:
self.project = project

if self._credentials:
if self._credentials.expired or not self._credentials.token:
# Only refresh when it needs to. Default expiration is 3600 seconds.
async_auth_lock = await self._get_async_auth_lock()
async with async_auth_lock:
if self._credentials.expired or not self._credentials.token:
# Double check that the credentials expired before refreshing.
await asyncio.to_thread(refresh_auth, self._credentials)

if not self._credentials.token:
raise RuntimeError('Could not resolve API token from the environment')

return self._credentials.token
return await async_get_token_from_credentials(
self,
self._credentials
) # type: ignore[no-any-return]
else:
raise RuntimeError('Could not resolve API token from the environment')

Expand Down Expand Up @@ -1925,3 +1912,35 @@ def __del__(self) -> None:
asyncio.get_running_loop().create_task(self.aclose())
except Exception: # pylint: disable=broad-except
pass

def get_token_from_credentials(
client: 'BaseApiClient',
credentials: google.auth.credentials.Credentials
) -> str:
"""Refreshes the authentication token for the given credentials."""
if credentials.expired or not credentials.token:
# Only refresh when it needs to. Default expiration is 3600 seconds.
refresh_auth(credentials)
if not credentials.token:
raise RuntimeError('Could not resolve API token from the environment')
return credentials.token # type: ignore[no-any-return]

async def async_get_token_from_credentials(
client: 'BaseApiClient',
credentials: google.auth.credentials.Credentials
) -> str:
"""Refreshes the authentication token for the given credentials."""
if credentials.expired or not credentials.token:
# Only refresh when it needs to. Default expiration is 3600 seconds.
async_auth_lock = await client._get_async_auth_lock()
async with async_auth_lock:
if credentials.expired or not credentials.token:
# Double check that the credentials expired before refreshing.
await asyncio.to_thread(refresh_auth, credentials)

if not credentials.token:
raise RuntimeError('Could not resolve API token from the environment')

return credentials.token # type: ignore[no-any-return]


224 changes: 224 additions & 0 deletions google/genai/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from typing import Any, Optional, Union
from urllib.parse import urlencode

import google.auth

from . import _api_client
from . import _api_module
from . import _common
from . import _extra_utils
Expand Down Expand Up @@ -149,6 +152,33 @@ def _ListFilesResponse_from_mldev(
return to_object


def _RegisterFilesParameters_to_mldev(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['uris']) is not None:
setv(to_object, ['uris'], getv(from_object, ['uris']))

return to_object


def _RegisterFilesResponse_from_mldev(
from_object: Union[dict[str, Any], object],
parent_object: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
to_object: dict[str, Any] = {}
if getv(from_object, ['sdkHttpResponse']) is not None:
setv(
to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse'])
)

if getv(from_object, ['files']) is not None:
setv(to_object, ['files'], [item for item in getv(from_object, ['files'])])

return to_object


class Files(_api_module.BaseModule):

def _list(
Expand Down Expand Up @@ -402,6 +432,69 @@ def delete(
self._api_client._verify_response(return_value)
return return_value

def _register_files(
self,
*,
uris: list[str],
config: Optional[types.RegisterFilesConfigOrDict] = None,
) -> types.RegisterFilesResponse:
parameter_model = types._RegisterFilesParameters(
uris=uris,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if self._api_client.vertexai:
raise ValueError(
'This method is only supported in the Gemini Developer client.'
)
else:
request_dict = _RegisterFilesParameters_to_mldev(parameter_model)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = 'files:register'.format_map(request_url_dict)
else:
path = 'files:register'

query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
request_dict.pop('config', None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = self._api_client.request(
'post', path, request_dict, http_options
)

if config is not None and getattr(
config, 'should_return_http_response', None
):
return_value = types.RegisterFilesResponse(sdk_http_response=response)
self._api_client._verify_response(return_value)
return return_value

response_dict = {} if not response.body else json.loads(response.body)

if not self._api_client.vertexai:
response_dict = _RegisterFilesResponse_from_mldev(response_dict)

return_value = types.RegisterFilesResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
return return_value

def upload(
self,
*,
Expand Down Expand Up @@ -559,6 +652,39 @@ def download(

return data

def register_files(
self,
*,
auth: google.auth.credentials.Credentials,
uris: list[str],
config: Optional[types.RegisterFilesConfigOrDict] = None,
) -> types.RegisterFilesResponse:
"""Registers gcs files with the file service."""
if not isinstance(auth, google.auth.credentials.Credentials):
raise ValueError(
'auth must be a google.auth.credentials.Credentials object.'
)
if config is None:
config = types.RegisterFilesConfig()
else:
config = types.RegisterFilesConfig.model_validate(config)
config = config.model_copy(deep=True)

http_options = config.http_options or types.HttpOptions()
headers = http_options.headers or {}
headers = {k.lower(): v for k, v in headers.items()}

token = _api_client.get_token_from_credentials(self._api_client, auth)
headers['authorization'] = f'Bearer {token}'

if auth.quota_project_id:
headers['x-goog-user-project'] = auth.quota_project_id

http_options.headers = headers
config.http_options = http_options

return self._register_files(uris=uris, config=config)

def list(
self, *, config: Optional[types.ListFilesConfigOrDict] = None
) -> Pager[types.File]:
Expand Down Expand Up @@ -845,6 +971,69 @@ async def delete(
self._api_client._verify_response(return_value)
return return_value

async def _register_files(
self,
*,
uris: list[str],
config: Optional[types.RegisterFilesConfigOrDict] = None,
) -> types.RegisterFilesResponse:
parameter_model = types._RegisterFilesParameters(
uris=uris,
config=config,
)

request_url_dict: Optional[dict[str, str]]
if self._api_client.vertexai:
raise ValueError(
'This method is only supported in the Gemini Developer client.'
)
else:
request_dict = _RegisterFilesParameters_to_mldev(parameter_model)
request_url_dict = request_dict.get('_url')
if request_url_dict:
path = 'files:register'.format_map(request_url_dict)
else:
path = 'files:register'

query_params = request_dict.get('_query')
if query_params:
path = f'{path}?{urlencode(query_params)}'
# TODO: remove the hack that pops config.
request_dict.pop('config', None)

http_options: Optional[types.HttpOptions] = None
if (
parameter_model.config is not None
and parameter_model.config.http_options is not None
):
http_options = parameter_model.config.http_options

request_dict = _common.convert_to_dict(request_dict)
request_dict = _common.encode_unserializable_types(request_dict)

response = await self._api_client.async_request(
'post', path, request_dict, http_options
)

if config is not None and getattr(
config, 'should_return_http_response', None
):
return_value = types.RegisterFilesResponse(sdk_http_response=response)
self._api_client._verify_response(return_value)
return return_value

response_dict = {} if not response.body else json.loads(response.body)

if not self._api_client.vertexai:
response_dict = _RegisterFilesResponse_from_mldev(response_dict)

return_value = types.RegisterFilesResponse._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)

self._api_client._verify_response(return_value)
return return_value

async def upload(
self,
*,
Expand Down Expand Up @@ -992,6 +1181,41 @@ async def download(

return data

async def register_files(
self,
*,
auth: google.auth.credentials.Credentials,
uris: list[str],
config: Optional[types.RegisterFilesConfigOrDict] = None,
) -> types.RegisterFilesResponse:
"""Registers gcs files with the file service."""
if not isinstance(auth, google.auth.credentials.Credentials):
raise ValueError(
'auth must be a google.auth.credentials.Credentials object.'
)
if config is None:
config = types.RegisterFilesConfig()
else:
config = types.RegisterFilesConfig.model_validate(config)
config = config.model_copy(deep=True)

http_options = config.http_options or types.HttpOptions()
headers = http_options.headers or {}
headers = {k.lower(): v for k, v in headers.items()}

token = await _api_client.async_get_token_from_credentials(
self._api_client, auth
)
headers['authorization'] = f'Bearer {token}'

if auth.quota_project_id:
headers['x-goog-user-project'] = auth.quota_project_id

http_options.headers = headers
config.http_options = http_options

return await self._register_files(uris=uris, config=config)

async def list(
self, *, config: Optional[types.ListFilesConfigOrDict] = None
) -> AsyncPager[types.File]:
Expand Down
Loading