Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .fernignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,23 @@ tests/unit/
tests/integrations/

# Custom Extensions & Clients
# - client.py: Custom DeepgramClient wrapper that adds access_token, telemetry, and websocket_client support
# - websocket_client.py: WebSocket abstraction layer for custom clients (e.g., AWS SageMaker)
# - extensions/: All custom extension modules (telemetry, instrumentation, socket types)
src/deepgram/client.py
src/deepgram/core/websocket_client.py
src/deepgram/extensions/

# Socket Client Implementations
# Enhanced with binary message support, comprehensive socket types, and send methods
# These wrap the raw WebSocket connections with Deepgram-specific message handling
src/deepgram/agent/v1/socket_client.py
src/deepgram/listen/v1/socket_client.py
src/deepgram/listen/v2/socket_client.py
src/deepgram/speak/v1/socket_client.py

# Bug Fixes
# Core Modifications
# - client_wrapper.py: Enhanced to support custom websocket_client parameter
# - listen/client.py: Bug fix for v2 client property access (pre-existing, not related to WebSocket work)
src/deepgram/listen/client.py
src/deepgram/core/client_wrapper.py
104 changes: 104 additions & 0 deletions sagemaker_bidi_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# [Private Beta] SageMakerRuntime BiDi API Contract

### 1. HTTP/2 Request Syntax

```
POST https://runtime.sagemaker.us-east-2.amazonaws.com:8443/endpoints/<EndpointName>/invocations-bidirectional-stream

[Optional] X-Amzn-SageMaker-Target-Variant: <TargetVariant>
[Optional] X-Amzn-SageMaker-Model-Invocation-Path: <ModelInvocationPath>
[Optional] X-Amzn-SageMaker-Model-Query-String: <ModelQueryString>
{
"PayloadPart": {
"Bytes": <Blob>,
"DataType": <String: UTF8 | BINARY>,
"CompletionState": <String: PARTIAL | COMPLETE>,
"P": <String>
}
}
```

More detailed explanation **(you can skip this part if you are already familiar with SageMaker publi APIs):**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix spelling errors in user-facing text.

The word "publi" should be "public" in both instances.

Apply this diff to correct the spelling:

- More detailed explanation **(you can skip this part if you are already familiar with SageMaker publi APIs):**
+ More detailed explanation **(you can skip this part if you are already familiar with SageMaker public APIs):**

Apply the same fix to line 81.

Also applies to: 81-81

🧰 Tools
🪛 LanguageTool

[grammar] ~21-~21: Ensure spelling is correct
Context: ...you are already familiar with SageMaker publi APIs):** - [Optional] TargetVariant: ...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🤖 Prompt for AI Agents
In sagemaker_bidi_api.md around lines 21 and 81 the user-facing text contains
the misspelling "publi" instead of "public"; update both occurrences to read
"public" so the phrases become "SageMaker public APIs" (or equivalent) ensuring
consistent, correct spelling in both locations.


