@@ -5,30 +5,27 @@ use gas::prelude::*;
55use hyper:: header:: HeaderName ;
66use rivet_guard_core:: proxy_service:: { RouteConfig , RouteTarget , RoutingOutput , RoutingTimeout } ;
77
8- use super :: SEC_WEBSOCKET_PROTOCOL ;
8+ use super :: { SEC_WEBSOCKET_PROTOCOL , WS_PROTOCOL_ACTOR , WS_PROTOCOL_TOKEN , X_RIVET_TOKEN } ;
99use crate :: { errors, shared_state:: SharedState } ;
1010
1111const ACTOR_READY_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
1212pub const X_RIVET_ACTOR : HeaderName = HeaderName :: from_static ( "x-rivet-actor" ) ;
13- const WS_PROTOCOL_ACTOR : & str = "rivet_actor." ;
1413
1514/// Route requests to actor services using path-based routing
1615#[ tracing:: instrument( skip_all) ]
1716pub async fn route_request_path_based (
1817 ctx : & StandaloneCtx ,
1918 shared_state : & SharedState ,
2019 actor_id_str : & str ,
21- _token : Option < & str > ,
20+ token : Option < & str > ,
2221 path : & str ,
2322 _headers : & hyper:: HeaderMap ,
2423 _is_websocket : bool ,
2524) -> Result < Option < RoutingOutput > > {
26- // NOTE: Token validation implemented in EE
27-
2825 // Parse actor ID
2926 let actor_id = Id :: parse ( actor_id_str) . context ( "invalid actor id in path" ) ?;
3027
31- route_request_inner ( ctx, shared_state, actor_id, path) . await
28+ route_request_inner ( ctx, shared_state, actor_id, path, token ) . await
3229}
3330
3431/// Route requests to actor services based on headers
@@ -47,28 +44,39 @@ pub async fn route_request(
4744 return Ok ( None ) ;
4845 }
4946
50- // Extract actor ID from WebSocket protocol or HTTP header
51- let actor_id_str = if is_websocket {
47+ // Extract actor ID and token from WebSocket protocol or HTTP headers
48+ let ( actor_id_str, token ) = if is_websocket {
5249 // For WebSocket, parse the sec-websocket-protocol header
53- headers
50+ let protocols_header = headers
5451 . get ( SEC_WEBSOCKET_PROTOCOL )
5552 . and_then ( |protocols| protocols. to_str ( ) . ok ( ) )
56- . and_then ( |protocols| {
57- // Parse protocols to find actor.{id}
58- protocols
59- . split ( ',' )
60- . map ( |p| p. trim ( ) )
61- . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_ACTOR ) )
62- } )
53+ . ok_or_else ( || {
54+ crate :: errors:: MissingHeader {
55+ header : "sec-websocket-protocol" . to_string ( ) ,
56+ }
57+ . build ( )
58+ } ) ?;
59+
60+ let protocols: Vec < & str > = protocols_header. split ( ',' ) . map ( |p| p. trim ( ) ) . collect ( ) ;
61+
62+ let actor_id = protocols
63+ . iter ( )
64+ . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_ACTOR ) )
6365 . ok_or_else ( || {
6466 crate :: errors:: MissingHeader {
6567 header : "`rivet_actor.*` protocol in sec-websocket-protocol" . to_string ( ) ,
6668 }
6769 . build ( )
68- } ) ?
70+ } ) ?;
71+
72+ let token = protocols
73+ . iter ( )
74+ . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_TOKEN ) ) ;
75+
76+ ( actor_id, token)
6977 } else {
70- // For HTTP, use the x-rivet-actor header
71- headers
78+ // For HTTP, use headers
79+ let actor_id = headers
7280 . get ( X_RIVET_ACTOR )
7381 . map ( |x| x. to_str ( ) )
7482 . transpose ( )
@@ -78,21 +86,32 @@ pub async fn route_request(
7886 header : X_RIVET_ACTOR . to_string ( ) ,
7987 }
8088 . build ( )
81- } ) ?
89+ } ) ?;
90+
91+ let token = headers
92+ . get ( X_RIVET_TOKEN )
93+ . map ( |x| x. to_str ( ) )
94+ . transpose ( )
95+ . context ( "invalid x-rivet-token header" ) ?;
96+
97+ ( actor_id, token)
8298 } ;
8399
84100 // Find actor to route to
85101 let actor_id = Id :: parse ( actor_id_str) . context ( "invalid x-rivet-actor header" ) ?;
86102
87- route_request_inner ( ctx, shared_state, actor_id, path) . await
103+ route_request_inner ( ctx, shared_state, actor_id, path, token ) . await
88104}
89105
90106async fn route_request_inner (
91107 ctx : & StandaloneCtx ,
92108 shared_state : & SharedState ,
93109 actor_id : Id ,
94110 path : & str ,
111+ _token : Option < & str > ,
95112) -> Result < Option < RoutingOutput > > {
113+ // NOTE: Token validation implemented in EE
114+
96115 // Route to peer dc where the actor lives
97116 if actor_id. label ( ) != ctx. config ( ) . dc_label ( ) {
98117 tracing:: debug!( peer_dc_label=?actor_id. label( ) , "re-routing actor to peer dc" ) ;
0 commit comments