diff --git a/setup.py b/setup.py index bd3a7fe965..ce185536fb 100644 --- a/setup.py +++ b/setup.py @@ -168,6 +168,7 @@ "opentelemetry-exporter-otlp-proto-http < 2", "pydantic >= 2.11.1, < 3", "typing_extensions", + "google-cloud-iam", ] evaluation_extra_require = [ diff --git a/vertexai/_genai/sandboxes.py b/vertexai/_genai/sandboxes.py index 3e1d1181bb..c0032c033e 100644 --- a/vertexai/_genai/sandboxes.py +++ b/vertexai/_genai/sandboxes.py @@ -19,11 +19,16 @@ import json import logging import mimetypes +import secrets +import time from typing import Any, Iterator, Optional, Union from urllib.parse import urlencode +from google import genai +from google.cloud import iam_credentials_v1 from google.genai import _api_module from google.genai import _common +from google.genai import types as genai_types from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv from google.genai.pagers import Pager @@ -704,6 +709,161 @@ def delete( """ return self._delete(name=name, config=config) + def generate_access_token( + self, + service_account_email: str, + sandbox_id: str, + port: str = "8080", + timeout: int = 3600, + ) -> str: + """Signs a JWT with a Google Cloud service account. + + Args: + service_account_email (str): + Required. The email of the service account to use for signing. + sandbox_id (str): + Required. The resource name of the sandbox to generate a token for. + port (str): + Optional. The port to use for the token. Defaults to "8080". + timeout (int): + Optional. The timeout in seconds for the token. Defaults to 3600. + + Returns: + str: The signed JWT. + """ + client = iam_credentials_v1.IAMCredentialsClient() + name = f"projects/-/serviceAccounts/{service_account_email}" + custom_claims = {"port": port, "sandbox_id": sandbox_id} + payload = { + "iat": int(time.time()), + "exp": int(time.time()) + timeout, + "iss": service_account_email, + "nonce": secrets.randbelow(1000000000) + 1, + "aud": "vmaas-proxy-api", # default audience for sandbox proxy + **custom_claims, + } + request = iam_credentials_v1.SignJwtRequest( + name=name, + payload=json.dumps(payload), + ) + response = client.sign_jwt(request=request) + return response.signed_jwt + + def send_command( + self, + *, + http_method: str, + access_token: str, + sandbox_environment: types.SandboxEnvironment, + path: str = None, + query_params: Optional[dict[str, object]] = None, + headers: Optional[dict[str, str]] = None, + request_dict: Optional[dict[str, object]] = None, + ) -> genai_types.HttpResponse: + """Sends a command to the sandbox. + + Args: + http_method (str): + Required. The HTTP method to use for the command. + access_token (str): + Required. The access token to use for authorization. + sandbox_environment (types.SandboxEnvironment): + Required. The sandbox environment to send the command to. + path (str): + Optional. The path to send the command to. + query_params (dict[str, object]): + Optional. The query parameters to include in the command. + headers (dict[str, str]): + Optional. The headers to include in the command. + request_dict (dict[str, object]): + Optional. The request body to include in the command. + + Returns: + genai_types.HttpResponse: The response from the sandbox. + """ + headers = headers or {} + request_dict = request_dict or {} + connection_info = sandbox_environment.connection_info + if not connection_info: + raise ValueError("Connection info is not available.") + if connection_info.load_balancer_hostname: + endpoint = "https://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + endpoint = "http://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + + path = path or "" + if query_params: + path = f"{path}?{urlencode(query_params)}" + headers["Authorization"] = f"Bearer {access_token}" + endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) + http_client = genai.Client(vertexai=True, http_options=http_options) + # Full path is constructed in this function. The passed in path into request + # function will not be used. + response = http_client._api_client.request(http_method, path, request_dict) + return genai_types.HttpResponse( + headers=response.headers, + body=response.body, + ) + + def generate_browser_ws_headers( + self, + sandbox_environment: types.SandboxEnvironment, + service_account_email: str, + timeout: int = 3600, + ) -> tuple[str, dict[str, str]]: + """Generates the websocket upgrade headers for the browser. + + Args: + sandbox_environment (types.SandboxEnvironment): + Required. The sandbox environment to generate websocket headers for. + service_account_email (str): + Required. The email of the service account to use for signing. + timeout (int): + Optional. The timeout in seconds for the token. Defaults to 3600. + + Returns: + tuple[str, dict[str, str]]: A tuple containing the websocket URL and + the headers for websocket upgrade. + """ + sandbox_id = sandbox_environment.name + # port 8080 is the default port for http endpoint. + http_access_token = self.generate_access_token( + service_account_email, sandbox_id, "8080", timeout + ) + response = self.send_command( + http_method="GET", + access_token=http_access_token, + sandbox_environment=sandbox_environment, + path="/cdp_ws_endpoint", + ) + if not response: + raise ValueError("Failed to get the websocket endpoint.") + body_dict = json.loads(response.body) + ws_path = body_dict["endpoint"] + + ws_url = "wss://test-us-central1.autopush-sandbox.vertexai.goog" + if sandbox_environment and sandbox_environment.connection_info: + connection_info = sandbox_environment.connection_info + if connection_info.load_balancer_hostname: + ws_url = "wss://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + ws_url = "ws://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + ws_url = ws_url + "/" + ws_path + + # port 9222 is the default port for the browser websocket endpoint. + ws_access_token = self.generate_access_token( + service_account_email, sandbox_id, "9222", timeout + ) + + headers = {} + headers["Sec-WebSocket-Protocol"] = f"binary, {ws_access_token}" + return ws_url, headers + class AsyncSandboxes(_api_module.BaseModule):