Skip to content

Conversation

nnarayen
Copy link
Contributor

@nnarayen nnarayen commented Sep 9, 2025

🚀 What

Heavily based off #1912, adds two small improvements:

  1. Propagates error codes bidirectionally instead of masking with 1000
  2. Improves the API for dealing with websockets in Chains, exposing the more generic receive() -> str | bytes

💻 How

🔬 Testing

  • Unit tests
  • Testing with 0.11.1rc1

@nnarayen nnarayen force-pushed the nikhil/propagate-error-codes branch 3 times, most recently from 9cf8b4c to 1a6cd89 Compare September 9, 2025 04:00

async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...

async def receive(self) -> Union[str, bytes]: ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nnarayen nnarayen force-pushed the nikhil/propagate-error-codes branch from 1a6cd89 to 04bf0e5 Compare September 9, 2025 04:08
control_app = APIRouter()


class CloseableWebsocket(Protocol):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe overkill, and mentioned through comments below, but it's unfortunate that the incoming websocket (FastAPI) and outbound websockets (httpx) are different types with incompatible APIs.

We get around that separately below, but luckily close is common, so I wrapped in a Protocol here

) -> None:
while True:
message = await client_ws.receive()
if message.get("type") == "websocket.disconnect":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs seem to suggest that the client closing the connection could either:

  • Trigger an StartletteWebSocketDisconnect
  • Send an explicit websocket.disconnect message, that we have to handle

To be safe, we re-raise a StartletteWebSocketDisconnect here and handle below

await _safe_close_ws(server_ws, logger)


async def _attempt_websocket_proxy(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactor was mainly to improve on the 6 layer indent for proxy_ws before, but there are some material changes above

else:
logger.warning(f"Ungraceful websocket close: {exc}")
finally:
await _safe_close_ws(client_ws, logger)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have protections on double close, but I want this in the finally just to make sure everything is cleaned up in all casese

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had this question as well, you might want to add a comment above just in case!

):
logger = client_ws.app.state.logger
try:
async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing from asyncio.gather to TaskGroup means the other task gets automatically cancelled when one raises an exception, compared to running until it attempts to receive on a closed websocket

async def _safe_close_ws(
ws: CloseableWebsocket,
logger: logging.Logger,
code: int = 1000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cosmetic: here and in other places, could we use a constant like WEBSOCKET_NORMAL_CLOSE instead of 1000? Perhaps the websocket library already offers one.

else:
logger.warning(f"Ungraceful websocket close: {exc}")
finally:
await _safe_close_ws(client_ws, logger)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had this question as well, you might want to add a comment above just in case!

tg.create_task(forward_to_client(client_ws, server_ws))
tg.create_task(forward_to_server(client_ws, server_ws))
except ExceptionGroup as eg: # type: ignore[name-defined] # noqa: F821
exc = eg.exceptions[0] # NB(nikhil): Only care about the first one.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the first one? Is it because the remaining ones are consequences of the first one? Worth mentioning that as well.

await asyncio.gather(forward_to_client(), forward_to_server())
finally:
await _safe_close_ws(client_ws, logger)
await _attempt_websocket_proxy(client_ws, proxy_client, logger)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for breaking this code block down into smaller functions!

await ws.accept()
await self._model.websocket(ws)
await _safe_close_websocket(ws, None, status_code=1000)
await _safe_close_websocket(ws, status_code=1000, reason=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we always specify arguments, do we still need default values for _safe_close_websocket()? It makes the code DRY-er and more explicit.

Copy link
Contributor Author

@nnarayen nnarayen Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately on L182 I have a failsafe close, but I can make code without a default!

except Exception:
await _safe_close_websocket(
ws, errors.MODEL_ERROR_MESSAGE, status_code=1011
ws, status_code=1011, reason=errors.MODEL_ERROR_MESSAGE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1011 might benefit from a constant string too.

):
await proxy_ws(client_ws)
proxy_task = asyncio.create_task(proxy_ws(client_ws))
await asyncio.sleep(0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of reducing test flakiness, is there any pattern we could use instead of a sleep()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I actually don't need this - removed!

@nnarayen nnarayen force-pushed the nikhil/propagate-error-codes branch 2 times, most recently from bf9d00f to c3ab8e7 Compare September 9, 2025 22:18
@nnarayen nnarayen force-pushed the nikhil/propagate-error-codes branch 4 times, most recently from 4d328b1 to 0f94794 Compare September 9, 2025 23:20
@nnarayen nnarayen force-pushed the nikhil/propagate-error-codes branch from 0f94794 to 7bd86ad Compare September 10, 2025 03:11
await self._websocket.close(code=code, reason=reason)

async def receive(self) -> Union[str, bytes]:
try:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pulled from lines above, since we need to expose WebsocketDisconnect to end users.

This wrapper exists to try and hide FastAPI, but per a comment above we already knew we were leaking details via Exceptions. I'd rather not come up with a separate protocol for how users would have to think about catching websocket disconnects (especially since it would be different in Truss), so for now I just pass it through

@nnarayen nnarayen merged commit ee8aed0 into main Sep 10, 2025
31 checks passed
@nnarayen nnarayen deleted the nikhil/propagate-error-codes branch September 10, 2025 03:22
@nnarayen nnarayen changed the title Better code propagation, better Websocket websocket API Better code propagation, better Chains websocket API Sep 10, 2025
@squidarth
Copy link
Collaborator

@nnarayen should we change the docs here?

@nnarayen
Copy link
Contributor Author

@nnarayen should we change the docs here?

Good call! https://github.com/basetenlabs/docs.baseten.co/pull/263

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants