Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 170 additions & 112 deletions google/genai/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,76 +929,172 @@ async def connect(
base_url = self._api_client._websocket_base_url()
if isinstance(base_url, bytes):
base_url = base_url.decode('utf-8')
transformed_model = t.t_model(self._api_client, model) # type: ignore

parameter_model = await _t_live_connect_config(self._api_client, config)

if self._api_client.api_key and not self._api_client.vertexai:
version = self._api_client._http_options.api_version
api_key = self._api_client.api_key
method = 'BidiGenerateContent'
original_headers = self._api_client._http_options.headers
headers = original_headers.copy() if original_headers is not None else {}
if api_key.startswith('auth_tokens/'):
if self._api_client.vertexai:
uri, headers, request = await self._prepare_connection_vertex(
base_url=base_url, model=model, parameter_model=parameter_model
)
else:
uri, headers, request = await self._prepare_connection_mldev(
base_url=base_url, model=model, parameter_model=parameter_model
)

if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
parameter_model.tools
):
if headers is None:
headers = {}
_mcp_utils.set_mcp_usage_header(headers)

async with ws_connect(
uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx
) as ws:
await ws.send(request)
try:
# websockets 14.0+
raw_response = await ws.recv(decode=False)
except TypeError:
raw_response = await ws.recv() # type: ignore[assignment]
if raw_response:
try:
response = json.loads(raw_response)
except json.decoder.JSONDecodeError as e:
raise ValueError(f'Failed to parse response: {raw_response!r}') from e
else:
response = {}

if self._api_client.vertexai:
response_dict = live_converters._LiveServerMessage_from_vertex(response)
else:
response_dict = response

setup_response = types.LiveServerMessage._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
if setup_response.setup_complete:
session_id = setup_response.setup_complete.session_id
else:
session_id = None
yield AsyncSession(
api_client=self._api_client,
websocket=ws,
session_id=session_id,
)

async def _prepare_connection_mldev(
self, *,
base_url: str,
model: str,
parameter_model: types.LiveConnectConfig,
) -> tuple[str, _common.StringDict, str]:
"""Prepares live connection parameters for the MLDev API.

Constructs the WebSocket URI, headers, and request body necessary
to establish a connection with the MLDev backend.

Args:
base_url: The base URL for the WebSocket connection.
model: The name of the model to use.
parameter_model: Configuration parameters for the connection.

Returns:
A tuple containing:
- uri: The WebSocket connection URI.
- headers: A dictionary of headers for the connection.
- request: The JSON-serialized request body.

Raises:
ValueError: If an API key is not provided.
"""
transformed_model = t.t_model(self._api_client, model) # type: ignore
version = self._api_client._http_options.api_version
method = 'BidiGenerateContent'
original_headers = self._api_client._http_options.headers
headers = original_headers.copy() if original_headers is not None else {}
api_key = self._api_client.api_key

if not api_key:
# this shouldn't happen
raise ValueError('Genai live connection requires an API key.')

if api_key.startswith('auth_tokens/'):
method = 'BidiGenerateContentConstrained'
headers['Authorization'] = f'Token {api_key}'
warnings.warn(
message=(
"The SDK's ephemeral token support is experimental, and may"
' change in future versions.'
),
category=errors.ExperimentalWarning,
)
if version != 'v1alpha':
warnings.warn(
message=(
"The SDK's ephemeral token support is experimental, and may"
' change in future versions.'
"The SDK's ephemeral token support is in v1alpha only."
'Please use client = genai.Client(api_key=token.name, '
'http_options=types.HttpOptions(api_version="v1alpha"))'
' before session connection.'
),
category=errors.ExperimentalWarning,
)
method = 'BidiGenerateContentConstrained'
headers['Authorization'] = f'Token {api_key}'
if version != 'v1alpha':
warnings.warn(
message=(
"The SDK's ephemeral token support is in v1alpha only."
'Please use client = genai.Client(api_key=token.name, '
'http_options=types.HttpOptions(api_version="v1alpha"))'
' before session connection.'
),
category=errors.ExperimentalWarning,
)
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}'

request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_mldev(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True),
)
)
del request_dict['config']

