Skip to content

Commit a06d25a

Browse files
authored
refactor(jwt): improve the implementation of jwt plugin and expose it to expressions (#534)
1 parent 41a436e commit a06d25a

31 files changed

+630
-158
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bin/router/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ retry-policies = { workspace = true}
4444
reqwest-retry = { workspace = true }
4545
reqwest-middleware = { workspace = true }
4646
vrl = { workspace = true }
47+
serde_json = { workspace = true }
4748

4849
mimalloc = { version = "0.1.48", features = ["v3"] }
4950
moka = { version = "0.12.10", features = ["future"] }

bin/router/src/jwt/context.rs

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,44 @@
11
use std::collections::HashMap;
22

3-
use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan;
3+
use hive_router_plan_executor::execution::jwt_forward::JwtForwardingError;
44
use jsonwebtoken::TokenData;
55
use serde::{Deserialize, Serialize};
6-
use sonic_rs::Value;
7-
8-
use crate::jwt::errors::JwtForwardingError;
96

107
pub type JwtTokenPayload = TokenData<JwtClaims>;
118

129
#[derive(Debug, Clone)]
1310
pub struct JwtRequestContext {
14-
// The payload extracted from the JWT token, and the extensions key to inject it into the request
15-
pub token_payload: Option<(String, JwtTokenPayload)>,
11+
pub token_prefix: Option<String>,
12+
pub token_raw: String,
13+
pub token_payload: JwtTokenPayload,
1614
}
1715

18-
impl TryInto<Option<JwtAuthForwardingPlan>> for JwtRequestContext {
19-
type Error = JwtForwardingError;
16+
impl JwtRequestContext {
17+
pub fn get_claims_value(&self) -> Result<sonic_rs::Value, JwtForwardingError> {
18+
Ok(sonic_rs::to_value(&self.token_payload.claims)?)
19+
}
20+
21+
/// Extracts an optional "scope"/"scopes" field form the token's payload.
22+
/// Supports both space-delimited and array formats.
23+
pub fn extract_scopes(&self) -> Option<Vec<String>> {
24+
let map = &self.token_payload.claims.additional_claims;
25+
let maybe_scopes = map.get("scope").or_else(|| map.get("scopes"));
26+
27+
if let Some(serde_json::Value::String(scopes_str)) = maybe_scopes {
28+
return Some(scopes_str.split(' ').map(String::from).collect());
29+
}
2030

21-
fn try_into(self) -> Result<Option<JwtAuthForwardingPlan>, Self::Error> {
22-
if let Some((extension_field_name, payload)) = &self.token_payload {
23-
return Ok(Some(JwtAuthForwardingPlan {
24-
extension_field_name: extension_field_name.clone(),
25-
extension_field_value: sonic_rs::to_value(&payload.claims)?,
26-
}));
31+
if let Some(serde_json::Value::Array(scopes_arr)) = maybe_scopes {
32+
return Some(
33+
scopes_arr
34+
.iter()
35+
.filter_map(|s| s.as_str())
36+
.map(String::from)
37+
.collect::<Vec<_>>(),
38+
);
2739
}
2840

29-
Ok(None)
41+
None
3042
}
3143
}
3244

@@ -61,6 +73,8 @@ pub struct JwtClaims {
6173
#[serde(skip_serializing_if = "Option::is_none")]
6274
pub jti: Option<String>,
6375

76+
// we are using serde to deserialize the additional claims
77+
// because the jsonwebtoken crate is using `serde_json` internally, and the `sonic_rs::Value` is not recognized as valid type
6478
#[serde(flatten)]
65-
pub additional_claims: HashMap<String, Value>,
79+
pub additional_claims: HashMap<String, serde_json::Value>,
6680
}

bin/router/src/jwt/errors.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,3 @@ impl From<&JwtError> for GraphQLError {
9494
}
9595
}
9696
}
97-
98-
#[derive(Debug, thiserror::Error)]
99-
pub enum JwtForwardingError {
100-
#[error("failed to serialized jwt claims")]
101-
ClaimsSerializeError(#[from] sonic_rs::Error),
102-
#[error("failed to parse as valid header value")]
103-
ValueIsNotValidHeader(#[from] http::header::InvalidHeaderValue),
104-
}

bin/router/src/jwt/mod.rs

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl JwtAuthRuntime {
5151
Ok(instance)
5252
}
5353

54-
fn lookup(&self, req: &HttpRequest) -> Result<String, LookupError> {
54+
fn lookup(&self, req: &HttpRequest) -> Result<(Option<String>, String), LookupError> {
5555
for lookup_config in &self.config.lookup_locations {
5656
match lookup_config {
5757
JwtAuthPluginLookupLocation::Header { name, prefix } => {
@@ -73,14 +73,17 @@ impl JwtAuthRuntime {
7373
.and_then(|s| s.strip_prefix(prefix))
7474
{
7575
Some(stripped_value) => {
76-
return Ok(stripped_value.trim().to_string());
76+
return Ok((
77+
Some(prefix.to_string()),
78+
stripped_value.trim().to_string(),
79+
));
7780
}
7881
None => {
7982
return Err(LookupError::MismatchedPrefix);
8083
}
8184
},
8285
None => {
83-
return Ok(header_value.to_str().unwrap_or("").to_string());
86+
return Ok((None, header_value.to_str().unwrap_or("").to_string()));
8487
}
8588
}
8689
}
@@ -101,7 +104,7 @@ impl JwtAuthRuntime {
101104
let (cookie_name, cookie_value) = v.name_value_trimmed();
102105

103106
if cookie_name == name {
104-
return Ok(cookie_value.to_string());
107+
return Ok((None, cookie_value.to_string()));
105108
}
106109
}
107110
Err(e) => {
@@ -158,15 +161,15 @@ impl JwtAuthRuntime {
158161
&self,
159162
jwks: &Vec<Arc<JwkSet>>,
160163
req: &HttpRequest,
161-
) -> Result<(JwtTokenPayload, String), JwtError> {
164+
) -> Result<(JwtTokenPayload, Option<String>, String), JwtError> {
162165
match self.lookup(req) {
163-
Ok(token) => {
166+
Ok((maybe_prefix, token)) => {
164167
// First, we need to decode the header to determine which provider to use.
165168
let header = decode_header(&token).map_err(JwtError::InvalidJwtHeader)?;
166169
let jwk = self.find_matching_jwks(&header, jwks)?;
167170

168171
self.decode_and_validate_token(&token, &jwk.keys)
169-
.map(|token_data| (token_data, token))
172+
.map(|token_data| (token_data, maybe_prefix, token))
170173
}
171174
Err(e) => {
172175
warn!("jwt plugin failed to lookup token. error: {}", e);
@@ -266,25 +269,15 @@ impl JwtAuthRuntime {
266269
let valid_jwks = self.jwks.all();
267270

268271
match self.authenticate(&valid_jwks, request) {
269-
Ok((token_data, _token)) => {
270-
let mut jwt_ctx = JwtRequestContext {
271-
token_payload: None,
272-
};
273-
274-
if self.config.forward_claims_to_upstream_extensions.enabled {
275-
jwt_ctx.token_payload = Some((
276-
self.config
277-
.forward_claims_to_upstream_extensions
278-
.field_name
279-
.clone(),
280-
token_data,
281-
));
282-
}
283-
284-
request.extensions_mut().insert(jwt_ctx);
272+
Ok((token_payload, maybe_token_prefix, token)) => {
273+
request.extensions_mut().insert(JwtRequestContext {
274+
token_payload,
275+
token_raw: token,
276+
token_prefix: maybe_token_prefix,
277+
});
285278
}
286279
Err(e) => {
287-
warn!("jwt token error: {}", e);
280+
warn!("jwt token error: {:?}", e);
288281

289282
if self.config.require_authentication.is_some_and(|v| v) {
290283
return Err(e);

bin/router/src/lib.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ pub async fn configure_app_from_config(
106106
router_config: HiveRouterConfig,
107107
bg_tasks_manager: &mut BackgroundTasksManager,
108108
) -> Result<(Arc<RouterSharedState>, Arc<SchemaState>), Box<dyn std::error::Error>> {
109-
let jwt_runtime = match &router_config.jwt {
110-
Some(jwt_config) => Some(JwtAuthRuntime::init(bg_tasks_manager, jwt_config).await?),
111-
None => None,
109+
let jwt_runtime = match router_config.jwt.is_jwt_auth_enabled() {
110+
true => Some(JwtAuthRuntime::init(bg_tasks_manager, &router_config.jwt).await?),
111+
false => None,
112112
};
113113

114114
let router_config_arc = Arc::new(router_config);

bin/router/src/pipeline/error.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Arc;
22

33
use graphql_tools::validation::utils::ValidationError;
44
use hive_router_plan_executor::{
5-
execution::error::PlanExecutionError,
5+
execution::{error::PlanExecutionError, jwt_forward::JwtForwardingError},
66
response::graphql_error::{GraphQLError, GraphQLErrorExtensions},
77
};
88
use hive_router_query_planner::{
@@ -15,12 +15,9 @@ use ntex::{
1515
};
1616
use serde::{Deserialize, Serialize};
1717

18-
use crate::{
19-
jwt::errors::JwtForwardingError,
20-
pipeline::{
21-
header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR},
22-
progressive_override::LabelEvaluationError,
23-
},
18+
use crate::pipeline::{
19+
header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR},
20+
progressive_override::LabelEvaluationError,
2421
};
2522

2623
#[derive(Debug)]

bin/router/src/pipeline/execution.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
use std::collections::HashMap;
22
use std::sync::Arc;
33

4-
use crate::jwt::context::JwtRequestContext;
54
use crate::pipeline::coerce_variables::CoerceVariablesPayload;
65
use crate::pipeline::error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant};
76
use crate::pipeline::normalize::GraphQLNormalizationPayload;
87
use crate::schema_state::SupergraphData;
98
use crate::shared_state::RouterSharedState;
109
use hive_router_plan_executor::execute_query_plan;
10+
use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails;
1111
use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan;
12-
use hive_router_plan_executor::execution::plan::{
13-
ClientRequestDetails, PlanExecutionOutput, QueryPlanExecutionContext,
14-
};
12+
use hive_router_plan_executor::execution::plan::{PlanExecutionOutput, QueryPlanExecutionContext};
1513
use hive_router_plan_executor::introspection::resolve::IntrospectionContext;
1614
use hive_router_query_planner::planner::plan_nodes::QueryPlan;
1715
use http::HeaderName;
@@ -67,15 +65,23 @@ pub async fn execute_plan(
6765
metadata: &supergraph.metadata,
6866
};
6967

70-
let jwt_context = {
71-
let req_extensions = req.extensions();
72-
req_extensions.get::<JwtRequestContext>().cloned()
73-
};
74-
let jwt_forward_plan: Option<JwtAuthForwardingPlan> = match jwt_context {
75-
Some(jwt_context) => jwt_context
76-
.try_into()
77-
.map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?,
78-
None => None,
68+
let jwt_forward_plan: Option<JwtAuthForwardingPlan> = if app_state
69+
.router_config
70+
.jwt
71+
.is_jwt_extensions_forwarding_enabled()
72+
{
73+
client_request_details
74+
.jwt
75+
.build_forwarding_plan(
76+
&app_state
77+
.router_config
78+
.jwt
79+
.forward_claims_to_upstream_extensions
80+
.field_name,
81+
)
82+
.map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?
83+
} else {
84+
None
7985
};
8086

8187
execute_query_plan(QueryPlanExecutionContext {

bin/router/src/pipeline/mod.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
use std::sync::Arc;
22

3-
use hive_router_plan_executor::execution::plan::{
4-
ClientRequestDetails, OperationDetails, PlanExecutionOutput,
3+
use hive_router_plan_executor::execution::{
4+
client_request_details::{ClientRequestDetails, JwtRequestDetails, OperationDetails},
5+
plan::PlanExecutionOutput,
56
};
67
use hive_router_query_planner::{
78
state::supergraph_state::OperationKind, utils::cancellation::CancellationToken,
@@ -13,6 +14,7 @@ use ntex::{
1314
};
1415

1516
use crate::{
17+
jwt::context::JwtRequestContext,
1618
pipeline::{
1719
coerce_variables::coerce_request_variables,
1820
csrf_prevention::perform_csrf_prevention,
@@ -101,6 +103,7 @@ pub async fn graphql_request_handler(
101103
}
102104

103105
#[inline]
106+
#[allow(clippy::await_holding_refcell_ref)]
104107
pub async fn execute_pipeline(
105108
req: &mut HttpRequest,
106109
body_bytes: Bytes,
@@ -129,6 +132,20 @@ pub async fn execute_pipeline(
129132
let query_plan_cancellation_token =
130133
CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout);
131134

135+
let req_extensions = req.extensions();
136+
let jwt_context = req_extensions.get::<JwtRequestContext>();
137+
let jwt_request_details = match jwt_context {
138+
Some(jwt_context) => JwtRequestDetails::Authenticated {
139+
token: jwt_context.token_raw.as_str(),
140+
prefix: jwt_context.token_prefix.as_deref(),
141+
scopes: jwt_context.extract_scopes(),
142+
claims: &jwt_context
143+
.get_claims_value()
144+
.map_err(|e| req.new_pipeline_error(PipelineErrorVariant::JwtForwardingError(e)))?,
145+
},
146+
None => JwtRequestDetails::Unauthenticated,
147+
};
148+
132149
let client_request_details = ClientRequestDetails {
133150
method: req.method(),
134151
url: req.uri(),
@@ -143,6 +160,7 @@ pub async fn execute_pipeline(
143160
},
144161
query: &execution_request.query,
145162
},
163+
jwt: &jwt_request_details,
146164
};
147165

148166
let progressive_override_ctx = request_override_context(

bin/router/src/pipeline/progressive_override.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::collections::{BTreeMap, HashMap, HashSet};
22

33
use hive_router_config::override_labels::{LabelOverrideValue, OverrideLabelsConfig};
4-
use hive_router_plan_executor::execution::plan::ClientRequestDetails;
4+
use hive_router_plan_executor::execution::client_request_details::ClientRequestDetails;
55
use hive_router_query_planner::{
66
graph::{PlannerOverrideContext, PERCENTAGE_SCALE_FACTOR},
77
state::supergraph_state::SupergraphState,

0 commit comments

Comments
 (0)