Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,15 @@ pub type RoutingFn = Arc<
>;

pub type CacheKeyFn = Arc<
dyn for<'a> Fn(&'a str, &'a str, PortType, &'a hyper::HeaderMap) -> Result<u64> + Send + Sync,
dyn for<'a> Fn(
&'a str,
&'a str,
&'a hyper::Method,
PortType,
&'a hyper::HeaderMap,
) -> Result<u64>
+ Send
+ Sync,
>;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -375,6 +383,7 @@ impl ProxyState {
&self,
hostname: &str,
path: &str,
method: &hyper::Method,
port_type: PortType,
headers: &hyper::HeaderMap,
ignore_cache: bool,
Expand All @@ -385,11 +394,13 @@ impl ProxyState {
tracing::debug!(
hostname = %hostname_only,
path = %path,
method = %method,
port_type = ?port_type,
"Resolving route for request"
);

let cache_key = (self.cache_key_fn)(hostname_only, &path, port_type.clone(), headers)?;
let cache_key =
(self.cache_key_fn)(hostname_only, &path, method, port_type.clone(), headers)?;

// Check cache first
if !ignore_cache {
Expand Down Expand Up @@ -700,6 +711,7 @@ impl ProxyService {
.resolve_route(
host,
&path,
req.method(),
self.state.port_type.clone(),
req.headers(),
false,
Expand Down Expand Up @@ -922,6 +934,7 @@ impl ProxyService {
.resolve_route(
&host,
&path,
&req_parts.method,
self.state.port_type.clone(),
&req_parts.headers,
true,
Expand Down Expand Up @@ -996,6 +1009,7 @@ impl ProxyService {
.resolve_route(
&host,
&path,
&req_parts.method,
self.state.port_type.clone(),
&req_parts.headers,
true,
Expand Down Expand Up @@ -1060,6 +1074,7 @@ impl ProxyService {
.resolve_route(
&host,
&path,
req_collected.method(),
self.state.port_type.clone(),
&req_headers,
true,
Expand Down Expand Up @@ -1155,8 +1170,9 @@ impl ProxyService {
.map(|x| x.to_string())
.unwrap_or_else(|| req.uri().path().to_string());

// Capture headers before request is consumed
// Capture headers and method before request is consumed
let req_headers = req.headers().clone();
let req_method = req.method().clone();
let ray_id = req.extensions().get::<RequestIds>().map(|x| x.ray_id);

// Get middleware config for this actor if it exists
Expand Down Expand Up @@ -1449,6 +1465,7 @@ impl ProxyService {
.resolve_route(
&req_host,
&req_path,
&req_method,
state.port_type.clone(),
&req_headers,
true,
Expand Down Expand Up @@ -1907,6 +1924,7 @@ impl ProxyService {
.resolve_route(
&req_host,
&req_path,
&req_method,
state.port_type.clone(),
&req_headers,
true,
Expand Down
1 change: 1 addition & 0 deletions engine/packages/guard-core/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ pub fn create_test_cache_key_fn() -> CacheKeyFn {
Arc::new(
move |hostname: &str,
path: &str,
_method: &hyper::Method,
_port_type: rivet_guard_core::proxy_service::PortType,
_headers: &hyper::HeaderMap| {
// Extract just the hostname, stripping the port if present
Expand Down
10 changes: 8 additions & 2 deletions engine/packages/guard/src/cache/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ use gas::prelude::*;
use crate::routing::pegboard_gateway::X_RIVET_ACTOR;

#[tracing::instrument(skip_all)]
pub fn build_cache_key(target: &str, path: &str, headers: &hyper::HeaderMap) -> Result<u64> {
pub fn build_cache_key(
target: &str,
path: &str,
method: &hyper::Method,
headers: &hyper::HeaderMap,
) -> Result<u64> {
// Check target
ensure!(target == "actor", "wrong target");

Expand All @@ -26,11 +31,12 @@ pub fn build_cache_key(target: &str, path: &str, headers: &hyper::HeaderMap) ->
.context("invalid x-rivet-actor header")?;
let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?;

// Create a hash using target, actor_id, and path
// Create a hash using target, actor_id, path, and method
let mut hasher = DefaultHasher::new();
target.hash(&mut hasher);
actor_id.hash(&mut hasher);
path.hash(&mut hasher);
method.as_str().hash(&mut hasher);
let hash = hasher.finish();

Ok(hash)
Expand Down
13 changes: 7 additions & 6 deletions engine/packages/guard/src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@ use crate::routing::X_RIVET_TARGET;
/// Creates the main cache key function that handles all incoming requests
#[tracing::instrument(skip_all)]
pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn {
Arc::new(move |hostname, path, _port_type, headers| {
Arc::new(move |hostname, path, method, _port_type, headers| {
tracing::debug!("building cache key");

let target = match read_target(headers) {
Ok(target) => target,
Err(err) => {
tracing::debug!(?err, "failed parsing target for cache key");

return Ok(host_path_cache_key(hostname, path));
return Ok(host_path_method_cache_key(hostname, path, method));
}
};

let cache_key = match actor::build_cache_key(target, path, headers) {
let cache_key = match actor::build_cache_key(target, path, method, headers) {
Ok(key) => Some(key),
Err(err) => {
tracing::debug!(?err, "failed to create actor cache key");
Expand All @@ -36,11 +36,11 @@ pub fn create_cache_key_function(_ctx: StandaloneCtx) -> CacheKeyFn {
}
};

// Fallback to hostname + path hash if actor did not work
// Fallback to hostname + path + method hash if actor did not work
if let Some(cache_key) = cache_key {
Ok(cache_key)
} else {
Ok(host_path_cache_key(hostname, path))
Ok(host_path_method_cache_key(hostname, path, method))
}
})
}
Expand All @@ -57,12 +57,13 @@ fn read_target(headers: &hyper::HeaderMap) -> Result<&str> {
Ok(target.to_str()?)
}

fn host_path_cache_key(hostname: &str, path: &str) -> u64 {
fn host_path_method_cache_key(hostname: &str, path: &str, method: &hyper::Method) -> u64 {
// Extract just the hostname, stripping the port if present
let hostname_only = hostname.split(':').next().unwrap_or(hostname);

let mut hasher = DefaultHasher::new();
hostname_only.hash(&mut hasher);
path.hash(&mut hasher);
method.as_str().hash(&mut hasher);
hasher.finish()
}
Loading