Skip to content

Commit d9f9f0c

Browse files
committed
fix(guard): include method in cache key
1 parent dce28f8 commit d9f9f0c

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

engine/packages/guard-core/src/proxy_service.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,15 @@ pub type RoutingFn = Arc<
191191
>;
192192

193193
pub type CacheKeyFn = Arc<
194-
dyn for<'a> Fn(&'a str, &'a str, PortType, &'a hyper::HeaderMap) -> Result<u64> + Send + Sync,
194+
dyn for<'a> Fn(
195+
&'a str,
196+
&'a str,
197+
&'a hyper::Method,
198+
PortType,
199+
&'a hyper::HeaderMap,
200+
) -> Result<u64>
201+
+ Send
202+
+ Sync,
195203
>;
196204

197205
#[derive(Clone, Debug)]
@@ -375,6 +383,7 @@ impl ProxyState {
375383
&self,
376384
hostname: &str,
377385
path: &str,
386+
method: &hyper::Method,
378387
port_type: PortType,
379388
headers: &hyper::HeaderMap,
380389
ignore_cache: bool,
@@ -385,11 +394,13 @@ impl ProxyState {
385394
tracing::debug!(
386395
hostname = %hostname_only,
387396
path = %path,
397+
method = %method,
388398
port_type = ?port_type,
389399
"Resolving route for request"
390400
);
391401

392-
let cache_key = (self.cache_key_fn)(hostname_only, &path, port_type.clone(), headers)?;
402+
let cache_key =
403+
(self.cache_key_fn)(hostname_only, &path, method, port_type.clone(), headers)?;
393404

394405
// Check cache first
395406
if !ignore_cache {
@@ -700,6 +711,7 @@ impl ProxyService {
700711
.resolve_route(
701712
host,
702713
&path,
714+
req.method(),
703715
self.state.port_type.clone(),
704716
req.headers(),
705717
false,
@@ -922,6 +934,7 @@ impl ProxyService {
922934
.resolve_route(
923935
&host,
924936
&path,
937+
&req_parts.method,
925938
self.state.port_type.clone(),
926939
&req_parts.headers,
927940
true,
@@ -996,6 +1009,7 @@ impl ProxyService {
9961009
.resolve_route(
9971010
&host,
9981011
&path,
1012+
&req_parts.method,
9991013
self.state.port_type.clone(),
10001014
&req_parts.headers,
10011015
true,
@@ -1060,6 +1074,7 @@ impl ProxyService {
10601074
.resolve_route(
10611075
&host,
10621076
&path,
1077+
req_collected.method(),
10631078
self.state.port_type.clone(),
10641079
&req_headers,
10651080
true,
@@ -1155,8 +1170,9 @@ impl ProxyService {
11551170
.map(|x| x.to_string())
11561171
.unwrap_or_else(|| req.uri().path().to_string());
11571172

1158-
// Capture headers before request is consumed
1173+
// Capture headers and method before request is consumed
11591174
let req_headers = req.headers().clone();
1175+
let req_method = req.method().clone();
11601176
let ray_id = req.extensions().get::<RequestIds>().map(|x| x.ray_id);
11611177

11621178
// Get middleware config for this actor if it exists
@@ -1449,6 +1465,7 @@ impl ProxyService {
14491465
.resolve_route(
14501466
&req_host,
14511467
&req_path,
1468+
&req_method,
14521469
state.port_type.clone(),
14531470
&req_headers,
14541471
true,
@@ -1907,6 +1924,7 @@ impl ProxyService {
19071924
.resolve_route(
19081925
&req_host,
19091926
&req_path,
1927+
&req_method,
19101928
state.port_type.clone(),
19111929
&req_headers,
19121930
true,

engine/packages/guard/src/cache/actor.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ use gas::prelude::*;
99
use crate::routing::pegboard_gateway::X_RIVET_ACTOR;
1010

1111
#[tracing::instrument(skip_all)]
12-
pub fn build_cache_key(target: &str, path: &str, headers: &hyper::HeaderMap) -> Result<u64> {
12+
pub fn build_cache_key(
13+
target: &str,
14+
path: &str,
15+
method: &hyper::Method,
16+
headers: &hyper::HeaderMap,
17+
) -> Result<u64> {
1318
// Check target
1419
ensure!(target == "actor", "wrong target");
1520

@@ -26,11 +31,12 @@ pub fn build_cache_key(target: &str, path: &str, headers: &hyper::HeaderMap) ->
2631
.context("invalid x-rivet-actor header")?;
2732
let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?;
2833

29-
// Create a hash using target, actor_id, and path
34+
// Create a hash using target, actor_id, path, and method
3035
let mut hasher = DefaultHasher::new();
3136
target.hash(&mut hasher);
3237
actor_id.hash(&mut hasher);
3338
path.hash(&mut hasher);
39+
method.as_str().hash(&mut hasher);
3440
let hash = hasher.finish();
3541

3642
Ok(hash)

engine/packages/guard/src/cache/mod.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@ use crate::routing::X_RIVET_TARGET;
1515
/// Creates the main cache key function that handles all incoming requests
1616
#[tracing::instrument(skip_all)]
1717
pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn {
18-
Arc::new(move |hostname, path, _port_type, headers| {
18+
Arc::new(move |hostname, path, method, _port_type, headers| {
1919
tracing::debug!("building cache key");
2020

2121
let target = match read_target(headers) {
2222
Ok(target) => target,
2323
Err(err) => {
2424
tracing::debug!(?err, "failed parsing target for cache key");
2525

26-
return Ok(host_path_cache_key(hostname, path));
26+
return Ok(host_path_method_cache_key(hostname, path, method));
2727
}
2828
};
2929

30-
let cache_key = match actor::build_cache_key(target, path, headers) {
30+
let cache_key = match actor::build_cache_key(target, path, method, headers) {
3131
Ok(key) => Some(key),
3232
Err(err) => {
3333
tracing::debug!(?err, "failed to create actor cache key");
@@ -36,11 +36,11 @@ pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn {
3636
}
3737
};
3838

39-
// Fallback to hostname + path hash if actor did not work
39+
// Fallback to hostname + path + method hash if actor did not work
4040
if let Some(cache_key) = cache_key {
4141
Ok(cache_key)
4242
} else {
43-
Ok(host_path_cache_key(hostname, path))
43+
Ok(host_path_method_cache_key(hostname, path, method))
4444
}
4545
})
4646
}
@@ -57,12 +57,13 @@ fn read_target(headers: &hyper::HeaderMap) -> Result<&str> {
5757
Ok(target.to_str()?)
5858
}
5959

60-
fn host_path_cache_key(hostname: &str, path: &str) -> u64 {
60+
fn host_path_method_cache_key(hostname: &str, path: &str, method: &hyper::Method) -> u64 {
6161
// Extract just the hostname, stripping the port if present
6262
let hostname_only = hostname.split(':').next().unwrap_or(hostname);
6363

6464
let mut hasher = DefaultHasher::new();
6565
hostname_only.hash(&mut hasher);
6666
path.hash(&mut hasher);
67+
method.as_str().hash(&mut hasher);
6768
hasher.finish()
6869
}

0 commit comments

Comments
 (0)