@@ -9,16 +9,24 @@ use devtools_wire_format::sources::sources_server::SourcesServer;
9
9
use devtools_wire_format:: tauri:: tauri_server;
10
10
use devtools_wire_format:: tauri:: tauri_server:: TauriServer ;
11
11
use futures:: { FutureExt , TryStreamExt } ;
12
+ use http:: HeaderValue ;
13
+ use hyper:: Body ;
12
14
use std:: net:: SocketAddr ;
15
+ use std:: pin:: Pin ;
16
+ use std:: sync:: { Arc , Mutex } ;
17
+ use std:: task:: { Context , Poll } ;
13
18
use tokio:: sync:: mpsc;
19
+ use tonic:: body:: BoxBody ;
14
20
use tonic:: codegen:: http:: Method ;
15
21
use tonic:: codegen:: tokio_stream:: wrappers:: ReceiverStream ;
16
22
use tonic:: codegen:: BoxStream ;
17
23
use tonic:: { Request , Response , Status } ;
18
24
use tonic_health:: pb:: health_server:: { Health , HealthServer } ;
19
25
use tonic_health:: server:: HealthReporter ;
20
26
use tonic_health:: ServingStatus ;
21
- use tower_http:: cors:: { AllowHeaders , CorsLayer } ;
27
+ use tower:: Service ;
28
+ use tower_http:: cors:: { AllowHeaders , AllowOrigin , CorsLayer } ;
29
+ use tower_layer:: Layer ;
22
30
23
31
/// Default maximum capacity for the channel of events sent from a
24
32
/// [`Server`] to each subscribed client.
@@ -28,15 +36,84 @@ use tower_http::cors::{AllowHeaders, CorsLayer};
28
36
const DEFAULT_CLIENT_BUFFER_CAPACITY : usize = 1024 * 4 ;
29
37
30
38
/// The `gRPC` server that exposes the instrumenting API
31
- pub struct Server (
32
- tonic:: transport:: server:: Router < tower_layer:: Stack < CorsLayer , tower_layer:: Identity > > ,
33
- ) ;
39
+ pub struct Server {
40
+ router : tonic:: transport:: server:: Router <
41
+ tower_layer:: Stack < DynamicCorsLayer , tower_layer:: Identity > ,
42
+ > ,
43
+ handle : ServerHandle ,
44
+ }
45
+
46
+ /// A handle to a server that is allowed to modify its properties (such as CORS allowed origins)
47
+ #[ allow( clippy:: module_name_repetitions) ]
48
+ #[ derive( Clone ) ]
49
+ pub struct ServerHandle {
50
+ allowed_origins : Arc < Mutex < Vec < AllowOrigin > > > ,
51
+ }
52
+
53
+ impl ServerHandle {
54
+ /// Allow the given origin in the instrumentation server CORS.
55
+ #[ allow( clippy:: missing_panics_doc) ]
56
+ pub fn allow_origin ( & self , origin : impl Into < AllowOrigin > ) {
57
+ self . allowed_origins . lock ( ) . unwrap ( ) . push ( origin. into ( ) ) ;
58
+ }
59
+ }
34
60
35
61
struct InstrumentService {
36
62
tx : mpsc:: Sender < Command > ,
37
63
health_reporter : HealthReporter ,
38
64
}
39
65
66
+ #[ derive( Clone ) ]
67
+ struct DynamicCorsLayer {
68
+ allowed_origins : Arc < Mutex < Vec < AllowOrigin > > > ,
69
+ }
70
+
71
+ impl < S > Layer < S > for DynamicCorsLayer {
72
+ type Service = DynamicCors < S > ;
73
+
74
+ fn layer ( & self , service : S ) -> Self :: Service {
75
+ DynamicCors {
76
+ inner : service,
77
+ allowed_origins : self . allowed_origins . clone ( ) ,
78
+ }
79
+ }
80
+ }
81
+
82
+ #[ derive( Debug , Clone ) ]
83
+ struct DynamicCors < S > {
84
+ inner : S ,
85
+ allowed_origins : Arc < Mutex < Vec < AllowOrigin > > > ,
86
+ }
87
+
88
+ type BoxFuture < ' a , T > = Pin < Box < dyn std:: future:: Future < Output = T > + Send + ' a > > ;
89
+
90
+ impl < S > Service < hyper:: Request < Body > > for DynamicCors < S >
91
+ where
92
+ S : Service < hyper:: Request < Body > , Response = hyper:: Response < BoxBody > > + Clone + Send + ' static ,
93
+ S :: Future : Send + ' static ,
94
+ {
95
+ type Response = S :: Response ;
96
+ type Error = S :: Error ;
97
+ type Future = BoxFuture < ' static , Result < Self :: Response , Self :: Error > > ;
98
+
99
+ fn poll_ready ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
100
+ self . inner . poll_ready ( cx)
101
+ }
102
+
103
+ fn call ( & mut self , req : hyper:: Request < Body > ) -> Self :: Future {
104
+ let mut cors = CorsLayer :: new ( )
105
+ // allow `GET` and `POST` when accessing the resource
106
+ . allow_methods ( [ Method :: GET , Method :: POST ] )
107
+ . allow_headers ( AllowHeaders :: any ( ) ) ;
108
+
109
+ for origin in & * self . allowed_origins . lock ( ) . unwrap ( ) {
110
+ cors = cors. allow_origin ( origin. clone ( ) ) ;
111
+ }
112
+
113
+ Box :: pin ( cors. layer ( self . inner . clone ( ) ) . call ( req) )
114
+ }
115
+ }
116
+
40
117
impl Server {
41
118
#[ allow( clippy:: missing_panics_doc) ]
42
119
pub fn new (
@@ -51,15 +128,22 @@ impl Server {
51
128
. set_serving :: < InstrumentServer < InstrumentService > > ( )
52
129
. now_or_never ( ) ;
53
130
54
- let cors = CorsLayer :: new ( )
55
- // allow `GET` and `POST` when accessing the resource
56
- . allow_methods ( [ Method :: GET , Method :: POST ] )
57
- . allow_headers ( AllowHeaders :: any ( ) )
58
- . allow_origin ( tower_http:: cors:: Any ) ;
131
+ let allowed_origins =
132
+ Arc :: new ( Mutex :: new ( vec ! [
133
+ if option_env!( "__DEVTOOLS_LOCAL_DEVELOPMENT" ) . is_some( ) {
134
+ AllowOrigin :: from( tower_http:: cors:: Any )
135
+ } else {
136
+ HeaderValue :: from_str( "https://devtools.crabnebula.dev" )
137
+ . unwrap( )
138
+ . into( )
139
+ } ,
140
+ ] ) ) ;
59
141
60
142
let router = tonic:: transport:: Server :: builder ( )
61
143
. accept_http1 ( true )
62
- . layer ( cors)
144
+ . layer ( DynamicCorsLayer {
145
+ allowed_origins : allowed_origins. clone ( ) ,
146
+ } )
63
147
. add_service ( tonic_web:: enable ( health_service) )
64
148
. add_service ( tonic_web:: enable ( InstrumentServer :: new (
65
149
InstrumentService {
@@ -71,7 +155,15 @@ impl Server {
71
155
. add_service ( tonic_web:: enable ( MetadataServer :: new ( metadata_server) ) )
72
156
. add_service ( tonic_web:: enable ( SourcesServer :: new ( sources_server) ) ) ;
73
157
74
- Self ( router)
158
+ Self {
159
+ router,
160
+ handle : ServerHandle { allowed_origins } ,
161
+ }
162
+ }
163
+
164
+ #[ must_use]
165
+ pub fn handle ( & self ) -> ServerHandle {
166
+ self . handle . clone ( )
75
167
}
76
168
77
169
/// Consumes this [`Server`] and returns a future that will execute the server.
@@ -82,7 +174,7 @@ impl Server {
82
174
pub async fn run ( self , addr : SocketAddr ) -> crate :: Result < ( ) > {
83
175
tracing:: info!( "Listening on {}" , addr) ;
84
176
85
- self . 0 . serve ( addr) . await ?;
177
+ self . router . serve ( addr) . await ?;
86
178
87
179
Ok ( ( ) )
88
180
}
0 commit comments