|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | | -from typing import Any, Callable, Dict |
| 3 | +from typing import Any, Callable, Dict, Optional, Protocol |
4 | 4 |
|
5 | 5 | import httpx |
6 | 6 | from fastapi import APIRouter, WebSocket |
7 | 7 | from fastapi.responses import JSONResponse, StreamingResponse |
| 8 | +from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws |
8 | 9 | from httpx_ws import _exceptions as httpx_ws_exceptions |
9 | | -from httpx_ws import aconnect_ws |
10 | 10 | from starlette.requests import ClientDisconnect, Request |
11 | 11 | from starlette.responses import Response |
| 12 | +from starlette.websockets import WebSocketDisconnect as StartletteWebSocketDisconnect |
12 | 13 | from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed |
13 | 14 | from wsproto.events import BytesMessage, TextMessage |
14 | 15 |
|
|
30 | 31 | control_app = APIRouter() |
31 | 32 |
|
32 | 33 |
|
| 34 | +class CloseableWebsocket(Protocol): |
| 35 | + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ... |
| 36 | + |
| 37 | + |
33 | 38 | @control_app.get("/") |
34 | 39 | def index(): |
35 | 40 | return {} |
@@ -118,48 +123,81 @@ def inference_retries( |
118 | 123 | yield attempt |
119 | 124 |
|
120 | 125 |
|
121 | | -async def _safe_close_ws(ws: WebSocket, logger: logging.Logger): |
| 126 | +async def _safe_close_ws( |
| 127 | + ws: CloseableWebsocket, |
| 128 | + logger: logging.Logger, |
| 129 | + code: int = 1000, |
| 130 | + reason: Optional[str] = None, |
| 131 | +): |
122 | 132 | try: |
123 | | - await ws.close() |
| 133 | + await ws.close(code, reason) |
124 | 134 | except RuntimeError as close_error: |
125 | 135 | logger.debug(f"Duplicate close of websocket: `{close_error}`.") |
126 | 136 |
|
127 | 137 |
|
| 138 | +async def forward_to_server( |
| 139 | + client_ws: WebSocket, server_ws: AsyncWebSocketSession |
| 140 | +) -> None: |
| 141 | + while True: |
| 142 | + message = await client_ws.receive() |
| 143 | + if message.get("type") == "websocket.disconnect": |
| 144 | + raise StartletteWebSocketDisconnect( |
| 145 | + message.get("code", 1000), message.get("reason") |
| 146 | + ) |
| 147 | + if "text" in message: |
| 148 | + await server_ws.send_text(message["text"]) |
| 149 | + elif "bytes" in message: |
| 150 | + await server_ws.send_bytes(message["bytes"]) |
| 151 | + |
| 152 | + |
| 153 | +async def forward_to_client(client_ws: WebSocket, server_ws: AsyncWebSocketSession): |
| 154 | + while True: |
| 155 | + message = await server_ws.receive() |
| 156 | + if isinstance(message, TextMessage): |
| 157 | + await client_ws.send_text(message.data) |
| 158 | + elif isinstance(message, BytesMessage): |
| 159 | + await client_ws.send_bytes(message.data) |
| 160 | + |
| 161 | + |
| 162 | +# NB(nikhil): _handle_websocket_forwarding uses some py311 specific syntax, but in newer |
| 163 | +# versions of truss we're guaranteed to be running the control server with at least that version. |
| 164 | +async def _handle_websocket_forwarding( |
| 165 | + client_ws: WebSocket, server_ws: AsyncWebSocketSession |
| 166 | +): |
| 167 | + logger = client_ws.app.state.logger |
| 168 | + try: |
| 169 | + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] |
| 170 | + tg.create_task(forward_to_client(client_ws, server_ws)) |
| 171 | + tg.create_task(forward_to_server(client_ws, server_ws)) |
| 172 | + except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821 |
| 173 | + exc = eg.exceptions[0] # NB(nikhil): Only care about the first one. |
| 174 | + if isinstance(exc, WebSocketDisconnect): |
| 175 | + await _safe_close_ws(client_ws, logger, exc.code, exc.reason) |
| 176 | + elif isinstance(exc, StartletteWebSocketDisconnect): |
| 177 | + await _safe_close_ws(server_ws, logger, exc.code, exc.reason) |
| 178 | + else: |
| 179 | + logger.warning(f"Ungraceful websocket close: {exc}") |
| 180 | + finally: |
| 181 | + await _safe_close_ws(client_ws, logger) |
| 182 | + await _safe_close_ws(server_ws, logger) |
| 183 | + |
| 184 | + |
| 185 | +async def _attempt_websocket_proxy( |
| 186 | + client_ws: WebSocket, proxy_client: httpx.AsyncClient, logger |
| 187 | +): |
| 188 | + async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore |
| 189 | + await client_ws.accept() |
| 190 | + await _handle_websocket_forwarding(client_ws, server_ws) |
| 191 | + |
| 192 | + |
128 | 193 | async def proxy_ws(client_ws: WebSocket): |
129 | 194 | proxy_client: httpx.AsyncClient = client_ws.app.state.proxy_client |
130 | 195 | logger = client_ws.app.state.logger |
131 | 196 |
|
132 | 197 | for attempt in inference_retries(): |
133 | 198 | with attempt: |
134 | 199 | try: |
135 | | - async with aconnect_ws("/v1/websocket", proxy_client) as server_ws: # type: ignore |
136 | | - # Unfortunate, but FastAPI and httpx-ws have slightly different abstractions |
137 | | - # for sending data, so it's not easy to create a unified wrapper. |
138 | | - async def forward_to_server(): |
139 | | - while True: |
140 | | - message = await client_ws.receive() |
141 | | - if message.get("type") == "websocket.disconnect": |
142 | | - break |
143 | | - if "text" in message: |
144 | | - await server_ws.send_text(message["text"]) |
145 | | - elif "bytes" in message: |
146 | | - await server_ws.send_bytes(message["bytes"]) |
147 | | - |
148 | | - async def forward_to_client(): |
149 | | - while True: |
150 | | - message = await server_ws.receive() |
151 | | - if message is None: |
152 | | - break |
153 | | - if isinstance(message, TextMessage): |
154 | | - await client_ws.send_text(message.data) |
155 | | - elif isinstance(message, BytesMessage): |
156 | | - await client_ws.send_bytes(message.data) |
157 | | - |
158 | | - await client_ws.accept() |
159 | | - try: |
160 | | - await asyncio.gather(forward_to_client(), forward_to_server()) |
161 | | - finally: |
162 | | - await _safe_close_ws(client_ws, logger) |
| 200 | + await _attempt_websocket_proxy(client_ws, proxy_client, logger) |
163 | 201 | except httpx_ws_exceptions.HTTPXWSException as e: |
164 | 202 | logger.warning(f"WebSocket connection rejected: {e}") |
165 | 203 | await _safe_close_ws(client_ws, logger) |
|
0 commit comments