-
Notifications
You must be signed in to change notification settings - Fork 88
Better code propagation, better Chains websocket API #1918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9cf8b4c
to
1a6cd89
Compare
|
||
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ... | ||
|
||
async def receive(self) -> Union[str, bytes]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Context for this change is here: https://basetenlabs.slack.com/archives/C06CZ3RSXRU/p1757360869605229
1a6cd89
to
04bf0e5
Compare
control_app = APIRouter() | ||
|
||
|
||
class CloseableWebsocket(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe overkill, and mentioned through comments below, but it's unfortunate that the incoming websocket (FastAPI) and outbound websockets (httpx) are different types with incompatible APIs.
We get around that separately below, but luckily close
is common, so I wrapped in a Protocol here
) -> None: | ||
while True: | ||
message = await client_ws.receive() | ||
if message.get("type") == "websocket.disconnect": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs seem to suggest that the client closing the connection could either:
- Trigger an
StartletteWebSocketDisconnect
- Send an explicit
websocket.disconnect
message, that we have to handle
To be safe, we re-raise a StartletteWebSocketDisconnect
here and handle below
await _safe_close_ws(server_ws, logger) | ||
|
||
|
||
async def _attempt_websocket_proxy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactor was mainly to improve on the 6 layer indent for proxy_ws
before, but there are some material changes above
else: | ||
logger.warning(f"Ungraceful websocket close: {exc}") | ||
finally: | ||
await _safe_close_ws(client_ws, logger) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have protections on double close, but I want this in the finally
just to make sure everything is cleaned up in all casese
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had this question as well, you might want to add a comment above just in case!
): | ||
logger = client_ws.app.state.logger | ||
try: | ||
async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing from asyncio.gather
to TaskGroup means the other task gets automatically cancelled when one raises an exception, compared to running until it attempts to receive on a closed websocket
async def _safe_close_ws( | ||
ws: CloseableWebsocket, | ||
logger: logging.Logger, | ||
code: int = 1000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cosmetic: here and in other places, could we use a constant like WEBSOCKET_NORMAL_CLOSE
instead of 1000? Perhaps the websocket library already offers one.
else: | ||
logger.warning(f"Ungraceful websocket close: {exc}") | ||
finally: | ||
await _safe_close_ws(client_ws, logger) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had this question as well, you might want to add a comment above just in case!
tg.create_task(forward_to_client(client_ws, server_ws)) | ||
tg.create_task(forward_to_server(client_ws, server_ws)) | ||
except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821 | ||
exc = eg.exceptions[0] # NB(nikhil): Only care about the first one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the first one? Is it because the remaining ones are consequences of the first one? Worth mentioning that as well.
await asyncio.gather(forward_to_client(), forward_to_server()) | ||
finally: | ||
await _safe_close_ws(client_ws, logger) | ||
await _attempt_websocket_proxy(client_ws, proxy_client, logger) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for breaking this code block down into smaller functions!
await ws.accept() | ||
await self._model.websocket(ws) | ||
await _safe_close_websocket(ws, None, status_code=1000) | ||
await _safe_close_websocket(ws, status_code=1000, reason=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we always specify arguments, do we still need default values for _safe_close_websocket()
? It makes the code DRY-er and more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately on L182 I have a failsafe close, but I can make code
without a default!
except Exception: | ||
await _safe_close_websocket( | ||
ws, errors.MODEL_ERROR_MESSAGE, status_code=1011 | ||
ws, status_code=1011, reason=errors.MODEL_ERROR_MESSAGE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1011 might benefit from a constant string too.
): | ||
await proxy_ws(client_ws) | ||
proxy_task = asyncio.create_task(proxy_ws(client_ws)) | ||
await asyncio.sleep(0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the interest of reducing test flakiness, is there any pattern we could use instead of a sleep()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I actually don't need this - removed!
bf9d00f
to
c3ab8e7
Compare
4d328b1
to
0f94794
Compare
0f94794
to
7bd86ad
Compare
await self._websocket.close(code=code, reason=reason) | ||
|
||
async def receive(self) -> Union[str, bytes]: | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pulled from lines above, since we need to expose WebsocketDisconnect to end users.
This wrapper exists to try and hide FastAPI, but per a comment above we already knew we were leaking details via Exceptions. I'd rather not come up with a separate protocol for how users would have to think about catching websocket disconnects (especially since it would be different in Truss), so for now I just pass it through
@nnarayen should we change the docs here? |
Good call! https://github.com/basetenlabs/docs.baseten.co/pull/263 |
🚀 What
Heavily based off #1912, adds two small improvements:
1000
receive() -> str | bytes
💻 How
🔬 Testing
0.11.1rc1