setv(request_dict, ['setup', 'model'], transformed_model)
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}'

request = json.dumps(request_dict)
elif self._api_client.api_key and self._api_client.vertexai:
# Headers already contains api key for express mode.
api_key = self._api_client.api_key
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
original_headers = self._api_client._http_options.headers
headers = original_headers.copy() if original_headers is not None else {}

request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_vertex(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True),
)
)
del request_dict['config']
request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_mldev(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True),
)
)
del request_dict['config']

setv(request_dict, ['setup', 'model'], transformed_model)

return uri, headers, json.dumps(request_dict)


async def _prepare_connection_vertex(
self, *,
base_url: str,
model: str,
parameter_model: types.LiveConnectConfig,
) -> tuple[str, _common.StringDict, str]:
"""Prepares live connection parameters for the Vertex AI API.

setv(request_dict, ['setup', 'model'], transformed_model)
Constructs the WebSocket URI, headers, and request body necessary
to establish a connection with the Vertex AI backend. Handles
authentication using either an API key or default credentials.

request = json.dumps(request_dict)
Args:
base_url: The base URL for the WebSocket connection.
model: The name of the model to use.
parameter_model: Configuration parameters for the connection.

Returns:
A tuple containing:
- uri: The WebSocket connection URI.
- headers: A dictionary of headers for the connection.
- request: The JSON-serialized request body.

Raises:
ValueError: If project and location are not provided when
default credentials are used.
"""
transformed_model = t.t_model(self._api_client, model) # type: ignore
version = self._api_client._http_options.api_version
original_headers = self._api_client._http_options.headers
headers = (
original_headers.copy() if original_headers is not None else {}
)
if api_key := self._api_client.api_key:
# Headers already contains api key
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
else:
version = self._api_client._http_options.api_version
has_sufficient_auth = (
self._api_client.project and self._api_client.location
)
Expand Down Expand Up @@ -1028,13 +1124,9 @@ async def connect(
# Need to refresh credentials to populate those
if not (creds.token and creds.valid):
auth_req = google.auth.transport.requests.Request() # type: ignore
creds.refresh(auth_req)
await asyncio.to_thread(creds.refresh, auth_req)
bearer_token = creds.token

original_headers = self._api_client._http_options.headers
headers = (
original_headers.copy() if original_headers is not None else {}
)
if not headers.get('Authorization'):
headers['Authorization'] = f'Bearer {bearer_token}'

Expand All @@ -1044,17 +1136,22 @@ async def connect(
transformed_model = (
f'projects/{project}/locations/{location}/' + transformed_model
)
request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_vertex(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True),
)
)
del request_dict['config']

request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_vertex(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True),
)
)
del request_dict['config']

if api_key is None:
# Refactor note: I'm surprised the two paths are different, you'd have
# to test every model to be sure. The goal of this refactor is to not
# change any behavior so leaving it as is.
if (
getv(
request_dict, ['setup', 'generationConfig', 'responseModalities']
Expand All @@ -1067,49 +1164,10 @@ async def connect(
['AUDIO'],
)

request = json.dumps(request_dict)
return uri, headers, json.dumps(request_dict)

if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
parameter_model.tools
):
if headers is None:
headers = {}
_mcp_utils.set_mcp_usage_header(headers)

async with ws_connect(
uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx
) as ws:
await ws.send(request)
try:
# websockets 14.0+
raw_response = await ws.recv(decode=False)
except TypeError:
raw_response = await ws.recv() # type: ignore[assignment]
if raw_response:
try:
response = json.loads(raw_response)
except json.decoder.JSONDecodeError:
raise ValueError(f'Failed to parse response: {raw_response!r}')
else:
response = {}

if self._api_client.vertexai:
response_dict = live_converters._LiveServerMessage_from_vertex(response)
else:
response_dict = response

setup_response = types.LiveServerMessage._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
if setup_response.setup_complete:
session_id = setup_response.setup_complete.session_id
else:
session_id = None
yield AsyncSession(
api_client=self._api_client,
websocket=ws,
session_id=session_id,
)


async def _t_live_connect_config(
Expand Down
Loading