From 26e74d279cfe22c69d7e8a8c9e4d99111e21ea94 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 21 Nov 2025 14:15:21 -0800 Subject: [PATCH] chore: refactor live conection parameters into separate vertex and mldef functions. chore: move the credentials refresh into asyncio.to_thread (this func is async, requests should be too) The diff looks messy, but this really just splits the contents of the giant if/else in connect into _prepare_connection_vertex and _prepare_connection_mldev. No behavior changes expected. Existing tests pass, additional tests added for coverage. PiperOrigin-RevId: 835355573 --- google/genai/live.py | 282 ++++++++++++++++----------- google/genai/tests/live/test_live.py | 77 ++++++++ 2 files changed, 247 insertions(+), 112 deletions(-) diff --git a/google/genai/live.py b/google/genai/live.py index 0d8453fc5..2bc5f5aa4 100644 --- a/google/genai/live.py +++ b/google/genai/live.py @@ -929,76 +929,172 @@ async def connect( base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') - transformed_model = t.t_model(self._api_client, model) # type: ignore parameter_model = await _t_live_connect_config(self._api_client, config) - if self._api_client.api_key and not self._api_client.vertexai: - version = self._api_client._http_options.api_version - api_key = self._api_client.api_key - method = 'BidiGenerateContent' - original_headers = self._api_client._http_options.headers - headers = original_headers.copy() if original_headers is not None else {} - if api_key.startswith('auth_tokens/'): + if self._api_client.vertexai: + uri, headers, request = await self._prepare_connection_vertex( + base_url=base_url, model=model, parameter_model=parameter_model + ) + else: + uri, headers, request = await self._prepare_connection_mldev( + base_url=base_url, model=model, parameter_model=parameter_model + ) + + if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( + parameter_model.tools + ): + if headers is None: + headers = {} + _mcp_utils.set_mcp_usage_header(headers) + + async with ws_connect( + uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx + ) as ws: + await ws.send(request) + try: + # websockets 14.0+ + raw_response = await ws.recv(decode=False) + except TypeError: + raw_response = await ws.recv() # type: ignore[assignment] + if raw_response: + try: + response = json.loads(raw_response) + except json.decoder.JSONDecodeError as e: + raise ValueError(f'Failed to parse response: {raw_response!r}') from e + else: + response = {} + + if self._api_client.vertexai: + response_dict = live_converters._LiveServerMessage_from_vertex(response) + else: + response_dict = response + + setup_response = types.LiveServerMessage._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + if setup_response.setup_complete: + session_id = setup_response.setup_complete.session_id + else: + session_id = None + yield AsyncSession( + api_client=self._api_client, + websocket=ws, + session_id=session_id, + ) + + async def _prepare_connection_mldev( + self, *, + base_url: str, + model: str, + parameter_model: types.LiveConnectConfig, + ) -> tuple[str, _common.StringDict, str]: + """Prepares live connection parameters for the MLDev API. + + Constructs the WebSocket URI, headers, and request body necessary + to establish a connection with the MLDev backend. + + Args: + base_url: The base URL for the WebSocket connection. + model: The name of the model to use. + parameter_model: Configuration parameters for the connection. + + Returns: + A tuple containing: + - uri: The WebSocket connection URI. + - headers: A dictionary of headers for the connection. + - request: The JSON-serialized request body. + + Raises: + ValueError: If an API key is not provided. + """ + transformed_model = t.t_model(self._api_client, model) # type: ignore + version = self._api_client._http_options.api_version + method = 'BidiGenerateContent' + original_headers = self._api_client._http_options.headers + headers = original_headers.copy() if original_headers is not None else {} + api_key = self._api_client.api_key + + if not api_key: + # this shouldn't happen + raise ValueError('Genai live connection requires an API key.') + + if api_key.startswith('auth_tokens/'): + method = 'BidiGenerateContentConstrained' + headers['Authorization'] = f'Token {api_key}' + warnings.warn( + message=( + "The SDK's ephemeral token support is experimental, and may" + ' change in future versions.' + ), + category=errors.ExperimentalWarning, + ) + if version != 'v1alpha': warnings.warn( message=( - "The SDK's ephemeral token support is experimental, and may" - ' change in future versions.' + "The SDK's ephemeral token support is in v1alpha only." + 'Please use client = genai.Client(api_key=token.name, ' + 'http_options=types.HttpOptions(api_version="v1alpha"))' + ' before session connection.' ), category=errors.ExperimentalWarning, ) - method = 'BidiGenerateContentConstrained' - headers['Authorization'] = f'Token {api_key}' - if version != 'v1alpha': - warnings.warn( - message=( - "The SDK's ephemeral token support is in v1alpha only." - 'Please use client = genai.Client(api_key=token.name, ' - 'http_options=types.HttpOptions(api_version="v1alpha"))' - ' before session connection.' - ), - category=errors.ExperimentalWarning, - ) - uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' - - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_mldev( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] - setv(request_dict, ['setup', 'model'], transformed_model) + uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' - request = json.dumps(request_dict) - elif self._api_client.api_key and self._api_client.vertexai: - # Headers already contains api key for express mode. - api_key = self._api_client.api_key - version = self._api_client._http_options.api_version - uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' - original_headers = self._api_client._http_options.headers - headers = original_headers.copy() if original_headers is not None else {} - - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_vertex( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] + request_dict = _common.convert_to_dict( + live_converters._LiveConnectParameters_to_mldev( + api_client=self._api_client, + from_object=types.LiveConnectParameters( + model=transformed_model, + config=parameter_model, + ).model_dump(exclude_none=True), + ) + ) + del request_dict['config'] + + setv(request_dict, ['setup', 'model'], transformed_model) + + return uri, headers, json.dumps(request_dict) + + + async def _prepare_connection_vertex( + self, *, + base_url: str, + model: str, + parameter_model: types.LiveConnectConfig, + ) -> tuple[str, _common.StringDict, str]: + """Prepares live connection parameters for the Vertex AI API. - setv(request_dict, ['setup', 'model'], transformed_model) + Constructs the WebSocket URI, headers, and request body necessary + to establish a connection with the Vertex AI backend. Handles + authentication using either an API key or default credentials. - request = json.dumps(request_dict) + Args: + base_url: The base URL for the WebSocket connection. + model: The name of the model to use. + parameter_model: Configuration parameters for the connection. + + Returns: + A tuple containing: + - uri: The WebSocket connection URI. + - headers: A dictionary of headers for the connection. + - request: The JSON-serialized request body. + + Raises: + ValueError: If project and location are not provided when + default credentials are used. + """ + transformed_model = t.t_model(self._api_client, model) # type: ignore + version = self._api_client._http_options.api_version + original_headers = self._api_client._http_options.headers + headers = ( + original_headers.copy() if original_headers is not None else {} + ) + if api_key := self._api_client.api_key: + # Headers already contains api key + uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' else: - version = self._api_client._http_options.api_version has_sufficient_auth = ( self._api_client.project and self._api_client.location ) @@ -1028,13 +1124,9 @@ async def connect( # Need to refresh credentials to populate those if not (creds.token and creds.valid): auth_req = google.auth.transport.requests.Request() # type: ignore - creds.refresh(auth_req) + await asyncio.to_thread(creds.refresh, auth_req) bearer_token = creds.token - original_headers = self._api_client._http_options.headers - headers = ( - original_headers.copy() if original_headers is not None else {} - ) if not headers.get('Authorization'): headers['Authorization'] = f'Bearer {bearer_token}' @@ -1044,17 +1136,22 @@ async def connect( transformed_model = ( f'projects/{project}/locations/{location}/' + transformed_model ) - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_vertex( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] + request_dict = _common.convert_to_dict( + live_converters._LiveConnectParameters_to_vertex( + api_client=self._api_client, + from_object=types.LiveConnectParameters( + model=transformed_model, + config=parameter_model, + ).model_dump(exclude_none=True), + ) + ) + del request_dict['config'] + + if api_key is None: + # Refactor note: I'm surprised the two paths are different, you'd have + # to test every model to be sure. The goal of this refactor is to not + # change any behavior so leaving it as is. if ( getv( request_dict, ['setup', 'generationConfig', 'responseModalities'] @@ -1067,49 +1164,10 @@ async def connect( ['AUDIO'], ) - request = json.dumps(request_dict) + return uri, headers, json.dumps(request_dict) - if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( - parameter_model.tools - ): - if headers is None: - headers = {} - _mcp_utils.set_mcp_usage_header(headers) - async with ws_connect( - uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx - ) as ws: - await ws.send(request) - try: - # websockets 14.0+ - raw_response = await ws.recv(decode=False) - except TypeError: - raw_response = await ws.recv() # type: ignore[assignment] - if raw_response: - try: - response = json.loads(raw_response) - except json.decoder.JSONDecodeError: - raise ValueError(f'Failed to parse response: {raw_response!r}') - else: - response = {} - if self._api_client.vertexai: - response_dict = live_converters._LiveServerMessage_from_vertex(response) - else: - response_dict = response - - setup_response = types.LiveServerMessage._from_response( - response=response_dict, kwargs=parameter_model.model_dump() - ) - if setup_response.setup_complete: - session_id = setup_response.setup_complete.session_id - else: - session_id = None - yield AsyncSession( - api_client=self._api_client, - websocket=ws, - session_id=session_id, - ) async def _t_live_connect_config( diff --git a/google/genai/tests/live/test_live.py b/google/genai/tests/live/test_live.py index 9ffebb726..a814aebda 100644 --- a/google/genai/tests/live/test_live.py +++ b/google/genai/tests/live/test_live.py @@ -29,6 +29,8 @@ import warnings import certifi +import google.auth +from google.auth.transport import requests from google.oauth2.credentials import Credentials import pytest from websockets import client @@ -85,6 +87,23 @@ }] +class FakeCredentials(Credentials): + def __init__(self, token='fake_token', valid=True): + super().__init__(token='placeholder') + self.token = token + self._valid = valid + self.refresh_called = False + + def refresh(self, request): + self.token = 'refreshed_token' + self._valid = True + self.refresh_called = True + + @property + def valid(self): + return self._valid + + def get_current_weather(location: str, unit: str): """Get the current weather in a city.""" return 15 if unit == 'C' else 59 @@ -2073,3 +2092,61 @@ async def mock_connect(uri, additional_headers=None, **kwargs): assert 'x-goog-api-key' in capture['headers'], "x-goog-api-key is missing from headers" assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY' assert 'BidiGenerateContent' in capture['uri'] + + + +@pytest.mark.asyncio +async def test_prepare_connection_vertex_with_api_key(mock_websocket): + # Test the branch where api_key is present in vertexai + client = Client(vertexai=True, api_key="test_api_key") + capture = {} + + @contextlib.asynccontextmanager + async def mock_ws_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch.object(live, 'ws_connect', new=mock_ws_connect): + live_module = client.aio.live + async with live_module.connect(model='test-model'): + pass + + headers = capture['headers'] + uri = capture['uri'] + assert 'x-goog-api-key' in headers + assert headers['x-goog-api-key'] == "test_api_key" + # Authorization header should not be added by this method if api_key is used + assert 'Authorization' not in headers + assert "BidiGenerateContent" in uri + + +@pytest.mark.asyncio +async def test_prepare_connection_vertex_refresh_creds(mock_websocket): + # Test the branch where credentials need refreshing + fake_creds = FakeCredentials(token=None, valid=False) + capture = {} + + @contextlib.asynccontextmanager + async def mock_ws_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with ( + patch.object(google.auth, 'default', return_value=(fake_creds, "test-project")), + patch.object(requests, 'Request', return_value=Mock()), + patch.object(live, 'ws_connect', new=mock_ws_connect) + ): + client = Client(vertexai=True, project="test-project", + location="us-central1") + live_module = client.aio.live + async with live_module.connect(model='test-model'): + pass + + headers = capture['headers'] + uri = capture['uri'] + assert fake_creds.refresh_called + assert 'Authorization' in headers + assert headers['Authorization'] == f'Bearer refreshed_token' + assert "BidiGenerateContent" in uri