Skip to content

Commit 04bf0e5

Browse files
committed
Better code propagation, better Websocket websocket API
1 parent 62673f9 commit 04bf0e5

File tree

7 files changed

+106
-51
lines changed

7 files changed

+106
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "truss"
3-
version = "0.11.0"
3+
version = "0.11.1rc1"
44
description = "A seamless bridge from model development to model delivery"
55
authors = [
66
{ name = "Pankaj Gupta", email = "no-reply@baseten.co" },

truss-chains/truss_chains/public_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ class WebSocketProtocol(Protocol):
473473

474474
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
475475

476+
async def receive(self) -> Union[str, bytes]: ...
476477
async def receive_text(self) -> str: ...
477478
async def receive_bytes(self) -> bytes: ...
478479
async def receive_json(self) -> Any: ...

truss-chains/truss_chains/remote_chainlet/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import threading
1212
import time
1313
import traceback
14+
import typing
1415
from collections.abc import AsyncIterator
1516
from typing import (
1617
TYPE_CHECKING,
@@ -586,6 +587,13 @@ def __init__(self, websocket: "fastapi.WebSocket") -> None:
586587
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
587588
await self._websocket.close(code=code, reason=reason)
588589

590+
async def receive(self) -> Union[str, bytes]:
591+
message = await self._websocket.receive()
592+
if message.get("text"):
593+
return typing.cast(str, message["text"])
594+
else:
595+
return typing.cast(bytes, message["bytes"])
596+
589597
async def receive_text(self) -> str:
590598
return await self._websocket.receive_text()
591599

truss/templates/control/control/endpoints.py

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import asyncio
22
import logging
3-
from typing import Any, Callable, Dict
3+
from typing import Any, Callable, Dict, Optional, Protocol
44

55
import httpx
66
from fastapi import APIRouter, WebSocket
77
from fastapi.responses import JSONResponse, StreamingResponse
8+
from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
89
from httpx_ws import _exceptions as httpx_ws_exceptions
9-
from httpx_ws import aconnect_ws
1010
from starlette.requests import ClientDisconnect, Request
1111
from starlette.responses import Response
12+
from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect
1213
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
1314
from wsproto.events import BytesMessage, TextMessage
1415

@@ -30,6 +31,10 @@
3031
control_app = APIRouter()
3132

3233

34+
class CloseableWebsocket(Protocol):
35+
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...
36+
37+
3338
@control_app.get("/")
3439
def index():
3540
return {}
@@ -118,48 +123,81 @@ def inference_retries(
118123
yield attempt
119124

120125

121-
async def _safe_close_ws(ws: WebSocket, logger: logging.Logger):
126+
async def _safe_close_ws(
127+
ws: CloseableWebsocket,
128+
logger: logging.Logger,
129+
code: int = 1000,
130+
reason: Optional[str] = None,
131+
):
122132
try:
123-
await ws.close()
133+
await ws.close(code, reason)
124134
except RuntimeError as close_error:
125135
logger.debug(f"Duplicate close of websocket: `{close_error}`.")
126136

127137

138+
async def forward_to_server(
139+
client_ws: WebSocket, server_ws: AsyncWebSocketSession
140+
) -> None:
141+
while True:
142+
message = await client_ws.receive()
143+
if message.get("type") == "websocket.disconnect":
144+
raise StartletteWebSocketDisconnect(
145+
message.get("code", 1000), message.get("reason")
146+
)
147+
if "text" in message:
148+
await server_ws.send_text(message["text"])
149+
elif "bytes" in message:
150+
await server_ws.send_bytes(message["bytes"])
151+
152+
153+
async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession):
154+
while True:
155+
message = await server_ws.receive()
156+
if isinstance(message, TextMessage):
157+
await client_ws.send_text(message.data)
158+
elif isinstance(message, BytesMessage):
159+
await client_ws.send_bytes(message.data)
160+
161+
162+
# NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer
163+
# versions of truss we're guaranteed to be running the control server with at least that version.
164+
async def _handle_websocket_forwarding(
165+
client_ws: WebSocket, server_ws: AsyncWebSocketSession
166+
):
167+
logger = client_ws.app.state.logger
168+
try:
169+
async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
170+
tg.create_task(forward_to_client(client_ws, server_ws))
171+
tg.create_task(forward_to_server(client_ws, server_ws))
172+
except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
173+
exc = eg.exceptions[0] # NB(nikhil): Only care about the first one.
174+
if isinstance(exc, WebSocketDisconnect):
175+
await _safe_close_ws(client_ws, logger, exc.code, exc.reason)
176+
elif isinstance(exc, StartletteWebSocketDisconnect):
177+
await _safe_close_ws(server_ws, logger, exc.code, exc.reason)
178+
else:
179+
logger.warning(f"Ungraceful websocket close: {exc}")
180+
finally:
181+
await _safe_close_ws(client_ws, logger)
182+
await _safe_close_ws(server_ws, logger)
183+
184+
185+
async def _attempt_websocket_proxy(
186+
client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger
187+
):
188+
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
189+
await client_ws.accept()
190+
await _handle_websocket_forwarding(client_ws, server_ws)
191+
192+
128193
async def proxy_ws(client_ws: WebSocket):
129194
proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client
130195
logger = client_ws.app.state.logger
131196

132197
for attempt in inference_retries():
133198
with attempt:
134199
try:
135-
async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore
136-
# Unfortunate, but FastAPI and httpx-ws have slightly different abstractions
137-
# for sending data, so it's not easy to create a unified wrapper.
138-
async def forward_to_server():
139-
while True:
140-
message = await client_ws.receive()
141-
if message.get("type") == "websocket.disconnect":
142-
break
143-
if "text" in message:
144-
await server_ws.send_text(message["text"])
145-
elif "bytes" in message:
146-
await server_ws.send_bytes(message["bytes"])
147-
148-
async def forward_to_client():
149-
while True:
150-
message = await server_ws.receive()
151-
if message is None:
152-
break
153-
if isinstance(message, TextMessage):
154-
await client_ws.send_text(message.data)
155-
elif isinstance(message, BytesMessage):
156-
await client_ws.send_bytes(message.data)
157-
158-
await client_ws.accept()
159-
try:
160-
await asyncio.gather(forward_to_client(), forward_to_server())
161-
finally:
162-
await _safe_close_ws(client_ws, logger)
200+
await _attempt_websocket_proxy(client_ws, proxy_client, logger)
163201
except httpx_ws_exceptions.HTTPXWSException as e:
164202
logger.warning(f"WebSocket connection rejected: {e}")
165203
await _safe_close_ws(client_ws, logger)

truss/templates/server/truss_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def parse_body(request: Request) -> bytes:
7676

7777

7878
async def _safe_close_websocket(
79-
ws: WebSocket, reason: Optional[str], status_code: int = 1000
79+
ws: WebSocket, status_code: int = 1000, reason: Optional[str] = None
8080
) -> None:
8181
try:
8282
await ws.close(code=status_code, reason=reason)
@@ -257,14 +257,14 @@ async def websocket(self, ws: WebSocket) -> None:
257257
try:
258258
await ws.accept()
259259
await self._model.websocket(ws)
260-
await _safe_close_websocket(ws, None, status_code=1000)
260+
await _safe_close_websocket(ws, status_code=1000, reason=None)
261261
except WebSocketDisconnect as ws_error:
262262
logging.info(
263263
f"Client terminated websocket connection: `{ws_error}`."
264264
)
265265
except Exception:
266266
await _safe_close_websocket(
267-
ws, errors.MODEL_ERROR_MESSAGE, status_code=1011
267+
ws, status_code=1011, reason=errors.MODEL_ERROR_MESSAGE
268268
)
269269
raise # Re raise to let `intercept_exceptions` deal with it.
270270

truss/tests/templates/control/control/test_endpoints.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from unittest.mock import AsyncMock, MagicMock, patch
1+
import asyncio
2+
from unittest.mock import AsyncMock, MagicMock, call, patch
23

34
import pytest
45
from fastapi import FastAPI, WebSocket
@@ -31,33 +32,40 @@ def client_ws(app):
3132

3233
@pytest.mark.asyncio
3334
async def test_proxy_ws_bidirectional_messaging(client_ws):
34-
"""Test that both directions of communication work and clean up properly"""
35-
client_ws.receive.side_effect = [
36-
{"type": "websocket.receive", "text": "msg1"},
37-
{"type": "websocket.receive", "text": "msg2"},
38-
{"type": "websocket.disconnect"},
39-
]
35+
client_queue = asyncio.Queue()
36+
client_ws.receive = client_queue.get
4037

38+
server_queue = asyncio.Queue()
4139
mock_server_ws = AsyncMock(spec=AsyncWebSocketSession)
42-
mock_server_ws.receive.side_effect = [
43-
TextMessage(data="response1"),
44-
TextMessage(data="response2"),
45-
None, # server closing connection
46-
]
40+
mock_server_ws.receive = server_queue.get
4741
mock_server_ws.__aenter__.return_value = mock_server_ws
4842
mock_server_ws.__aexit__.return_value = None
4943

44+
client_queue.put_nowait({"type": "websocket.receive", "text": "msg1"})
45+
client_queue.put_nowait({"type": "websocket.receive", "text": "msg2"})
46+
server_queue.put_nowait(TextMessage(data="response1"))
47+
server_queue.put_nowait(TextMessage(data="response2"))
48+
5049
with patch(
5150
"truss.templates.control.control.endpoints.aconnect_ws",
5251
return_value=mock_server_ws,
5352
):
54-
await proxy_ws(client_ws)
53+
proxy_task = asyncio.create_task(proxy_ws(client_ws))
54+
await asyncio.sleep(0.5)
55+
56+
client_queue.put_nowait(
57+
{"type": "websocket.disconnect", "code": 1002, "reason": "test-closure"}
58+
)
59+
60+
await proxy_task
5561

5662
assert mock_server_ws.send_text.call_count == 2
5763
assert mock_server_ws.send_text.call_args_list == [(("msg1",),), (("msg2",),)]
5864
assert client_ws.send_text.call_count == 2
5965
assert client_ws.send_text.call_args_list == [(("response1",),), (("response2",),)]
60-
client_ws.close.assert_called_once()
66+
67+
assert mock_server_ws.close.call_args_list[0] == call(1002, "test-closure")
68+
client_ws.close.assert_called()
6169

6270

6371
@pytest.mark.asyncio

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)