Skip to content

Commit d9e9f6c

Browse files
Merge pull request #288 from crabnebula-dev/feat/add-cors-origin-api
feat: add API to allow an origin to be allowed by CORS
2 parents b34c109 + 90ef64e commit d9e9f6c

File tree

5 files changed

+116
-18
lines changed

5 files changed

+116
-18
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/devtools-core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ bytes = "1.5.0"
2929
ringbuf = "0.4.0-rc.3"
3030
async-stream = "0.3.5"
3131
http = "0.2"
32+
hyper = "0.14"
33+
tower = "0.4"

crates/devtools-core/src/aggregator.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,6 @@ impl<T, const CAP: usize> EventBuf<T, CAP> {
321321
}
322322

323323
/// Push an event into the buffer, overwriting the oldest event if the buffer is full.
324-
// TODO does it really make sense to track the dropped events here?
325324
pub fn push_overwrite(&mut self, item: T) {
326325
if self.inner.push_overwrite(item).is_some() {
327326
self.sent = self.sent.saturating_sub(1);

crates/devtools-core/src/server.rs

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,24 @@ use devtools_wire_format::sources::sources_server::SourcesServer;
99
use devtools_wire_format::tauri::tauri_server;
1010
use devtools_wire_format::tauri::tauri_server::TauriServer;
1111
use futures::{FutureExt, TryStreamExt};
12+
use http::HeaderValue;
13+
use hyper::Body;
1214
use std::net::SocketAddr;
15+
use std::pin::Pin;
16+
use std::sync::{Arc, Mutex};
17+
use std::task::{Context, Poll};
1318
use tokio::sync::mpsc;
19+
use tonic::body::BoxBody;
1420
use tonic::codegen::http::Method;
1521
use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
1622
use tonic::codegen::BoxStream;
1723
use tonic::{Request, Response, Status};
1824
use tonic_health::pb::health_server::{Health, HealthServer};
1925
use tonic_health::server::HealthReporter;
2026
use tonic_health::ServingStatus;
21-
use tower_http::cors::{AllowHeaders, CorsLayer};
27+
use tower::Service;
28+
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
29+
use tower_layer::Layer;
2230

2331
/// Default maximum capacity for the channel of events sent from a
2432
/// [`Server`] to each subscribed client.
@@ -28,15 +36,84 @@ use tower_http::cors::{AllowHeaders, CorsLayer};
2836
const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4;
2937

3038
/// The `gRPC` server that exposes the instrumenting API
31-
pub struct Server(
32-
tonic::transport::server::Router<tower_layer::Stack<CorsLayer, tower_layer::Identity>>,
33-
);
39+
pub struct Server {
40+
router: tonic::transport::server::Router<
41+
tower_layer::Stack<DynamicCorsLayer, tower_layer::Identity>,
42+
>,
43+
handle: ServerHandle,
44+
}
45+
46+
/// A handle to a server that is allowed to modify its properties (such as CORS allowed origins)
47+
#[allow(clippy::module_name_repetitions)]
48+
#[derive(Clone)]
49+
pub struct ServerHandle {
50+
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
51+
}
52+
53+
impl ServerHandle {
54+
/// Allow the given origin in the instrumentation server CORS.
55+
#[allow(clippy::missing_panics_doc)]
56+
pub fn allow_origin(&self, origin: impl Into<AllowOrigin>) {
57+
self.allowed_origins.lock().unwrap().push(origin.into());
58+
}
59+
}
3460

3561
struct InstrumentService {
3662
tx: mpsc::Sender<Command>,
3763
health_reporter: HealthReporter,
3864
}
3965

66+
#[derive(Clone)]
67+
struct DynamicCorsLayer {
68+
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
69+
}
70+
71+
impl<S> Layer<S> for DynamicCorsLayer {
72+
type Service = DynamicCors<S>;
73+
74+
fn layer(&self, service: S) -> Self::Service {
75+
DynamicCors {
76+
inner: service,
77+
allowed_origins: self.allowed_origins.clone(),
78+
}
79+
}
80+
}
81+
82+
#[derive(Debug, Clone)]
83+
struct DynamicCors<S> {
84+
inner: S,
85+
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
86+
}
87+
88+
type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
89+
90+
impl<S> Service<hyper::Request<Body>> for DynamicCors<S>
91+
where
92+
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
93+
S::Future: Send + 'static,
94+
{
95+
type Response = S::Response;
96+
type Error = S::Error;
97+
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
98+
99+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100+
self.inner.poll_ready(cx)
101+
}
102+
103+
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
104+
let mut cors = CorsLayer::new()
105+
// allow `GET` and `POST` when accessing the resource
106+
.allow_methods([Method::GET, Method::POST])
107+
.allow_headers(AllowHeaders::any());
108+
109+
for origin in &*self.allowed_origins.lock().unwrap() {
110+
cors = cors.allow_origin(origin.clone());
111+
}
112+
113+
Box::pin(cors.layer(self.inner.clone()).call(req))
114+
}
115+
}
116+
40117
impl Server {
41118
#[allow(clippy::missing_panics_doc)]
42119
pub fn new(
@@ -51,15 +128,22 @@ impl Server {
51128
.set_serving::<InstrumentServer<InstrumentService>>()
52129
.now_or_never();
53130

54-
let cors = CorsLayer::new()
55-
// allow `GET` and `POST` when accessing the resource
56-
.allow_methods([Method::GET, Method::POST])
57-
.allow_headers(AllowHeaders::any())
58-
.allow_origin(tower_http::cors::Any);
131+
let allowed_origins =
132+
Arc::new(Mutex::new(vec![
133+
if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() {
134+
AllowOrigin::from(tower_http::cors::Any)
135+
} else {
136+
HeaderValue::from_str("https://devtools.crabnebula.dev")
137+
.unwrap()
138+
.into()
139+
},
140+
]));
59141

60142
let router = tonic::transport::Server::builder()
61143
.accept_http1(true)
62-
.layer(cors)
144+
.layer(DynamicCorsLayer {
145+
allowed_origins: allowed_origins.clone(),
146+
})
63147
.add_service(tonic_web::enable(health_service))
64148
.add_service(tonic_web::enable(InstrumentServer::new(
65149
InstrumentService {
@@ -71,7 +155,15 @@ impl Server {
71155
.add_service(tonic_web::enable(MetadataServer::new(metadata_server)))
72156
.add_service(tonic_web::enable(SourcesServer::new(sources_server)));
73157

74-
Self(router)
158+
Self {
159+
router,
160+
handle: ServerHandle { allowed_origins },
161+
}
162+
}
163+
164+
#[must_use]
165+
pub fn handle(&self) -> ServerHandle {
166+
self.handle.clone()
75167
}
76168

77169
/// Consumes this [`Server`] and returns a future that will execute the server.
@@ -82,7 +174,7 @@ impl Server {
82174
pub async fn run(self, addr: SocketAddr) -> crate::Result<()> {
83175
tracing::info!("Listening on {}", addr);
84176

85-
self.0.serve(addr).await?;
177+
self.router.serve(addr).await?;
86178

87179
Ok(())
88180
}

crates/devtools/src/lib.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod server;
33
use devtools_core::aggregator::Aggregator;
44
use devtools_core::layer::Layer;
55
use devtools_core::server::wire::tauri::tauri_server::TauriServer;
6-
use devtools_core::server::Server;
6+
use devtools_core::server::{Server, ServerHandle};
77
use devtools_core::Command;
88
pub use devtools_core::Error;
99
use devtools_core::{Result, Shared};
@@ -52,6 +52,7 @@ mod ios {
5252

5353
pub struct Devtools {
5454
pub connection: ConnectionInfo,
55+
pub server_handle: ServerHandle,
5556
}
5657

5758
fn init_plugin<R: Runtime>(
@@ -64,10 +65,6 @@ fn init_plugin<R: Runtime>(
6465
.setup(move |app_handle, _api| {
6566
let (mut health_reporter, health_service) = tonic_health::server::health_reporter();
6667

67-
app_handle.manage(Devtools {
68-
connection: connection_info(&addr),
69-
});
70-
7168
health_reporter
7269
.set_serving::<TauriServer<server::TauriService<R>>>()
7370
.now_or_never()
@@ -87,6 +84,12 @@ fn init_plugin<R: Runtime>(
8784
app_handle: app_handle.clone(),
8885
},
8986
);
87+
let server_handle = server.handle();
88+
89+
app_handle.manage(Devtools {
90+
connection: connection_info(&addr),
91+
server_handle,
92+
});
9093

9194
#[cfg(not(target_os = "ios"))]
9295
print_link(&addr);

0 commit comments

Comments
 (0)