Skip to content

Commit 3cacca3

Browse files
goffrieConvex, Inc.
authored andcommitted
Move ConductorState->AppState conversion to an extractor (#42840)
Introduces a new extractor `MtState`, which acts like the `State` extractor but is overridable. Also, use it in the actions middleware using `middleware::from_fn_with_state` rather than attaching the state using an extension. This is more typesafe than using extensions. GitOrigin-RevId: bee2c9bee725774ca9c8b3765474c13cdf0fe423
1 parent 993d622 commit 3cacca3

17 files changed

+184
-200
lines changed

crates/common/src/http/extract.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
use std::time::Instant;
1+
use std::{
2+
future,
3+
time::Instant,
4+
};
25

36
use axum::{
47
extract::{
8+
FromRef,
59
FromRequest,
610
FromRequestParts,
711
Request,
@@ -15,9 +19,10 @@ use axum::{
1519
use bytes::Bytes;
1620
use errors::ErrorMetadata;
1721
use fastrace::{
18-
future::FutureExt,
22+
future::FutureExt as _,
1923
Span,
2024
};
25+
use futures::TryFutureExt as _;
2126
use http::HeaderMap;
2227
use serde::{
2328
de::DeserializeOwned,
@@ -145,3 +150,40 @@ where
145150
axum::Json(self.0).into_response()
146151
}
147152
}
153+
154+
/// Like `axum::extract::State`, but customizable
155+
pub struct MtState<T>(pub T);
156+
157+
impl<S, T> FromRequestParts<S> for MtState<T>
158+
where
159+
T: FromMtState<S>,
160+
S: Send + Sync,
161+
{
162+
type Rejection = HttpResponseError;
163+
164+
fn from_request_parts(
165+
parts: &mut Parts,
166+
state: &S,
167+
) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
168+
T::from_request_parts(parts, state).map_ok(MtState)
169+
}
170+
}
171+
172+
pub trait FromMtState<Outer>: Sized {
173+
fn from_request_parts(
174+
parts: &mut Parts,
175+
state: &Outer,
176+
) -> impl Future<Output = Result<Self, HttpResponseError>> + Send;
177+
}
178+
179+
impl<Outer, T: FromRef<Outer>> FromMtState<Outer> for T
180+
where
181+
T: Send + Sync,
182+
{
183+
fn from_request_parts(
184+
_parts: &mut Parts,
185+
state: &Outer,
186+
) -> impl Future<Output = Result<Self, HttpResponseError>> + Send {
187+
future::ready(Ok(T::from_ref(state)))
188+
}
189+
}

crates/local_backend/src/app_metrics.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use axum::{
2-
extract::State,
3-
response::IntoResponse,
4-
};
1+
use axum::response::IntoResponse;
52
use common::{
63
components::{
74
ComponentFunctionPath,
@@ -10,6 +7,7 @@ use common::{
107
http::{
118
extract::{
129
Json,
10+
MtState,
1311
Query,
1412
},
1513
HttpResponseError,
@@ -39,7 +37,7 @@ pub(crate) struct UdfRateQueryArgs {
3937
}
4038

4139
pub(crate) async fn udf_rate(
42-
State(st): State<LocalAppState>,
40+
MtState(st): MtState<LocalAppState>,
4341
ExtractIdentity(identity): ExtractIdentity,
4442
Query(UdfRateQueryArgs {
4543
component_path,
@@ -69,7 +67,7 @@ pub(crate) struct TopKQueryArgs {
6967
}
7068

7169
pub(crate) async fn failure_percentage_top_k(
72-
State(st): State<LocalAppState>,
70+
MtState(st): MtState<LocalAppState>,
7371
ExtractIdentity(identity): ExtractIdentity,
7472
Query(TopKQueryArgs { window, k }): Query<TopKQueryArgs>,
7573
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -85,7 +83,7 @@ pub(crate) async fn failure_percentage_top_k(
8583
}
8684

8785
pub(crate) async fn cache_hit_percentage_top_k(
88-
State(st): State<LocalAppState>,
86+
MtState(st): MtState<LocalAppState>,
8987
ExtractIdentity(identity): ExtractIdentity,
9088
Query(TopKQueryArgs { window, k }): Query<TopKQueryArgs>,
9189
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -110,7 +108,7 @@ pub(crate) struct CacheHitPercentageQueryArgs {
110108
udf_type: Option<String>,
111109
}
112110
pub(crate) async fn cache_hit_percentage(
113-
State(st): State<LocalAppState>,
111+
MtState(st): MtState<LocalAppState>,
114112
ExtractIdentity(identity): ExtractIdentity,
115113
Query(query_args): Query<CacheHitPercentageQueryArgs>,
116114
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -140,7 +138,7 @@ pub(crate) struct LatencyPercentilesQueryArgs {
140138
udf_type: Option<String>,
141139
}
142140
pub(crate) async fn latency_percentiles(
143-
State(st): State<LocalAppState>,
141+
MtState(st): MtState<LocalAppState>,
144142
ExtractIdentity(identity): ExtractIdentity,
145143
Query(query_args): Query<LatencyPercentilesQueryArgs>,
146144
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -170,7 +168,7 @@ pub(crate) struct TableRateQueryArgs {
170168
window: String,
171169
}
172170
pub(crate) async fn table_rate(
173-
State(st): State<LocalAppState>,
171+
MtState(st): MtState<LocalAppState>,
174172
ExtractIdentity(identity): ExtractIdentity,
175173
Query(query_args): Query<TableRateQueryArgs>,
176174
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -224,7 +222,7 @@ pub(crate) struct ScheduledJobLagArgs {
224222
window: String,
225223
}
226224
pub(crate) async fn scheduled_job_lag(
227-
State(st): State<LocalAppState>,
225+
MtState(st): MtState<LocalAppState>,
228226
ExtractIdentity(identity): ExtractIdentity,
229227
Query(query_args): Query<ScheduledJobLagArgs>,
230228
) -> Result<impl IntoResponse, HttpResponseError> {

crates/local_backend/src/authentication.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ use anyhow::{
77
};
88
use authentication::extract_bearer_token;
99
use axum::{
10-
extract::{
11-
FromRef,
12-
FromRequestParts,
13-
},
10+
extract::FromRequestParts,
1411
RequestPartsExt,
1512
};
1613
use common::{
1714
http::{
18-
extract::Query,
15+
extract::{
16+
FromMtState,
17+
Query,
18+
},
1919
ExtractRequestId,
2020
ExtractResolvedHostname,
2121
HttpResponseError,
@@ -105,7 +105,7 @@ pub struct ExtractIdentity(pub Identity);
105105

106106
impl<S> FromRequestParts<S> for ExtractIdentity
107107
where
108-
LocalAppState: FromRef<S>,
108+
LocalAppState: FromMtState<S>,
109109
S: Send + Sync + Clone + 'static,
110110
{
111111
type Rejection = HttpResponseError;
@@ -116,7 +116,7 @@ where
116116
) -> Result<Self, Self::Rejection> {
117117
let token: AuthenticationToken =
118118
parts.extract::<ExtractAuthenticationToken>().await?.into();
119-
let st = LocalAppState::from_ref(st);
119+
let st = LocalAppState::from_request_parts(parts, st).await?;
120120

121121
Ok(Self(
122122
st.application

crates/local_backend/src/canonical_urls.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use axum::{
66
response::IntoResponse,
77
};
88
use common::http::{
9-
extract::Json,
9+
extract::{
10+
Json,
11+
MtState,
12+
},
1013
HttpResponseError,
1114
RequestDestination,
1215
};
@@ -111,7 +114,7 @@ pub struct GetCanonicalUrlsResponse {
111114
),
112115
)]
113116
pub async fn get_canonical_urls(
114-
State(st): State<LocalAppState>,
117+
MtState(st): MtState<LocalAppState>,
115118
ExtractIdentity(identity): ExtractIdentity,
116119
) -> Result<impl IntoResponse, HttpResponseError> {
117120
must_be_admin(&identity)?;

crates/local_backend/src/dashboard.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@ use application::{
55
};
66
use axum::{
77
debug_handler,
8-
extract::{
9-
FromRef,
10-
State,
11-
},
8+
extract::State,
129
response::IntoResponse,
1310
};
1411
use common::{
1512
components::ComponentId,
1613
http::{
1714
extract::{
15+
FromMtState,
1816
Json,
17+
MtState,
1918
Query,
2019
},
2120
ExtractClientVersion,
@@ -94,9 +93,8 @@ pub struct ShapesArgs {
9493
),
9594
responses((status = 200, body = serde_json::Value)),
9695
)]
97-
#[debug_handler]
9896
pub async fn shapes2(
99-
State(st): State<LocalAppState>,
97+
MtState(st): MtState<LocalAppState>,
10098
ExtractIdentity(identity): ExtractIdentity,
10199
Query(ShapesArgs { component }): Query<ShapesArgs>,
102100
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -136,9 +134,8 @@ pub async fn shapes2(
136134
request_body = DeleteTableArgs,
137135
responses((status = 200)),
138136
)]
139-
#[debug_handler]
140137
pub async fn delete_tables(
141-
State(st): State<LocalAppState>,
138+
MtState(st): MtState<LocalAppState>,
142139
ExtractIdentity(identity): ExtractIdentity,
143140
Json(DeleteTableArgs {
144141
table_names,
@@ -168,9 +165,8 @@ pub async fn delete_tables(
168165
request_body = DeleteComponentArgs,
169166
responses((status = 200)),
170167
)]
171-
#[debug_handler]
172168
pub async fn delete_component(
173-
State(st): State<LocalAppState>,
169+
MtState(st): MtState<LocalAppState>,
174170
ExtractIdentity(identity): ExtractIdentity,
175171
Json(DeleteComponentArgs { component_id }): Json<DeleteComponentArgs>,
176172
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -206,9 +202,8 @@ struct GetIndexesResponse {
206202
),
207203
responses((status = 200, body = GetIndexesResponse)),
208204
)]
209-
#[debug_handler]
210205
pub async fn get_indexes(
211-
State(st): State<LocalAppState>,
206+
MtState(st): MtState<LocalAppState>,
212207
ExtractIdentity(identity): ExtractIdentity,
213208
Query(GetIndexesArgs { component_id }): Query<GetIndexesArgs>,
214209
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -245,9 +240,8 @@ pub struct GetSourceCodeArgs {
245240
),
246241
responses((status = 200, body = String)),
247242
)]
248-
#[debug_handler]
249243
pub async fn get_source_code(
250-
State(st): State<LocalAppState>,
244+
MtState(st): MtState<LocalAppState>,
251245
ExtractIdentity(identity): ExtractIdentity,
252246
Query(GetSourceCodeArgs { path, component }): Query<GetSourceCodeArgs>,
253247
) -> Result<impl IntoResponse, HttpResponseError> {
@@ -351,7 +345,7 @@ pub fn local_only_dashboard_router() -> OpenApiRouter<crate::LocalAppState> {
351345
// Routes with the same handlers for the local backend + closed source backend
352346
pub fn common_dashboard_api_router<S>() -> OpenApiRouter<S>
353347
where
354-
LocalAppState: FromRef<S>,
348+
LocalAppState: FromMtState<S>,
355349
S: Clone + Send + Sync + 'static,
356350
{
357351
OpenApiRouter::new()

crates/local_backend/src/deploy_config.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ use axum::{
1616
use common::{
1717
components::ComponentId,
1818
http::{
19-
extract::Json,
19+
extract::{
20+
Json,
21+
MtState,
22+
},
2023
HttpResponseError,
2124
},
2225
version::Version,
@@ -182,9 +185,8 @@ pub struct ModuleHashJson {
182185
environment: Option<String>,
183186
}
184187

185-
#[debug_handler]
186188
pub async fn get_config(
187-
State(st): State<LocalAppState>,
189+
MtState(st): MtState<LocalAppState>,
188190
Json(req): Json<GetConfigRequest>,
189191
) -> Result<impl IntoResponse, HttpResponseError> {
190192
let identity = must_be_admin_from_key(
@@ -212,9 +214,8 @@ pub async fn get_config(
212214
}))
213215
}
214216

215-
#[debug_handler]
216217
pub async fn get_config_hashes(
217-
State(st): State<LocalAppState>,
218+
MtState(st): MtState<LocalAppState>,
218219
Json(req): Json<GetConfigRequest>,
219220
) -> Result<impl IntoResponse, HttpResponseError> {
220221
let identity = must_be_admin_from_key(

crates/local_backend/src/deploy_config2.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ use common::{
2222
},
2323
bootstrap_model::components::definition::SerializedComponentDefinitionMetadata,
2424
http::{
25-
extract::Json,
25+
extract::{
26+
Json,
27+
MtState,
28+
},
2629
HttpResponseError,
2730
},
2831
};
@@ -226,9 +229,8 @@ pub async fn start_push(
226229
// it won’t start schema validation/index backfill). It can be used to determine
227230
// what will be the effects of a large push without starting work that can take
228231
// a long time on large instances.
229-
#[debug_handler]
230232
pub async fn evaluate_push(
231-
State(st): State<LocalAppState>,
233+
MtState(st): MtState<LocalAppState>,
232234
Json(req): Json<StartPushRequest>,
233235
) -> Result<impl IntoResponse, HttpResponseError> {
234236
let _identity = must_be_admin_from_key_with_write_access(
@@ -258,9 +260,8 @@ pub struct WaitForSchemaRequest {
258260
timeout_ms: Option<u32>,
259261
}
260262

261-
#[debug_handler]
262263
pub async fn wait_for_schema(
263-
State(st): State<LocalAppState>,
264+
MtState(st): MtState<LocalAppState>,
264265
Json(req): Json<WaitForSchemaRequest>,
265266
) -> Result<impl IntoResponse, HttpResponseError> {
266267
let identity = must_be_admin_from_key(
@@ -289,9 +290,8 @@ pub struct FinishPushRequest {
289290
dry_run: bool,
290291
}
291292

292-
#[debug_handler]
293293
pub async fn finish_push(
294-
State(st): State<LocalAppState>,
294+
MtState(st): MtState<LocalAppState>,
295295
Json(req): Json<FinishPushRequest>,
296296
) -> Result<impl IntoResponse, HttpResponseError> {
297297
let identity = must_be_admin_from_key_with_write_access(

crates/local_backend/src/environment_variables.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ use axum::{
77
response::IntoResponse,
88
};
99
use common::http::{
10-
extract::Json,
10+
extract::{
11+
Json,
12+
MtState,
13+
},
1114
HttpResponseError,
1215
};
1316
use http::StatusCode;
@@ -134,7 +137,7 @@ pub struct ListEnvVarsResponse {
134137
),
135138
)]
136139
pub async fn list_environment_variables(
137-
State(st): State<LocalAppState>,
140+
MtState(st): MtState<LocalAppState>,
138141
ExtractIdentity(identity): ExtractIdentity,
139142
) -> Result<impl IntoResponse, HttpResponseError> {
140143
must_be_admin(&identity)?;

0 commit comments

Comments
 (0)