From d5b8fd4f7d1faf956e4dbf764dbc09fa8717fa13 Mon Sep 17 00:00:00 2001 From: Luke Oliff Date: Fri, 14 Nov 2025 13:57:09 +0000 Subject: [PATCH] wip: parking custom websocket client --- .fernignore | 10 +- sagemaker_bidi_api.md | 104 ++++++++++++++++++ src/deepgram/client.py | 76 +++++++++---- src/deepgram/core/client_wrapper.py | 23 +++- src/deepgram/core/websocket_client.py | 151 ++++++++++++++++++++++++++ 5 files changed, 337 insertions(+), 27 deletions(-) create mode 100644 sagemaker_bidi_api.md create mode 100644 src/deepgram/core/websocket_client.py diff --git a/.fernignore b/.fernignore index dfd18a20..c5dcfaa3 100644 --- a/.fernignore +++ b/.fernignore @@ -19,15 +19,23 @@ tests/unit/ tests/integrations/ # Custom Extensions & Clients +# - client.py: Custom DeepgramClient wrapper that adds access_token, telemetry, and websocket_client support +# - websocket_client.py: WebSocket abstraction layer for custom clients (e.g., AWS SageMaker) +# - extensions/: All custom extension modules (telemetry, instrumentation, socket types) src/deepgram/client.py +src/deepgram/core/websocket_client.py src/deepgram/extensions/ # Socket Client Implementations +# Enhanced with binary message support, comprehensive socket types, and send methods +# These wrap the raw WebSocket connections with Deepgram-specific message handling src/deepgram/agent/v1/socket_client.py src/deepgram/listen/v1/socket_client.py src/deepgram/listen/v2/socket_client.py src/deepgram/speak/v1/socket_client.py -# Bug Fixes +# Core Modifications +# - client_wrapper.py: Enhanced to support custom websocket_client parameter +# - listen/client.py: Bug fix for v2 client property access (pre-existing, not related to WebSocket work) src/deepgram/listen/client.py src/deepgram/core/client_wrapper.py \ No newline at end of file diff --git a/sagemaker_bidi_api.md b/sagemaker_bidi_api.md new file mode 100644 index 00000000..5b8ecc00 --- /dev/null +++ b/sagemaker_bidi_api.md @@ -0,0 +1,104 @@ +# [Private Beta] SageMakerRuntime BiDi API Contract + +### 1. HTTP/2 Request Syntax + +``` +POST https://runtime.sagemaker.us-east-2.amazonaws.com:8443/endpoints//invocations-bidirectional-stream + +[Optional] X-Amzn-SageMaker-Target-Variant: +[Optional] X-Amzn-SageMaker-Model-Invocation-Path: +[Optional] X-Amzn-SageMaker-Model-Query-String: +{ + "PayloadPart": { + "Bytes": , + "DataType": , + "CompletionState": , + "P": + } +} +``` + +More detailed explanation **(you can skip this part if you are already familiar with SageMaker publi APIs):** + +- [Optional] TargetVariant: + - Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights. + - Length Constraints: Maximum length of 63. + - Pattern: **`^[a-zA-Z0-9](-*[a-zA-Z0-9])*`** +- [Optional] ModelInvocationPath: + - SageMaker connections to model container using URL **`ws://local:8081/`**. This parameter defaults to **`invoke-bidi-stream`** if not specified. + - Length Constraints: Maximum length of 100. + - Pattern: **`^[A-Za-z0-9\-._]+(?:/[A-Za-z0-9\-._]+)*$`** +- [Optional] ModelQueryString: + - If specified, SageMaker appends it to the URL when connecting to model containers: **`ws://local:8081/?`**. + - Length Constraints: Maximum length of 2048. + - Pattern: **`^[a-zA-Z0-9][A-Za-z0-9_-]*=(?:[A-Za-z0-9._~\-]|%[0-9A-Fa-f]{2})+(?:&[a-zA-Z0-9][A-Za-z0-9_-]*=(?:[A-Za-z0-9._~\-]|%[0-9A-Fa-f]{2})+)*$`** +- [Required] RequestStream (the following data is sent in JSON format by the client to the service): + - PayloadPart: A wrapper object of input payload. A request stream consists of one or more pieces of PayloadParts. This object has the following fields: + - [Required] Bytes: + - Base64 encoded binary chunk. + - [Optional] DataType: + - Regex Pattern: **`^(UTF8)$|^(BINARY)$`** + - If this field is null, SageMaker defaults the value to **`BINARY`**. + - If **`DataType = UTF8`**, the binary chunk contains the raw bytes of a UTF-8 encoded string. + - [Optional] CompletionStatus: + - Regex Pattern: **`^(PARTIAL)$|^(COMPLETE)$`** + - If this field is null, SageMaker defaults the value to **`COMPLETE`**. + - A binary chunk may be fragmented across multiple PayloadParts. If **`CompletionStatus = PARTIAL`**, the current PayloadPart is incomplete and shall be aggregated with subsequent PayloadPart until one with **`CompletionStatus = COMPLETE`** is received. + - An un-fragmented binary chunk always has **`CompletionStatus = COMPLETE`**. + - Note that SageMaker does not aggregate PayloadParts on customers’ behalf, the fragments are passed to model containers as-is. + - [Optional] P: + - Padding string defending against token length side channel attack . + +### 2. HTTP/2 Response Syntax + +Response event stream is a union of (by union, we are saying an event in the output stream can be one (and only one) of the followings): + +- ResponsePayload is passed all the way from model container to customer, **_which SageMaker never inspects into_**; +- InternalStreamFailure: mid stream server-side fault (similar to 500 error code, but not the same); +- ModelStreamError: mid stream errors originated from inside the Model Container (similar to 400 error code, but not the same). + +``` +HTTP/2 200 +x-Amzn-Invoked-Production-Variant: + +{ + "PayloadPart": { + "Bytes": , + "DataType": , + "CompletionState": , + "P": + }, + "ModelStreamError": { + "ErrorCode": , + "Message": + }, + "InternalStreamFailure": { + "Message": + } +} +``` + +More detailed explanation **(you can skip this part if you are already familiar with SageMaker publi APIs):** + +- InvokedProductionVariant: + - Identifies the production variant that was invoked. + - Length Constraints: Maximum length of 1024. + - Pattern: `\p{ASCII}*` +- ResponseStream contains three types of events: PayloadPart, ModelStreamError or InternalStreamFailure. The following data is returned in JSON format by the service: + - PayloadPart: A wrapper object of output payload. A response stream consists of one or more pieces of PayloadParts. This object has the following fields: + - [Required] Bytes: + - Base64 encoded binary chunk. + - [Optional] DataType: + - Regex Pattern: **`^(UTF8)$|^(BINARY)$`** + - If this field is null, SageMaker defaults the value to **`BINARY`**. + - If **`DataType = UTF8`**, the binary chunk contains the raw bytes of a UTF-8 encoded string. + - [Optional] CompletionStatus: + - Regex Pattern: **`^(PARTIAL)$|^(COMPLETE)$`** + - If this field is null, SageMaker defaults the value to **`COMPLETE`**. + - A binary chunk may be fragmented across multiple PayloadParts. If **`CompletionStatus = PARTIAL`**, the current PayloadPart is incomplete and shall be aggregated with subsequent PayloadPart until one with **`CompletionStatus = COMPLETE`** is received. + - An unfragmented binary chunk always has **`CompletionStatus = COMPLETE`**. + - Note that SageMaker does not aggregate PayloadParts on customers’ behalf, the fragments are passed from model container to clients as-is. + - [Optional] P: + - Padding string defending against token length side channel attack. + - ModelStreamError: An error originated from the model container while streaming the response body. + - InternalStreamFailure: A fault originated from SageMaker platform while streaming the response body . The stream processing failed because of an unknown error, exception or failure. Try your request again. diff --git a/src/deepgram/client.py b/src/deepgram/client.py index 456d3129..1c2c54bc 100644 --- a/src/deepgram/client.py +++ b/src/deepgram/client.py @@ -16,7 +16,8 @@ from .base_client import AsyncBaseClient, BaseClient -from deepgram.core.client_wrapper import BaseClientWrapper +from deepgram.core.client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper +from deepgram.core.websocket_client import WebSocketFactory from deepgram.extensions.core.instrumented_http import InstrumentedAsyncHttpClient, InstrumentedHttpClient from deepgram.extensions.core.instrumented_socket import apply_websocket_instrumentation from deepgram.extensions.core.telemetry_events import TelemetryHttpEvents, TelemetrySocketEvents @@ -31,10 +32,11 @@ def _create_telemetry_context(session_id: str) -> Dict[str, Any]: # Get package version try: from . import version + package_version = version.__version__ except ImportError: package_version = "unknown" - + return { "package_name": "python-sdk", "package_version": package_version, @@ -55,15 +57,15 @@ def _create_telemetry_context(session_id: str) -> Dict[str, Any]: def _setup_telemetry( - session_id: str, - telemetry_opt_out: bool, + session_id: str, + telemetry_opt_out: bool, telemetry_handler: Optional[TelemetryHandler], client_wrapper: BaseClientWrapper, ) -> Optional[TelemetryHandler]: """Setup telemetry for the client.""" if telemetry_opt_out: return None - + # Use provided handler or create default batching handler if telemetry_handler is None: try: @@ -79,15 +81,15 @@ def _setup_telemetry( except Exception: # If we can't create the handler, disable telemetry return None - + # Setup HTTP instrumentation try: http_events = TelemetryHttpEvents(telemetry_handler) - + # Replace the HTTP client with instrumented version - if hasattr(client_wrapper, 'httpx_client'): + if hasattr(client_wrapper, "httpx_client"): original_client = client_wrapper.httpx_client - if hasattr(original_client, 'httpx_client'): # It's already our HttpClient + if hasattr(original_client, "httpx_client"): # It's already our HttpClient instrumented_client = InstrumentedHttpClient( delegate=original_client, events=http_events, @@ -96,7 +98,7 @@ def _setup_telemetry( except Exception: # If instrumentation fails, continue without it pass - + # Setup WebSocket instrumentation try: socket_events = TelemetrySocketEvents(telemetry_handler) @@ -105,20 +107,20 @@ def _setup_telemetry( except Exception: # If WebSocket instrumentation fails, continue without it pass - + return telemetry_handler def _setup_async_telemetry( - session_id: str, - telemetry_opt_out: bool, + session_id: str, + telemetry_opt_out: bool, telemetry_handler: Optional[TelemetryHandler], client_wrapper: BaseClientWrapper, ) -> Optional[TelemetryHandler]: """Setup telemetry for the async client.""" if telemetry_opt_out: return None - + # Use provided handler or create default batching handler if telemetry_handler is None: try: @@ -134,15 +136,15 @@ def _setup_async_telemetry( except Exception: # If we can't create the handler, disable telemetry return None - + # Setup HTTP instrumentation try: http_events = TelemetryHttpEvents(telemetry_handler) - + # Replace the HTTP client with instrumented version - if hasattr(client_wrapper, 'httpx_client'): + if hasattr(client_wrapper, "httpx_client"): original_client = client_wrapper.httpx_client - if hasattr(original_client, 'httpx_client'): # It's already our AsyncHttpClient + if hasattr(original_client, "httpx_client"): # It's already our AsyncHttpClient instrumented_client = InstrumentedAsyncHttpClient( delegate=original_client, events=http_events, @@ -151,7 +153,7 @@ def _setup_async_telemetry( except Exception: # If instrumentation fails, continue without it pass - + # Setup WebSocket instrumentation try: socket_events = TelemetrySocketEvents(telemetry_handler) @@ -160,7 +162,7 @@ def _setup_async_telemetry( except Exception: # If WebSocket instrumentation fails, continue without it pass - + return telemetry_handler @@ -185,11 +187,13 @@ def _get_headers_with_bearer(_self: Any) -> Dict[str, str]: if hasattr(client_wrapper, "httpx_client") and hasattr(client_wrapper.httpx_client, "base_headers"): client_wrapper.httpx_client.base_headers = client_wrapper.get_headers + class DeepgramClient(BaseClient): def __init__(self, *args, **kwargs) -> None: access_token: Optional[str] = kwargs.pop("access_token", None) telemetry_opt_out: bool = bool(kwargs.pop("telemetry_opt_out", True)) telemetry_handler: Optional[TelemetryHandler] = kwargs.pop("telemetry_handler", None) + websocket_client: Optional[WebSocketFactory] = kwargs.pop("websocket_client", None) # Generate a session id up-front so it can be placed into headers for all transports generated_session_id = str(uuid.uuid4()) @@ -210,9 +214,21 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.session_id = generated_session_id + # If a custom websocket_client is provided, recreate the client wrapper with it + if websocket_client is not None: + original_wrapper = self._client_wrapper + self._client_wrapper = SyncClientWrapper( + api_key=original_wrapper.api_key, + headers=original_wrapper.get_custom_headers(), + environment=original_wrapper.get_environment(), + timeout=original_wrapper.get_timeout(), + httpx_client=original_wrapper.httpx_client.httpx_client, + websocket_client=websocket_client, + ) + if access_token is not None: _apply_bearer_authorization_override(self._client_wrapper, access_token) - + # Setup telemetry self._telemetry_handler = _setup_telemetry( session_id=generated_session_id, @@ -221,11 +237,13 @@ def __init__(self, *args, **kwargs) -> None: client_wrapper=self._client_wrapper, ) + class AsyncDeepgramClient(AsyncBaseClient): def __init__(self, *args, **kwargs) -> None: access_token: Optional[str] = kwargs.pop("access_token", None) telemetry_opt_out: bool = bool(kwargs.pop("telemetry_opt_out", True)) telemetry_handler: Optional[TelemetryHandler] = kwargs.pop("telemetry_handler", None) + websocket_client: Optional[WebSocketFactory] = kwargs.pop("websocket_client", None) # Generate a session id up-front so it can be placed into headers for all transports generated_session_id = str(uuid.uuid4()) @@ -246,13 +264,25 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.session_id = generated_session_id + # If a custom websocket_client is provided, recreate the client wrapper with it + if websocket_client is not None: + original_wrapper = self._client_wrapper + self._client_wrapper = AsyncClientWrapper( + api_key=original_wrapper.api_key, + headers=original_wrapper.get_custom_headers(), + environment=original_wrapper.get_environment(), + timeout=original_wrapper.get_timeout(), + httpx_client=original_wrapper.httpx_client.httpx_client, + websocket_client=websocket_client, + ) + if access_token is not None: _apply_bearer_authorization_override(self._client_wrapper, access_token) - + # Setup telemetry self._telemetry_handler = _setup_async_telemetry( session_id=generated_session_id, telemetry_opt_out=telemetry_opt_out, telemetry_handler=telemetry_handler, client_wrapper=self._client_wrapper, - ) \ No newline at end of file + ) diff --git a/src/deepgram/core/client_wrapper.py b/src/deepgram/core/client_wrapper.py index 806d0202..c7ce0614 100644 --- a/src/deepgram/core/client_wrapper.py +++ b/src/deepgram/core/client_wrapper.py @@ -5,6 +5,7 @@ import httpx from ..environment import DeepgramClientEnvironment from .http_client import AsyncHttpClient, HttpClient +from .websocket_client import WebSocketFactory, get_default_factory class BaseClientWrapper: @@ -15,18 +16,20 @@ def __init__( headers: typing.Optional[typing.Dict[str, str]] = None, environment: DeepgramClientEnvironment, timeout: typing.Optional[float] = None, + websocket_client: typing.Optional[WebSocketFactory] = None, ): self.api_key = api_key self._headers = headers self._environment = environment self._timeout = timeout + self.websocket_client = websocket_client or get_default_factory() def get_headers(self) -> typing.Dict[str, str]: headers: typing.Dict[str, str] = { "X-Fern-Language": "Python", "X-Fern-SDK-Name": "deepgram", # x-release-please-start-version - "X-Fern-SDK-Version": "5.3.0", + "X-Fern-SDK-Version": "5.3.0", # x-release-please-end **(self.get_custom_headers() or {}), } @@ -52,8 +55,15 @@ def __init__( environment: DeepgramClientEnvironment, timeout: typing.Optional[float] = None, httpx_client: httpx.Client, + websocket_client: typing.Optional[WebSocketFactory] = None, ): - super().__init__(api_key=api_key, headers=headers, environment=environment, timeout=timeout) + super().__init__( + api_key=api_key, + headers=headers, + environment=environment, + timeout=timeout, + websocket_client=websocket_client, + ) self.httpx_client = HttpClient( httpx_client=httpx_client, base_headers=self.get_headers, base_timeout=self.get_timeout ) @@ -68,8 +78,15 @@ def __init__( environment: DeepgramClientEnvironment, timeout: typing.Optional[float] = None, httpx_client: httpx.AsyncClient, + websocket_client: typing.Optional[WebSocketFactory] = None, ): - super().__init__(api_key=api_key, headers=headers, environment=environment, timeout=timeout) + super().__init__( + api_key=api_key, + headers=headers, + environment=environment, + timeout=timeout, + websocket_client=websocket_client, + ) self.httpx_client = AsyncHttpClient( httpx_client=httpx_client, base_headers=self.get_headers, base_timeout=self.get_timeout ) diff --git a/src/deepgram/core/websocket_client.py b/src/deepgram/core/websocket_client.py new file mode 100644 index 00000000..cca19257 --- /dev/null +++ b/src/deepgram/core/websocket_client.py @@ -0,0 +1,151 @@ +""" +WebSocket client abstraction layer. + +This module provides a protocol-based abstraction for WebSocket clients, +allowing users to provide custom implementations (e.g., AWS SageMaker) +while maintaining a consistent interface for socket clients. +""" + +from contextlib import asynccontextmanager, contextmanager +from typing import Any, AsyncIterator, Dict, Iterator, Optional, Protocol, runtime_checkable + +import websockets.sync.client as websockets_sync_client + +try: + from websockets.legacy.client import WebSocketClientProtocol # type: ignore +except ImportError: + from websockets import WebSocketClientProtocol # type: ignore + +try: + from websockets.client import connect as websockets_client_connect # type: ignore +except ImportError: + from websockets.legacy.client import connect as websockets_client_connect # type: ignore + +import websockets.sync.connection as websockets_sync_connection + + +@runtime_checkable +class WebSocketProtocol(Protocol): + """ + Protocol defining the minimal interface required for WebSocket connections. + + Custom WebSocket implementations (e.g., AWS SageMaker) should implement + this protocol to be compatible with Deepgram socket clients. + """ + + async def send(self, message: Any) -> None: + """Send a message through the WebSocket.""" + ... + + async def recv(self) -> Any: + """Receive a message from the WebSocket.""" + ... + + def __aiter__(self): + """Async iteration over messages.""" + ... + + +@runtime_checkable +class SyncWebSocketProtocol(Protocol): + """ + Protocol defining the minimal interface required for synchronous WebSocket connections. + """ + + def send(self, message: Any) -> None: + """Send a message through the WebSocket.""" + ... + + def recv(self) -> Any: + """Receive a message from the WebSocket.""" + ... + + def __iter__(self): + """Iteration over messages.""" + ... + + +class WebSocketFactory(Protocol): + """ + Protocol for WebSocket connection factories. + + Users can implement this protocol to provide custom WebSocket clients + (e.g., for AWS SageMaker bidirectional streaming). + """ + + @asynccontextmanager + async def connect_async( + self, + url: str, + headers: Optional[Dict[str, str]] = None, + ) -> AsyncIterator[WebSocketProtocol]: + """ + Create an async WebSocket connection. + + Args: + url: The WebSocket URL to connect to + headers: Optional headers to include in the connection request + + Yields: + A WebSocket connection implementing WebSocketProtocol + """ + ... + + @contextmanager + def connect_sync( + self, + url: str, + headers: Optional[Dict[str, str]] = None, + ) -> Iterator[SyncWebSocketProtocol]: + """ + Create a synchronous WebSocket connection. + + Args: + url: The WebSocket URL to connect to + headers: Optional headers to include in the connection request + + Yields: + A WebSocket connection implementing SyncWebSocketProtocol + """ + ... + + +class DefaultWebSocketFactory: + """ + Default WebSocket factory using the standard websockets library. + + This is the default implementation used when no custom factory is provided. + """ + + @asynccontextmanager + async def connect_async( + self, + url: str, + headers: Optional[Dict[str, str]] = None, + ) -> AsyncIterator[WebSocketClientProtocol]: + """ + Create an async WebSocket connection using the websockets library. + """ + async with websockets_client_connect(url, extra_headers=headers) as protocol: + yield protocol + + @contextmanager + def connect_sync( + self, + url: str, + headers: Optional[Dict[str, str]] = None, + ) -> Iterator[websockets_sync_connection.Connection]: + """ + Create a synchronous WebSocket connection using the websockets library. + """ + with websockets_sync_client.connect(url, additional_headers=headers) as protocol: + yield protocol + + +# Global default factory instance +_default_factory = DefaultWebSocketFactory() + + +def get_default_factory() -> DefaultWebSocketFactory: + """Get the default WebSocket factory.""" + return _default_factory