- [Optional] TargetVariant:
- Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights.
- Length Constraints: Maximum length of 63.
- Pattern: **`^[a-zA-Z0-9](-*[a-zA-Z0-9])*`**
- [Optional] ModelInvocationPath:
- SageMaker connections to model container using URL **`ws://local:8081/<ModelInvocationPath>`**. This parameter defaults to **`invoke-bidi-stream`** if not specified.
- Length Constraints: Maximum length of 100.
- Pattern: **`^[A-Za-z0-9\-._]+(?:/[A-Za-z0-9\-._]+)*$`**
- [Optional] ModelQueryString:
- If specified, SageMaker appends it to the URL when connecting to model containers: **`ws://local:8081/<ModelInvocationPath>?<ModelQueryString>`**.
- Length Constraints: Maximum length of 2048.
- Pattern: **`^[a-zA-Z0-9][A-Za-z0-9_-]*=(?:[A-Za-z0-9._~\-]|%[0-9A-Fa-f]{2})+(?:&[a-zA-Z0-9][A-Za-z0-9_-]*=(?:[A-Za-z0-9._~\-]|%[0-9A-Fa-f]{2})+)*$`**
- [Required] RequestStream (the following data is sent in JSON format by the client to the service):
- PayloadPart: A wrapper object of input payload. A request stream consists of one or more pieces of PayloadParts. This object has the following fields:
- [Required] Bytes:
- Base64 encoded binary chunk.
- [Optional] DataType:
- Regex Pattern: **`^(UTF8)$|^(BINARY)$`**
- If this field is null, SageMaker defaults the value to **`BINARY`**.
- If **`DataType = UTF8`**, the binary chunk contains the raw bytes of a UTF-8 encoded string.
- [Optional] CompletionStatus:
- Regex Pattern: **`^(PARTIAL)$|^(COMPLETE)$`**
- If this field is null, SageMaker defaults the value to **`COMPLETE`**.
- A binary chunk may be fragmented across multiple PayloadParts. If **`CompletionStatus = PARTIAL`**, the current PayloadPart is incomplete and shall be aggregated with subsequent PayloadPart until one with **`CompletionStatus = COMPLETE`** is received.
- An un-fragmented binary chunk always has **`CompletionStatus = COMPLETE`**.
- Note that SageMaker does not aggregate PayloadParts on customers’ behalf, the fragments are passed to model containers as-is.
- [Optional] P:
- Padding string defending against token length side channel attack .
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove encoding artifacts from padding field descriptions.

Lines 50 and 102 contain zero-width or special characters that should be removed for clarity.

Apply this diff to clean up the encoding artifacts:

- - [Optional] P:
-   - Padding string defending against token length side channel attack .
+ - [Optional] P:
+   - Padding string defending against token length side channel attack.

Apply the same fix to line 102.

Also applies to: 102-102

🤖 Prompt for AI Agents
In sagemaker_bidi_api.md around lines 50 and 102, remove hidden
zero-width/special Unicode characters present in the padding field descriptions
(e.g., after words like "defending against" or elsewhere) so the text reads
normally; open those lines, delete any zero-width spaces or non-printing
characters, and save the file ensuring the descriptions use only regular
ASCII/UTF-8 printable characters and spacing.


### 2. HTTP/2 Response Syntax

Response event stream is a union of (by union, we are saying an event in the output stream can be one (and only one) of the followings):

- ResponsePayload is passed all the way from model container to customer, **_which SageMaker never inspects into_**;
- InternalStreamFailure: mid stream server-side fault (similar to 500 error code, but not the same);
- ModelStreamError: mid stream errors originated from inside the Model Container (similar to 400 error code, but not the same).

```
HTTP/2 200
x-Amzn-Invoked-Production-Variant: <InvokedProductionVariant>

{
"PayloadPart": {
"Bytes": <Blob>,
"DataType": <String: UTF8 | BINARY>,
"CompletionState": <String: PARTIAL | COMPLETE>,
"P": <String>
},
"ModelStreamError": {
"ErrorCode": <String>,
"Message": <String>
},
"InternalStreamFailure": {
"Message": <String>
}
}
```

More detailed explanation **(you can skip this part if you are already familiar with SageMaker publi APIs):**

- InvokedProductionVariant:
- Identifies the production variant that was invoked.
- Length Constraints: Maximum length of 1024.
- Pattern: `\p{ASCII}*`
- ResponseStream contains three types of events: PayloadPart, ModelStreamError or InternalStreamFailure. The following data is returned in JSON format by the service:
- PayloadPart: A wrapper object of output payload. A response stream consists of one or more pieces of PayloadParts. This object has the following fields:
- [Required] Bytes:
- Base64 encoded binary chunk.
- [Optional] DataType:
- Regex Pattern: **`^(UTF8)$|^(BINARY)$`**
- If this field is null, SageMaker defaults the value to **`BINARY`**.
- If **`DataType = UTF8`**, the binary chunk contains the raw bytes of a UTF-8 encoded string.
- [Optional] CompletionStatus:
- Regex Pattern: **`^(PARTIAL)$|^(COMPLETE)$`**
- If this field is null, SageMaker defaults the value to **`COMPLETE`**.
- A binary chunk may be fragmented across multiple PayloadParts. If **`CompletionStatus = PARTIAL`**, the current PayloadPart is incomplete and shall be aggregated with subsequent PayloadPart until one with **`CompletionStatus = COMPLETE`** is received.
- An unfragmented binary chunk always has **`CompletionStatus = COMPLETE`**.
- Note that SageMaker does not aggregate PayloadParts on customers’ behalf, the fragments are passed from model container to clients as-is.
- [Optional] P:
- Padding string defending against token length side channel attack.
- ModelStreamError: An error originated from the model container while streaming the response body.
- InternalStreamFailure: A fault originated from SageMaker platform while streaming the response body . The stream processing failed because of an unknown error, exception or failure. Try your request again.
76 changes: 53 additions & 23 deletions src/deepgram/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from .base_client import AsyncBaseClient, BaseClient

