Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
"opentelemetry-exporter-otlp-proto-http < 2",
"pydantic >= 2.11.1, < 3",
"typing_extensions",
"google-cloud-iam",
]

evaluation_extra_require = [
Expand Down
160 changes: 160 additions & 0 deletions vertexai/_genai/sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
import json
import logging
import mimetypes
import secrets
import time
from typing import Any, Iterator, Optional, Union
from urllib.parse import urlencode

from google import genai
from google.cloud import iam_credentials_v1
from google.genai import _api_module
from google.genai import _common
from google.genai import types as genai_types
from google.genai._common import get_value_by_path as getv
from google.genai._common import set_value_by_path as setv
from google.genai.pagers import Pager
Expand Down Expand Up @@ -704,6 +709,161 @@ def delete(
"""
return self._delete(name=name, config=config)

def generate_access_token(
self,
service_account_email: str,
sandbox_id: str,
port: str = "8080",
timeout: int = 3600,
) -> str:
"""Signs a JWT with a Google Cloud service account.

Args:
service_account_email (str):
Required. The email of the service account to use for signing.
sandbox_id (str):
Required. The resource name of the sandbox to generate a token for.
port (str):
Optional. The port to use for the token. Defaults to "8080".
timeout (int):
Optional. The timeout in seconds for the token. Defaults to 3600.

Returns:
str: The signed JWT.
"""
client = iam_credentials_v1.IAMCredentialsClient()
name = f"projects/-/serviceAccounts/{service_account_email}"
custom_claims = {"port": port, "sandbox_id": sandbox_id}
payload = {
"iat": int(time.time()),
"exp": int(time.time()) + timeout,
"iss": service_account_email,
"nonce": secrets.randbelow(1000000000) + 1,
"aud": "vmaas-proxy-api", # default audience for sandbox proxy
**custom_claims,
}
request = iam_credentials_v1.SignJwtRequest(
name=name,
payload=json.dumps(payload),
)
response = client.sign_jwt(request=request)
return response.signed_jwt

def send_command(
self,
*,
http_method: str,
access_token: str,
sandbox_environment: types.SandboxEnvironment,
path: str = None,
query_params: Optional[dict[str, object]] = None,
headers: Optional[dict[str, str]] = None,
request_dict: Optional[dict[str, object]] = None,
) -> genai_types.HttpResponse:
"""Sends a command to the sandbox.

Args:
http_method (str):
Required. The HTTP method to use for the command.
access_token (str):
Required. The access token to use for authorization.
sandbox_environment (types.SandboxEnvironment):
Required. The sandbox environment to send the command to.
path (str):
Optional. The path to send the command to.
query_params (dict[str, object]):
Optional. The query parameters to include in the command.
headers (dict[str, str]):
Optional. The headers to include in the command.
request_dict (dict[str, object]):
Optional. The request body to include in the command.

Returns:
genai_types.HttpResponse: The response from the sandbox.
"""
headers = headers or {}
request_dict = request_dict or {}
connection_info = sandbox_environment.connection_info
if not connection_info:
raise ValueError("Connection info is not available.")
if connection_info.load_balancer_hostname:
endpoint = "https://" + connection_info.load_balancer_hostname
elif connection_info.load_balancer_ip:
endpoint = "http://" + connection_info.load_balancer_ip
else:
raise ValueError("Load balancer hostname or ip is not available.")

path = path or ""
if query_params:
path = f"{path}?{urlencode(query_params)}"
headers["Authorization"] = f"Bearer {access_token}"
endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path
http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint)
http_client = genai.Client(vertexai=True, http_options=http_options)
# Full path is constructed in this function. The passed in path into request
# function will not be used.
response = http_client._api_client.request(http_method, path, request_dict)
return genai_types.HttpResponse(
headers=response.headers,
body=response.body,
)

def generate_browser_ws_headers(
self,
sandbox_environment: types.SandboxEnvironment,
service_account_email: str,
timeout: int = 3600,
) -> tuple[str, dict[str, str]]:
"""Generates the websocket upgrade headers for the browser.

Args:
sandbox_environment (types.SandboxEnvironment):
Required. The sandbox environment to generate websocket headers for.
service_account_email (str):
Required. The email of the service account to use for signing.
timeout (int):
Optional. The timeout in seconds for the token. Defaults to 3600.

Returns:
tuple[str, dict[str, str]]: A tuple containing the websocket URL and
the headers for websocket upgrade.
"""
sandbox_id = sandbox_environment.name
# port 8080 is the default port for http endpoint.
http_access_token = self.generate_access_token(
service_account_email, sandbox_id, "8080", timeout
)
response = self.send_command(
http_method="GET",
access_token=http_access_token,
sandbox_environment=sandbox_environment,
path="/cdp_ws_endpoint",
)
if not response:
raise ValueError("Failed to get the websocket endpoint.")
body_dict = json.loads(response.body)
ws_path = body_dict["endpoint"]

ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog"
if sandbox_environment and sandbox_environment.connection_info:
connection_info = sandbox_environment.connection_info
if connection_info.load_balancer_hostname:
ws_url = "wss://" + connection_info.load_balancer_hostname
elif connection_info.load_balancer_ip:
ws_url = "ws://" + connection_info.load_balancer_ip
else:
raise ValueError("Load balancer hostname or ip is not available.")
ws_url = ws_url + "/" + ws_path

# port 9222 is the default port for the browser websocket endpoint.
ws_access_token = self.generate_access_token(
service_account_email, sandbox_id, "9222", timeout
)

headers = {}
headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}"
return ws_url, headers


class AsyncSandboxes(_api_module.BaseModule):

Expand Down
Loading