From 28148fbc8099c07856dd517528a068859251c04f Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 4 Nov 2025 01:25:45 +0000 Subject: [PATCH] fix(guard): include method in cache key --- .../packages/guard-core/src/proxy_service.rs | 24 ++++++++++++++++--- .../packages/guard-core/tests/common/mod.rs | 1 + engine/packages/guard/src/cache/actor.rs | 10 ++++++-- engine/packages/guard/src/cache/mod.rs | 13 +++++----- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/engine/packages/guard-core/src/proxy_service.rs b/engine/packages/guard-core/src/proxy_service.rs index 722227a2eb..0a47d055de 100644 --- a/engine/packages/guard-core/src/proxy_service.rs +++ b/engine/packages/guard-core/src/proxy_service.rs @@ -191,7 +191,15 @@ pub type RoutingFn = Arc< >; pub type CacheKeyFn = Arc< - dyn for<'a> Fn(&'a str, &'a str, PortType, &'a hyper::HeaderMap) -> Result + Send + Sync, + dyn for<'a> Fn( + &'a str, + &'a str, + &'a hyper::Method, + PortType, + &'a hyper::HeaderMap, + ) -> Result + + Send + + Sync, >; #[derive(Clone, Debug)] @@ -375,6 +383,7 @@ impl ProxyState { &self, hostname: &str, path: &str, + method: &hyper::Method, port_type: PortType, headers: &hyper::HeaderMap, ignore_cache: bool, @@ -385,11 +394,13 @@ impl ProxyState { tracing::debug!( hostname = %hostname_only, path = %path, + method = %method, port_type = ?port_type, "Resolving route for request" ); - let cache_key = (self.cache_key_fn)(hostname_only, &path, port_type.clone(), headers)?; + let cache_key = + (self.cache_key_fn)(hostname_only, &path, method, port_type.clone(), headers)?; // Check cache first if !ignore_cache { @@ -700,6 +711,7 @@ impl ProxyService { .resolve_route( host, &path, + req.method(), self.state.port_type.clone(), req.headers(), false, @@ -922,6 +934,7 @@ impl ProxyService { .resolve_route( &host, &path, + &req_parts.method, self.state.port_type.clone(), &req_parts.headers, true, @@ -996,6 +1009,7 @@ impl ProxyService { .resolve_route( &host, &path, + &req_parts.method, self.state.port_type.clone(), &req_parts.headers, true, @@ -1060,6 +1074,7 @@ impl ProxyService { .resolve_route( &host, &path, + req_collected.method(), self.state.port_type.clone(), &req_headers, true, @@ -1155,8 +1170,9 @@ impl ProxyService { .map(|x| x.to_string()) .unwrap_or_else(|| req.uri().path().to_string()); - // Capture headers before request is consumed + // Capture headers and method before request is consumed let req_headers = req.headers().clone(); + let req_method = req.method().clone(); let ray_id = req.extensions().get::().map(|x| x.ray_id); // Get middleware config for this actor if it exists @@ -1449,6 +1465,7 @@ impl ProxyService { .resolve_route( &req_host, &req_path, + &req_method, state.port_type.clone(), &req_headers, true, @@ -1907,6 +1924,7 @@ impl ProxyService { .resolve_route( &req_host, &req_path, + &req_method, state.port_type.clone(), &req_headers, true, diff --git a/engine/packages/guard-core/tests/common/mod.rs b/engine/packages/guard-core/tests/common/mod.rs index 898ba05aff..c385fe490a 100644 --- a/engine/packages/guard-core/tests/common/mod.rs +++ b/engine/packages/guard-core/tests/common/mod.rs @@ -478,6 +478,7 @@ pub fn create_test_cache_key_fn() -> CacheKeyFn { Arc::new( move |hostname: &str, path: &str, + _method: &hyper::Method, _port_type: rivet_guard_core::proxy_service::PortType, _headers: &hyper::HeaderMap| { // Extract just the hostname, stripping the port if present diff --git a/engine/packages/guard/src/cache/actor.rs b/engine/packages/guard/src/cache/actor.rs index 12f55a4509..a8fe53c566 100644 --- a/engine/packages/guard/src/cache/actor.rs +++ b/engine/packages/guard/src/cache/actor.rs @@ -9,7 +9,12 @@ use gas::prelude::*; use crate::routing::pegboard_gateway::X_RIVET_ACTOR; #[tracing::instrument(skip_all)] -pub fn build_cache_key(target: &str, path: &str, headers: &hyper::HeaderMap) -> Result { +pub fn build_cache_key( + target: &str, + path: &str, + method: &hyper::Method, + headers: &hyper::HeaderMap, +) -> Result { // Check target ensure!(target == "actor", "wrong target"); @@ -26,11 +31,12 @@ pub fn build_cache_key(target: &str, path: &str, headers: &hyper::HeaderMap) -> .context("invalid x-rivet-actor header")?; let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?; - // Create a hash using target, actor_id, and path + // Create a hash using target, actor_id, path, and method let mut hasher = DefaultHasher::new(); target.hash(&mut hasher); actor_id.hash(&mut hasher); path.hash(&mut hasher); + method.as_str().hash(&mut hasher); let hash = hasher.finish(); Ok(hash) diff --git a/engine/packages/guard/src/cache/mod.rs b/engine/packages/guard/src/cache/mod.rs index fba7f0a584..b4b0e48125 100644 --- a/engine/packages/guard/src/cache/mod.rs +++ b/engine/packages/guard/src/cache/mod.rs @@ -15,7 +15,7 @@ use crate::routing::X_RIVET_TARGET; /// Creates the main cache key function that handles all incoming requests #[tracing::instrument(skip_all)] pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn { - Arc::new(move |hostname, path, _port_type, headers| { + Arc::new(move |hostname, path, method, _port_type, headers| { tracing::debug!("building cache key"); let target = match read_target(headers) { @@ -23,11 +23,11 @@ pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn { Err(err) => { tracing::debug!(?err, "failed parsing target for cache key"); - return Ok(host_path_cache_key(hostname, path)); + return Ok(host_path_method_cache_key(hostname, path, method)); } }; - let cache_key = match actor::build_cache_key(target, path, headers) { + let cache_key = match actor::build_cache_key(target, path, method, headers) { Ok(key) => Some(key), Err(err) => { tracing::debug!(?err, "failed to create actor cache key"); @@ -36,11 +36,11 @@ pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn { } }; - // Fallback to hostname + path hash if actor did not work + // Fallback to hostname + path + method hash if actor did not work if let Some(cache_key) = cache_key { Ok(cache_key) } else { - Ok(host_path_cache_key(hostname, path)) + Ok(host_path_method_cache_key(hostname, path, method)) } }) } @@ -57,12 +57,13 @@ fn read_target(headers: &hyper::HeaderMap) -> Result<&str> { Ok(target.to_str()?) } -fn host_path_cache_key(hostname: &str, path: &str) -> u64 { +fn host_path_method_cache_key(hostname: &str, path: &str, method: &hyper::Method) -> u64 { // Extract just the hostname, stripping the port if present let hostname_only = hostname.split(':').next().unwrap_or(hostname); let mut hasher = DefaultHasher::new(); hostname_only.hash(&mut hasher); path.hash(&mut hasher); + method.as_str().hash(&mut hasher); hasher.finish() }