from deepgram.core.client_wrapper import BaseClientWrapper
from deepgram.core.client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
from deepgram.core.websocket_client import WebSocketFactory
from deepgram.extensions.core.instrumented_http import InstrumentedAsyncHttpClient, InstrumentedHttpClient
from deepgram.extensions.core.instrumented_socket import apply_websocket_instrumentation
from deepgram.extensions.core.telemetry_events import TelemetryHttpEvents, TelemetrySocketEvents
Expand All @@ -31,10 +32,11 @@ def _create_telemetry_context(session_id: str) -> Dict[str, Any]:
# Get package version
try:
from . import version

package_version = version.__version__
except ImportError:
package_version = "unknown"

return {
"package_name": "python-sdk",
"package_version": package_version,
Expand All @@ -55,15 +57,15 @@ def _create_telemetry_context(session_id: str) -> Dict[str, Any]:


def _setup_telemetry(
session_id: str,
telemetry_opt_out: bool,
session_id: str,
telemetry_opt_out: bool,
telemetry_handler: Optional[TelemetryHandler],
client_wrapper: BaseClientWrapper,
) -> Optional[TelemetryHandler]:
"""Setup telemetry for the client."""
if telemetry_opt_out:
return None

# Use provided handler or create default batching handler
if telemetry_handler is None:
try:
Expand All @@ -79,15 +81,15 @@ def _setup_telemetry(
except Exception:
# If we can't create the handler, disable telemetry
return None

# Setup HTTP instrumentation
try:
http_events = TelemetryHttpEvents(telemetry_handler)

# Replace the HTTP client with instrumented version
if hasattr(client_wrapper, 'httpx_client'):
if hasattr(client_wrapper, "httpx_client"):
original_client = client_wrapper.httpx_client
if hasattr(original_client, 'httpx_client'): # It's already our HttpClient
if hasattr(original_client, "httpx_client"): # It's already our HttpClient
instrumented_client = InstrumentedHttpClient(
delegate=original_client,
events=http_events,
Expand All @@ -96,7 +98,7 @@ def _setup_telemetry(
except Exception:
# If instrumentation fails, continue without it
pass

# Setup WebSocket instrumentation
try:
socket_events = TelemetrySocketEvents(telemetry_handler)
Expand All @@ -105,20 +107,20 @@ def _setup_telemetry(
except Exception:
# If WebSocket instrumentation fails, continue without it
pass

return telemetry_handler


def _setup_async_telemetry(
session_id: str,
telemetry_opt_out: bool,
session_id: str,
telemetry_opt_out: bool,
telemetry_handler: Optional[TelemetryHandler],
client_wrapper: BaseClientWrapper,
) -> Optional[TelemetryHandler]:
"""Setup telemetry for the async client."""
if telemetry_opt_out:
return None

# Use provided handler or create default batching handler
if telemetry_handler is None:
try:
Expand All @@ -134,15 +136,15 @@ def _setup_async_telemetry(
except Exception:
# If we can't create the handler, disable telemetry
return None

# Setup HTTP instrumentation
try:
http_events = TelemetryHttpEvents(telemetry_handler)

# Replace the HTTP client with instrumented version
if hasattr(client_wrapper, 'httpx_client'):
if hasattr(client_wrapper, "httpx_client"):
original_client = client_wrapper.httpx_client
if hasattr(original_client, 'httpx_client'): # It's already our AsyncHttpClient
if hasattr(original_client, "httpx_client"): # It's already our AsyncHttpClient
instrumented_client = InstrumentedAsyncHttpClient(
delegate=original_client,
events=http_events,
Expand All @@ -151,7 +153,7 @@ def _setup_async_telemetry(
except Exception:
# If instrumentation fails, continue without it
pass

# Setup WebSocket instrumentation
try:
socket_events = TelemetrySocketEvents(telemetry_handler)
Expand All @@ -160,7 +162,7 @@ def _setup_async_telemetry(
except Exception:
# If WebSocket instrumentation fails, continue without it
pass

return telemetry_handler


Expand All @@ -185,11 +187,13 @@ def _get_headers_with_bearer(_self: Any) -> Dict[str, str]:
if hasattr(client_wrapper, "httpx_client") and hasattr(client_wrapper.httpx_client, "base_headers"):
client_wrapper.httpx_client.base_headers = client_wrapper.get_headers


class DeepgramClient(BaseClient):
def __init__(self, *args, **kwargs) -> None:
access_token: Optional[str] = kwargs.pop("access_token", None)
telemetry_opt_out: bool = bool(kwargs.pop("telemetry_opt_out", True))
telemetry_handler: Optional[TelemetryHandler] = kwargs.pop("telemetry_handler", None)
websocket_client: Optional[WebSocketFactory] = kwargs.pop("websocket_client", None)

# Generate a session id up-front so it can be placed into headers for all transports
generated_session_id = str(uuid.uuid4())
Expand All @@ -210,9 +214,21 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.session_id = generated_session_id

# If a custom websocket_client is provided, recreate the client wrapper with it
if websocket_client is not None:
original_wrapper = self._client_wrapper
self._client_wrapper = SyncClientWrapper(
api_key=original_wrapper.api_key,
headers=original_wrapper.get_custom_headers(),
environment=original_wrapper.get_environment(),
timeout=original_wrapper.get_timeout(),
httpx_client=original_wrapper.httpx_client.httpx_client,
websocket_client=websocket_client,
)

if access_token is not None:
_apply_bearer_authorization_override(self._client_wrapper, access_token)

# Setup telemetry
self._telemetry_handler = _setup_telemetry(
session_id=generated_session_id,
Expand All @@ -221,11 +237,13 @@ def __init__(self, *args, **kwargs) -> None:
client_wrapper=self._client_wrapper,
)


class AsyncDeepgramClient(AsyncBaseClient):
def __init__(self, *args, **kwargs) -> None:
access_token: Optional[str] = kwargs.pop("access_token", None)
telemetry_opt_out: bool = bool(kwargs.pop("telemetry_opt_out", True))
telemetry_handler: Optional[TelemetryHandler] = kwargs.pop("telemetry_handler", None)
websocket_client: Optional[WebSocketFactory] = kwargs.pop("websocket_client", None)

# Generate a session id up-front so it can be placed into headers for all transports
generated_session_id = str(uuid.uuid4())
Expand All @@ -246,13 +264,25 @@ def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.session_id = generated_session_id

# If a custom websocket_client is provided, recreate the client wrapper with it
if websocket_client is not None:
original_wrapper = self._client_wrapper
self._client_wrapper = AsyncClientWrapper(
api_key=original_wrapper.api_key,
headers=original_wrapper.get_custom_headers(),
environment=original_wrapper.get_environment(),
timeout=original_wrapper.get_timeout(),
httpx_client=original_wrapper.httpx_client.httpx_client,
websocket_client=websocket_client,
)

if access_token is not None:
_apply_bearer_authorization_override(self._client_wrapper, access_token)

# Setup telemetry
self._telemetry_handler = _setup_async_telemetry(
session_id=generated_session_id,
telemetry_opt_out=telemetry_opt_out,
telemetry_handler=telemetry_handler,
client_wrapper=self._client_wrapper,
)
)
Loading
Loading