From bcf6b346977ed8d4383e45268fdf65aab866c0be Mon Sep 17 00:00:00 2001 From: wayslog Date: Fri, 7 Nov 2025 18:06:41 +0800 Subject: [PATCH 1/2] feat: add client cache --- Cargo.lock | 56 ++- Cargo.toml | 2 + src/cache/mod.rs | 808 ++++++++++++++++++++++++++++++++++++++++++ src/cache/tracker.rs | 340 ++++++++++++++++++ src/cluster/mod.rs | 48 ++- src/config/mod.rs | 229 ++++++++++++ src/lib.rs | 1 + src/metrics/mod.rs | 73 ++++ src/standalone/mod.rs | 36 ++ 9 files changed, 1591 insertions(+), 2 deletions(-) create mode 100644 src/cache/mod.rs create mode 100644 src/cache/tracker.rs diff --git a/Cargo.lock b/Cargo.lock index 0bb2460..99e7eac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,19 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -98,6 +111,7 @@ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" name = "aster-proxy" version = "1.3.4" dependencies = [ + "ahash", "anyhow", "arc-swap", "async-trait", @@ -113,6 +127,7 @@ dependencies = [ "prometheus", "rand", "serde", + "smallvec", "socket2 0.5.10", "sysinfo", "tokio", @@ -430,6 +445,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "gimli" version = "0.32.3" @@ -777,6 +804,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" version = "0.8.5" @@ -804,7 +837,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.16", ] [[package]] @@ -1229,12 +1262,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1437,6 +1485,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/Cargo.toml b/Cargo.toml index e10e543..bf00c69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ md5 = "0.7" rand = "0.8" socket2 = "0.5" arc-swap = "1" +smallvec = "1.11" +ahash = "0.8" [profile.release] debug = true diff --git a/src/cache/mod.rs b/src/cache/mod.rs new file mode 100644 index 0000000..abaf2ad --- /dev/null +++ b/src/cache/mod.rs @@ -0,0 +1,808 @@ +use std::cmp::{max, Reverse}; +use std::collections::BinaryHeap; +use std::hash::{Hash, Hasher}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::{Arc, Weak}; +use std::time::Duration; + +use ahash::{AHasher, RandomState}; +use anyhow::{anyhow, bail, Result}; +use arc_swap::ArcSwap; +use bytes::Bytes; +use hashbrown::HashMap; +use parking_lot::{Mutex, RwLock}; +use smallvec::SmallVec; +use tokio::sync::watch; +use tokio::task::JoinHandle; +use tokio::time::sleep; +use tracing::warn; + +use crate::config::ClientCacheConfig; +use crate::metrics; +use crate::protocol::redis::{RedisCommand, RespValue}; + +pub mod tracker; + +const STATE_DISABLED: u8 = 0; +const STATE_ENABLED: u8 = 1; +const STATE_DRAINING: u8 = 2; + +const MAX_MULTI_KEYS: usize = 64; +const MAX_HASH_FIELD_CACHED: usize = 64; + +/// Operational state for the cache, observable by trackers. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CacheState { + Disabled, + Enabled, + Draining, +} + +impl CacheState { + fn from_u8(value: u8) -> Self { + match value { + STATE_ENABLED => CacheState::Enabled, + STATE_DRAINING => CacheState::Draining, + _ => CacheState::Disabled, + } + } + + fn as_u8(self) -> u8 { + match self { + CacheState::Disabled => STATE_DISABLED, + CacheState::Enabled => STATE_ENABLED, + CacheState::Draining => STATE_DRAINING, + } + } +} + +/// High-performance local cache with RESP3 invalidation support. +pub struct ClientCache { + cluster: Arc, + state: AtomicU8, + resp3_ready: bool, + shards: ArcSwap>, + config: RwLock, + drain_handle: Mutex>>, + state_tx: watch::Sender, +} + +impl ClientCache { + pub fn new(cluster: Arc, config: ClientCacheConfig, resp3_ready: bool) -> Self { + let shard_count = config.shard_count.max(1); + let per_shard = entries_per_shard(config.max_entries, shard_count); + let shards = (0..shard_count) + .map(|_| CacheShard::new(per_shard)) + .collect::>(); + let initial_state = if config.enabled && resp3_ready { + CacheState::Enabled + } else { + CacheState::Disabled + }; + let (state_tx, _state_rx) = watch::channel(initial_state); + let cache = Self { + cluster, + state: AtomicU8::new(initial_state.as_u8()), + resp3_ready, + shards: ArcSwap::from_pointee(shards), + config: RwLock::new(ClientCacheConfig { enabled: initial_state == CacheState::Enabled, ..config }), + drain_handle: Mutex::new(None), + state_tx, + }; + if !cache.resp3_ready && cache.config.read().enabled { + warn!(cluster = %cache.cluster, "client cache enabled in config but backend RESP3 is unavailable; keeping disabled"); + cache.state.store(STATE_DISABLED, Ordering::SeqCst); + cache.state_tx.send_replace(CacheState::Disabled); + } + cache + } + + pub fn state(&self) -> CacheState { + CacheState::from_u8(self.state.load(Ordering::Relaxed)) + } + + pub fn subscribe(&self) -> watch::Receiver { + self.state_tx.subscribe() + } + + pub fn enable(self: &Arc) -> Result<()> { + if !self.resp3_ready { + bail!("client cache requires RESP3 backend support"); + } + let prev = self + .state + .swap(STATE_ENABLED, Ordering::SeqCst); + self.stop_drain_task(); + self.state_tx.send_replace(CacheState::Enabled); + if prev != STATE_ENABLED { + metrics::client_cache_state(self.cluster.as_ref(), "enabled"); + } + { + let mut cfg = self.config.write(); + cfg.enabled = true; + } + Ok(()) + } + + pub fn disable(self: &Arc) { + let prev = self + .state + .swap(STATE_DRAINING, Ordering::SeqCst); + if prev == STATE_DISABLED { + self.state_tx.send_replace(CacheState::Disabled); + return; + } + self.state_tx.send_replace(CacheState::Draining); + metrics::client_cache_state(self.cluster.as_ref(), "draining"); + { + let mut cfg = self.config.write(); + cfg.enabled = false; + } + self.start_drain_task(); + } + + pub fn lookup(&self, command: &RedisCommand) -> Option { + if self.state() != CacheState::Enabled { + return None; + } + match classify_read(command) { + CacheRead::Single { kind, key, field } => { + let hit = self.lookup_single(kind, key, field); + metrics::client_cache_lookup( + self.cluster.as_ref(), + kind.label(), + hit.is_some(), + ); + hit + } + CacheRead::Multi { keys } => { + let hit = self.lookup_multi(&keys); + metrics::client_cache_lookup(self.cluster.as_ref(), "multi", hit.is_some()); + hit.map(RespValue::Array) + } + CacheRead::Unsupported => None, + } + } + + pub fn store(&self, command: &RedisCommand, response: &RespValue) { + if self.state() != CacheState::Enabled { + return; + } + if response.is_error() { + return; + } + let snapshot = self.config.read().clone(); + match classify_read(command) { + CacheRead::Single { kind, key, field } => { + self.store_single(&snapshot, kind, key, field, response); + } + CacheRead::Multi { keys } => { + self.store_multi(&snapshot, &keys, response); + } + CacheRead::Unsupported => {} + } + } + + /// Returns true if this command is a cacheable read handled by the client cache. + pub fn is_cacheable_read(command: &RedisCommand) -> bool { + matches!( + classify_read(command), + CacheRead::Single { .. } | CacheRead::Multi { .. } + ) + } + + /// Returns true if this command can invalidate cached keys. + pub fn is_invalidating_write(command: &RedisCommand) -> bool { + classify_write(command).is_some() + } + + pub fn invalidate_bytes>(&self, keys: &[B]) { + if keys.is_empty() { + return; + } + let shards = self.shards.load(); + let shard_total = shards.len().max(1); + drop(shards); + let mut removed = 0usize; + for key in keys { + removed += self.invalidate_primary(key.as_ref(), shard_total); + } + if removed > 0 { + metrics::client_cache_invalidate(self.cluster.as_ref(), removed); + } + } + + pub fn invalidate_command(&self, command: &RedisCommand) { + if let Some(action) = classify_write(command) { + match action { + CacheWrite::FlushAll => self.flush(), + CacheWrite::Keys(keys) => { + let total = self.shards.load().len().max(1); + for key in keys { + self.invalidate_primary(key.as_ref(), total); + } + } + } + } + } + + pub fn flush(&self) { + let shard_count = self.shards.load().len().max(1); + self.rebuild_shards(shard_count); + metrics::client_cache_state(self.cluster.as_ref(), "flushed"); + } + + pub fn set_max_entries(&self, value: usize) { + { + let mut cfg = self.config.write(); + cfg.max_entries = value.max(1); + } + let shards = self.shards.load(); + let per_shard = entries_per_shard(value, shards.len().max(1)); + for shard in shards.iter() { + shard.set_capacity(per_shard); + } + } + + pub fn set_max_value_bytes(&self, value: usize) { + let mut cfg = self.config.write(); + cfg.max_value_bytes = value.max(1); + } + + pub fn set_shard_count(self: &Arc, count: usize) { + let count = count.max(1).min(usize::MAX / 2); + { + let mut cfg = self.config.write(); + cfg.shard_count = count; + } + self.rebuild_shards(count); + } + + pub fn set_drain_batch(&self, value: usize) { + let mut cfg = self.config.write(); + cfg.drain_batch = value.max(1); + } + + pub fn set_drain_interval(&self, value: u64) { + let mut cfg = self.config.write(); + cfg.drain_interval_ms = value.max(1); + } + + fn lookup_single( + &self, + kind: CacheCommandKind, + key: &Bytes, + field: Option<&Bytes>, + ) -> Option { + let shards = self.shards.load(); + let index = shard_index(key.as_ref(), shards.len().max(1)); + shards[index].get(kind, key, field) + } + + fn lookup_multi(&self, keys: &[&Bytes]) -> Option> { + let mut results = Vec::with_capacity(keys.len()); + for key in keys { + if let Some(value) = self.lookup_single(CacheCommandKind::Value, key, None) { + results.push(value); + } else { + return None; + } + } + Some(results) + } + + fn store_single( + &self, + config: &ClientCacheConfig, + kind: CacheCommandKind, + key: &Bytes, + field: Option<&Bytes>, + response: &RespValue, + ) { + let normalized = match normalize_value(kind, response) { + Some(value) => value, + None => return, + }; + let footprint = resp_size(&normalized); + if footprint > config.max_value_bytes { + return; + } + let entry = CacheEntry::new(normalized, footprint); + let cache_key = CacheKey::new(kind, key.clone(), field.cloned()); + let shards = self.shards.load(); + let index = shard_index(cache_key.primary.as_ref(), shards.len().max(1)); + shards[index].put(cache_key, entry); + metrics::client_cache_store(self.cluster.as_ref(), kind.label()); + } + + fn store_multi( + &self, + config: &ClientCacheConfig, + keys: &[&Bytes], + response: &RespValue, + ) { + let values = match response.as_array() { + Some(values) if values.len() == keys.len() => values, + _ => return, + }; + for (key, value) in keys.iter().zip(values.iter()) { + self.store_single(config, CacheCommandKind::Value, key, None, value); + } + } + + fn invalidate_primary(&self, primary: &[u8], shard_total: usize) -> usize { + let shards = self.shards.load(); + let idx = shard_index(primary, shard_total); + shards[idx].remove_primary(primary) + } + + fn rebuild_shards(&self, count: usize) { + let cfg = self.config.read().clone(); + let per_shard = entries_per_shard(cfg.max_entries, count); + let shards = (0..count) + .map(|_| CacheShard::new(per_shard)) + .collect::>(); + self.shards.store(Arc::new(shards)); + } + + fn total_entries(&self) -> usize { + let shards = self.shards.load(); + shards.iter().map(|shard| shard.len()).sum() + } + + fn start_drain_task(self: &Arc) { + let mut guard = self.drain_handle.lock(); + if let Some(handle) = guard.take() { + handle.abort(); + } + let weak = Arc::downgrade(self); + let handle = tokio::spawn(async move { + while let Some(cache) = weak.upgrade() { + let batch = cache.config.read().drain_batch; + let interval = cache.config.read().drain_interval_ms; + let removed = cache.drain_once(batch); + if removed == 0 { + cache.finish_draining(); + break; + } + sleep(Duration::from_millis(interval)).await; + } + }); + *guard = Some(handle); + } + + fn stop_drain_task(&self) { + if let Some(handle) = self.drain_handle.lock().take() { + handle.abort(); + } + } + + fn finish_draining(&self) { + self.state.store(STATE_DISABLED, Ordering::SeqCst); + self.state_tx.send_replace(CacheState::Disabled); + metrics::client_cache_state(self.cluster.as_ref(), "disabled"); + } + + fn drain_once(&self, batch: usize) -> usize { + let shards = self.shards.load(); + let len = shards.len().max(1); + let per_shard = max(1, (batch.max(1) + len - 1) / len); + let mut removed = 0usize; + for shard in shards.iter() { + removed += shard.evict_batch(per_shard); + } + removed + } +} + +fn normalize_value(kind: CacheCommandKind, resp: &RespValue) -> Option { + match kind { + CacheCommandKind::Value => match resp { + RespValue::BulkString(_) | RespValue::SimpleString(_) | RespValue::Null | RespValue::NullBulk => + Some(resp.clone()), + _ => None, + }, + CacheCommandKind::HashField => match resp { + RespValue::BulkString(_) | RespValue::SimpleString(_) | RespValue::Null | RespValue::NullBulk => + Some(resp.clone()), + _ => None, + }, + } +} + +fn resp_size(value: &RespValue) -> usize { + match value { + RespValue::SimpleString(data) + | RespValue::BulkString(data) + | RespValue::Error(data) + | RespValue::BlobError(data) + | RespValue::Double(data) + | RespValue::BigNumber(data) => data.len(), + RespValue::Integer(_) => std::mem::size_of::(), + RespValue::Null + | RespValue::NullBulk + | RespValue::NullArray => 1, + RespValue::Boolean(_) => 1, + RespValue::Map(entries) | RespValue::Attribute(entries) => { + entries.iter().map(|(k, v)| resp_size(k) + resp_size(v)).sum() + } + RespValue::Array(values) | RespValue::Set(values) | RespValue::Push(values) => { + values.iter().map(resp_size).sum() + } + RespValue::VerbatimString { data, .. } => data.len(), + } +} + +fn shard_index(key: &[u8], shards: usize) -> usize { + let mut hasher = AHasher::default(); + hasher.write(key); + (hasher.finish() as usize) % shards.max(1) +} + +fn entries_per_shard(entries: usize, shards: usize) -> usize { + let shards = shards.max(1); + let per = (entries + shards - 1) / shards; + per.max(1) +} + +fn upper_name(input: &[u8]) -> Vec { + input.iter().map(|b| b.to_ascii_uppercase()).collect() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum CacheCommandKind { + Value, + HashField, +} + +impl CacheCommandKind { + fn label(self) -> &'static str { + match self { + CacheCommandKind::Value => "value", + CacheCommandKind::HashField => "hash_field", + } + } +} + +#[derive(Debug, Clone)] +struct CacheEntry { + value: RespValue, + access: u64, + size: usize, +} + +impl CacheEntry { + fn new(value: RespValue, size: usize) -> Self { + Self { + value, + access: 0, + size, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct CacheKey { + kind: CacheCommandKind, + primary: Bytes, + secondary: Option, +} + +impl CacheKey { + fn new(kind: CacheCommandKind, primary: Bytes, secondary: Option) -> Self { + Self { + kind, + primary, + secondary, + } + } +} + +#[derive(Debug)] +struct CacheShard { + inner: Mutex, +} + +impl CacheShard { + fn new(capacity: usize) -> Self { + Self { + inner: Mutex::new(CacheShardInner::new(capacity)), + } + } + + fn len(&self) -> usize { + self.inner.lock().len() + } + + fn get( + &self, + kind: CacheCommandKind, + key: &Bytes, + field: Option<&Bytes>, + ) -> Option { + let mut guard = self.inner.lock(); + guard.touch(&CacheKey::new(kind, key.clone(), field.cloned())) + } + + fn put(&self, key: CacheKey, entry: CacheEntry) { + let mut guard = self.inner.lock(); + guard.insert(key, entry); + } + + fn remove_primary(&self, primary: &[u8]) -> usize { + let mut guard = self.inner.lock(); + guard.remove_primary(primary) + } + + fn set_capacity(&self, capacity: usize) { + let mut guard = self.inner.lock(); + guard.set_capacity(capacity); + } + + fn evict_batch(&self, batch: usize) -> usize { + let mut guard = self.inner.lock(); + guard.evict_batch(batch) + } +} + +#[derive(Debug)] +struct CacheShardInner { + entries: HashMap, + order: BinaryHeap>, + per_key: HashMap, SmallVec<[CacheKey; 4]>, RandomState>, + counter: u64, + capacity: usize, +} + +impl CacheShardInner { + fn new(capacity: usize) -> Self { + Self { + entries: HashMap::with_hasher(RandomState::new()), + order: BinaryHeap::new(), + per_key: HashMap::with_hasher(RandomState::new()), + counter: 0, + capacity: capacity.max(1), + } + } + + fn len(&self) -> usize { + self.entries.len() + } + + fn touch(&mut self, key: &CacheKey) -> Option { + if let Some(entry) = self.entries.get_mut(key) { + entry.access = self.next_access(); + self.order + .push(Reverse(HeapEntry::new(entry.access, key.clone()))); + Some(entry.value.clone()) + } else { + None + } + } + + fn insert(&mut self, key: CacheKey, mut entry: CacheEntry) { + entry.access = self.next_access(); + if let Some(old) = self.entries.insert(key.clone(), entry.clone()) { + self.detach(&key); + drop(old); + } + self.attach(key.clone()); + self.order + .push(Reverse(HeapEntry::new(entry.access, key.clone()))); + self.enforce_capacity(); + } + + fn remove_primary(&mut self, primary: &[u8]) -> usize { + let keys = match self.per_key.remove(primary) { + Some(keys) => keys, + None => return 0, + }; + let mut removed = 0usize; + for cache_key in keys { + if self.entries.remove(&cache_key).is_some() { + removed += 1; + } + } + removed + } + + fn set_capacity(&mut self, capacity: usize) { + self.capacity = capacity.max(1); + self.enforce_capacity(); + } + + fn evict_batch(&mut self, batch: usize) -> usize { + let mut removed = 0usize; + for _ in 0..batch.max(1) { + if self.pop_lru().is_some() { + removed += 1; + } else { + break; + } + } + removed + } + + fn attach(&mut self, key: CacheKey) { + self.per_key + .entry(key.primary.to_vec()) + .or_default() + .push(key); + } + + fn detach(&mut self, key: &CacheKey) { + if let Some(list) = self.per_key.get_mut(key.primary.as_ref()) { + if let Some(pos) = list.iter().position(|existing| existing == key) { + list.swap_remove(pos); + } + if list.is_empty() { + self.per_key.remove(key.primary.as_ref()); + } + } + } + + fn enforce_capacity(&mut self) { + while self.entries.len() > self.capacity { + if self.pop_lru().is_none() { + break; + } + } + } + + fn pop_lru(&mut self) -> Option<(CacheKey, CacheEntry)> { + while let Some(Reverse(entry)) = self.order.pop() { + if let Some(stored) = self.entries.get(&entry.key) { + if stored.access == entry.access { + self.detach(&entry.key); + return self.entries.remove(&entry.key).map(|value| (entry.key, value)); + } + } + } + None + } + + fn next_access(&mut self) -> u64 { + self.counter = self.counter.wrapping_add(1); + self.counter + } +} + +#[derive(Debug, Clone)] +struct HeapEntry { + access: u64, + key: CacheKey, +} + +impl HeapEntry { + fn new(access: u64, key: CacheKey) -> Self { + Self { access, key } + } +} + +impl PartialEq for HeapEntry { + fn eq(&self, other: &Self) -> bool { + self.access == other.access && self.key == other.key + } +} + +impl Eq for HeapEntry {} + +impl PartialOrd for HeapEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for HeapEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.access.cmp(&other.access) + } +} + +#[derive(Debug)] +enum CacheRead<'a> { + Single { + kind: CacheCommandKind, + key: &'a Bytes, + field: Option<&'a Bytes>, + }, + Multi { + keys: SmallVec<[&'a Bytes; MAX_MULTI_KEYS]>, + }, + Unsupported, +} + +#[derive(Debug)] +enum CacheWrite<'a> { + FlushAll, + Keys(SmallVec<[&'a Bytes; MAX_MULTI_KEYS]>), +} + +fn classify_read(command: &RedisCommand) -> CacheRead<'_> { + let args = command.args(); + if args.is_empty() { + return CacheRead::Unsupported; + } + let upper = upper_name(command.command_name()); + match upper.as_slice() { + b"GET" => { + if args.len() < 2 { + return CacheRead::Unsupported; + } + CacheRead::Single { + kind: CacheCommandKind::Value, + key: &args[1], + field: None, + } + } + b"HGET" => { + if args.len() < 3 { + return CacheRead::Unsupported; + } + CacheRead::Single { + kind: CacheCommandKind::HashField, + key: &args[1], + field: Some(&args[2]), + } + } + b"MGET" => { + if args.len() < 2 || args.len() - 1 > MAX_MULTI_KEYS { + return CacheRead::Unsupported; + } + let mut keys = SmallVec::<[&Bytes; MAX_MULTI_KEYS]>::new(); + for key in &args[1..] { + keys.push(key); + } + CacheRead::Multi { keys } + } + _ => CacheRead::Unsupported, + } +} + +fn classify_write(command: &RedisCommand) -> Option> { + let args = command.args(); + if args.is_empty() { + return None; + } + let name = upper_name(command.command_name()); + match name.as_slice() { + b"FLUSHALL" | b"FLUSHDB" => Some(CacheWrite::FlushAll), + b"DEL" | b"UNLINK" => { + if args.len() < 2 { + return None; + } + let mut keys = SmallVec::<[&Bytes; MAX_MULTI_KEYS]>::new(); + for key in &args[1..] { + if keys.len() >= MAX_MULTI_KEYS { + break; + } + keys.push(key); + } + Some(CacheWrite::Keys(keys)) + } + b"SET" | b"SETEX" | b"PSETEX" | b"SETNX" | b"HSET" | b"HDEL" => { + if args.len() < 2 { + return None; + } + let mut keys = SmallVec::<[&Bytes; 1]>::new(); + keys.push(&args[1]); + Some(CacheWrite::Keys(keys)) + } + b"MSET" => { + if args.len() < 3 { + return None; + } + let mut keys = SmallVec::<[&Bytes; MAX_MULTI_KEYS]>::new(); + let mut idx = 1; + while idx < args.len() { + if keys.len() >= MAX_MULTI_KEYS { + break; + } + keys.push(&args[idx]); + idx += 2; + } + Some(CacheWrite::Keys(keys)) + } + _ => None, + } +} diff --git a/src/cache/tracker.rs b/src/cache/tracker.rs new file mode 100644 index 0000000..d965288 --- /dev/null +++ b/src/cache/tracker.rs @@ -0,0 +1,340 @@ +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::{anyhow, Context, Result}; +use bytes::Bytes; +use futures::{SinkExt, StreamExt}; +use parking_lot::Mutex; +#[cfg(any(unix, windows))] +use socket2::{SockRef, TcpKeepalive}; +use tokio::net::TcpStream; +use tokio::sync::watch; +use tokio::task::JoinHandle; +use tokio::time::{sleep, timeout}; +use tokio_util::codec::Framed; +use tracing::{debug, warn}; + +use crate::auth::BackendAuth; +use crate::cache::{CacheState, ClientCache}; +use crate::config::ClusterRuntime; +use crate::protocol::redis::{RespCodec, RespValue, RespVersion}; + +const TRACKER_BACKOFF_START: Duration = Duration::from_millis(200); +const TRACKER_BACKOFF_MAX: Duration = Duration::from_secs(2); + +#[derive(Debug)] +pub struct CacheTrackerSet { + cluster: Arc, + cache: Arc, + runtime: Arc, + backend_auth: Option, + timeout_ms: u64, + handles: Mutex, TrackerHandle>>, +} + +#[derive(Debug)] +struct TrackerHandle { + shutdown: watch::Sender, + join: JoinHandle<()>, +} + +impl CacheTrackerSet { + pub fn new( + cluster: Arc, + cache: Arc, + runtime: Arc, + backend_auth: Option, + timeout_ms: u64, + ) -> Self { + Self { + cluster, + cache, + runtime, + backend_auth, + timeout_ms, + handles: Mutex::new(HashMap::new()), + } + } + + pub fn set_nodes(&self, nodes: Vec) { + let mut guard = self.handles.lock(); + let desired: HashSet> = nodes.iter().map(|n| Arc::::from(n.clone())).collect(); + + guard.retain(|addr, handle| { + if desired.contains(addr) { + true + } else { + let _ = handle.shutdown.send(true); + handle.join.abort(); + false + } + }); + + for node in desired { + if guard.contains_key(&node) { + continue; + } + let handle = self.spawn_tracker(node.clone()); + guard.insert(node, handle); + } + } + + fn spawn_tracker(&self, address: Arc) -> TrackerHandle { + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let cluster = self.cluster.clone(); + let cache = self.cache.clone(); + let runtime = self.runtime.clone(); + let auth = self.backend_auth.clone(); + let timeout_ms = self.timeout_ms; + let mut state_rx = cache.subscribe(); + let mut local_shutdown = shutdown_rx; + let join = tokio::spawn(async move { + let mut backoff = TRACKER_BACKOFF_START; + loop { + if *local_shutdown.borrow() { + break; + } + if wait_for_enabled(&mut state_rx, &mut local_shutdown).await { + match listen_once( + cluster.clone(), + cache.clone(), + runtime.clone(), + auth.clone(), + timeout_ms, + address.clone(), + &mut state_rx, + &mut local_shutdown, + ) + .await + { + Ok(()) => { + backoff = TRACKER_BACKOFF_START; + } + Err(err) => { + warn!( + cluster = %cluster, + backend = %address, + error = %err, + "client cache tracker failed" + ); + if wait_with_shutdown(backoff, &mut local_shutdown).await { + break; + } + backoff = (backoff * 2).min(TRACKER_BACKOFF_MAX); + } + } + } else { + break; + } + } + debug!(cluster = %cluster, backend = %address, "client cache tracker stopped"); + }); + TrackerHandle { + shutdown: shutdown_tx, + join, + } + } +} + +impl Drop for CacheTrackerSet { + fn drop(&mut self) { + let handles = self.handles.lock().drain().collect::>(); + for (_, handle) in handles { + let _ = handle.shutdown.send(true); + handle.join.abort(); + } + } +} + +async fn wait_for_enabled( + state_rx: &mut watch::Receiver, + shutdown: &mut watch::Receiver, +) -> bool { + loop { + if *shutdown.borrow() { + return false; + } + match *state_rx.borrow() { + CacheState::Enabled => return true, + CacheState::Disabled | CacheState::Draining => {} + } + tokio::select! { + biased; + _ = shutdown.changed() => { + if *shutdown.borrow() { + return false; + } + } + changed = state_rx.changed() => { + if changed.is_err() { + return false; + } + } + } + } +} + +async fn wait_with_shutdown(delay: Duration, shutdown: &mut watch::Receiver) -> bool { + tokio::select! { + _ = shutdown.changed() => true, + _ = sleep(delay) => *shutdown.borrow(), + } +} + +async fn listen_once( + cluster: Arc, + cache: Arc, + runtime: Arc, + backend_auth: Option, + timeout_ms: u64, + address: Arc, + state_rx: &mut watch::Receiver, + shutdown: &mut watch::Receiver, +) -> Result<()> { + let timeout_duration = runtime.request_timeout(timeout_ms); + let mut framed = open_stream(&address, timeout_duration, backend_auth.clone()).await?; + negotiate_resp3(&cluster, &address, timeout_duration, &mut framed, backend_auth.clone()).await?; + enable_tracking(&cluster, &address, timeout_duration, &mut framed).await?; + + loop { + tokio::select! { + _ = shutdown.changed() => return Ok(()), + changed = state_rx.changed() => { + if changed.is_err() || *state_rx.borrow() != CacheState::Enabled { + return Ok(()); + } + } + frame = framed.next() => match frame { + Some(Ok(RespValue::Push(items))) => { + if let Some(keys) = parse_invalidation(&items) { + cache.invalidate_bytes(&keys); + } + } + Some(Ok(_)) => {} + Some(Err(err)) => return Err(err.into()), + None => return Err(anyhow!("tracking connection closed")), + } + } + } +} + +async fn open_stream( + address: &str, + timeout_duration: Duration, + backend_auth: Option, +) -> Result> { + let stream = timeout(timeout_duration, TcpStream::connect(address)) + .await + .with_context(|| format!("connect to {address} timed out"))??; + stream.set_nodelay(true).context("failed to enable TCP_NODELAY")?; + #[cfg(any(unix, windows))] + { + use socket2::{SockRef, TcpKeepalive}; + let keepalive = TcpKeepalive::new() + .with_time(Duration::from_secs(60)) + .with_interval(Duration::from_secs(60)); + if let Err(err) = SockRef::from(&stream).set_tcp_keepalive(&keepalive) { + warn!(backend = %address, error = %err, "failed to set tracker TCP keepalive"); + } + } + let mut framed = Framed::new(stream, RespCodec::default()); + if let Some(auth) = &backend_auth { + auth.apply_to_stream(&mut framed, timeout_duration, address) + .await?; + } + Ok(framed) +} + +async fn negotiate_resp3( + cluster: &str, + backend: &str, + timeout_duration: Duration, + framed: &mut Framed, + backend_auth: Option, +) -> Result<()> { + framed.codec_mut().set_version(RespVersion::Resp2); + let mut hello_parts = vec![ + RespValue::BulkString(Bytes::from_static(b"HELLO")), + RespValue::BulkString(Bytes::from_static(b"3")), + ]; + if let Some(auth) = backend_auth { + if let Some((username, password)) = auth.hello_credentials() { + hello_parts.push(RespValue::BulkString(Bytes::from_static(b"AUTH"))); + hello_parts.push(RespValue::BulkString(username)); + hello_parts.push(RespValue::BulkString(password)); + } + } + let hello = RespValue::Array(hello_parts); + timeout(timeout_duration, framed.send(hello)) + .await + .with_context(|| format!("failed to send HELLO to {backend}"))??; + match timeout(timeout_duration, framed.next()).await { + Ok(Some(Ok(resp))) => { + if resp.is_error() { + bail!("backend {backend} rejected HELLO for cluster {cluster}"); + } + } + Ok(Some(Err(err))) => return Err(err.context("HELLO handshake failed")), + Ok(None) => bail!("backend {backend} closed during HELLO"), + Err(_) => bail!("backend {backend} timed out waiting for HELLO reply"), + } + framed.codec_mut().set_version(RespVersion::Resp3); + Ok(()) +} + +async fn enable_tracking( + cluster: &str, + backend: &str, + timeout_duration: Duration, + framed: &mut Framed, +) -> Result<()> { + let command = RespValue::Array(vec![ + RespValue::BulkString(Bytes::from_static(b"CLIENT")), + RespValue::BulkString(Bytes::from_static(b"TRACKING")), + RespValue::BulkString(Bytes::from_static(b"ON")), + RespValue::BulkString(Bytes::from_static(b"BCAST")), + RespValue::BulkString(Bytes::from_static(b"NOLOOP")), + ]); + timeout(timeout_duration, framed.send(command)) + .await + .with_context(|| format!("failed to send CLIENT TRACKING to {backend}"))??; + match timeout(timeout_duration, framed.next()).await { + Ok(Some(Ok(resp))) => match resp { + RespValue::SimpleString(ref s) | RespValue::BulkString(ref s) + if s.eq_ignore_ascii_case(b"OK") => Ok(()), + RespValue::Error(err) => Err(anyhow!( + "backend {backend} rejected CLIENT TRACKING for cluster {cluster}: {}", + String::from_utf8_lossy(&err) + )), + other => Err(anyhow!( + "unexpected CLIENT TRACKING reply from {backend}: {:?}", + other + )), + }, + Ok(Some(Err(err))) => Err(err.context("CLIENT TRACKING failed")), + Ok(None) => Err(anyhow!("backend {backend} closed during CLIENT TRACKING")), + Err(_) => Err(anyhow!("backend {backend} timed out waiting for CLIENT TRACKING")), + } +} + +fn parse_invalidation(items: &[RespValue]) -> Option> { + if items.is_empty() { + return None; + } + let label = match &items[0] { + RespValue::SimpleString(data) | RespValue::BulkString(data) => data, + _ => return None, + }; + if !label.eq_ignore_ascii_case(b"invalidate") { + return None; + } + let mut keys = Vec::with_capacity(items.len().saturating_sub(1)); + for item in items.iter().skip(1) { + match item { + RespValue::BulkString(data) | RespValue::SimpleString(data) => keys.push(data.clone()), + _ => {} + } + } + Some(keys) +} diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 36abfec..c0e4c7a 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -18,6 +18,7 @@ use tokio_util::codec::{Framed, FramedParts}; use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; +use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::backend::client::{ClientId, FrontConnectionGuard}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager}; @@ -58,6 +59,8 @@ pub struct ClusterProxy { hotkey: Arc, listen_port: u16, seed_nodes: usize, + client_cache: Arc, + cache_trackers: Arc, } impl ClusterProxy { @@ -94,6 +97,16 @@ impl ClusterProxy { let hotkey = config_manager .hotkey_for(&config.name) .ok_or_else(|| anyhow!("missing hotkey state for cluster {}", config.name))?; + let client_cache = config_manager + .client_cache_for(&config.name) + .ok_or_else(|| anyhow!("missing client cache state for cluster {}", config.name))?; + let cache_trackers = Arc::new(CacheTrackerSet::new( + cluster.clone(), + client_cache.clone(), + runtime.clone(), + backend_auth.clone(), + REQUEST_TIMEOUT_MS, + )); let proxy = Self { cluster: cluster.clone(), hash_tag, @@ -109,6 +122,8 @@ impl ClusterProxy { hotkey, listen_port, seed_nodes: config.servers.len(), + client_cache, + cache_trackers: cache_trackers.clone(), }; // trigger an immediate topology fetch @@ -119,6 +134,7 @@ impl ClusterProxy { connector, proxy.slots.clone(), trigger_rx, + Some(cache_trackers), )); Ok(proxy) @@ -343,6 +359,18 @@ impl ClusterProxy { inflight += 1; continue; } + if let Some(hit) = self.client_cache.lookup(&cmd) { + self.hotkey.record_command(&cmd); + metrics::front_command( + self.cluster.as_ref(), + cmd.kind_label(), + true, + ); + let fut = async move { (hit, None) }; + pending.push_back(Box::pin(fut)); + inflight += 1; + continue; + } let requested_version = cmd.resp_version_request(); let guard = self.prepare_dispatch(client_id, cmd); let fut = async move { @@ -728,7 +756,11 @@ impl ClusterProxy { let slowlog = self.slowlog.clone(); let hotkey = self.hotkey.clone(); let kind_label = command.kind_label(); + let cache = self.client_cache.clone(); Box::pin(async move { + let cache_candidate = command.clone(); + let cacheable_read = ClientCache::is_cacheable_read(&cache_candidate); + let invalidating_write = ClientCache::is_invalidating_write(&cache_candidate); match dispatch_with_context( hash_tag, read_from_slave, @@ -744,10 +776,19 @@ impl ClusterProxy { { Ok(resp) => { let success = !matches!(resp, RespValue::Error(_)); + if success && cacheable_read { + cache.store(&cache_candidate, &resp); + } + if invalidating_write { + cache.invalidate_command(&cache_candidate); + } metrics::front_command(cluster.as_ref(), kind_label, success); resp } Err(err) => { + if invalidating_write { + cache.invalidate_command(&cache_candidate); + } metrics::global_error_incr(); metrics::front_command(cluster.as_ref(), kind_label, false); metrics::front_error(cluster.as_ref(), "dispatch"); @@ -1341,6 +1382,7 @@ async fn fetch_topology( connector: Arc, slots: Arc>, mut trigger: mpsc::UnboundedReceiver<()>, + tracker: Option>, ) { let mut ticker = tokio::time::interval(FETCH_INTERVAL); loop { @@ -1349,7 +1391,7 @@ async fn fetch_topology( _ = trigger.recv() => {}, } - if let Err(err) = fetch_once(&cluster, &seeds, connector.clone(), slots.clone()).await { + if let Err(err) = fetch_once(&cluster, &seeds, connector.clone(), slots.clone(), tracker.clone()).await { warn!(cluster = %cluster, error = %err, "failed to refresh cluster topology"); } } @@ -1360,6 +1402,7 @@ async fn fetch_once( seeds: &[String], connector: Arc, slots: Arc>, + tracker: Option>, ) -> Result<()> { let mut shuffled = seeds.to_vec(); { @@ -1371,6 +1414,9 @@ async fn fetch_once( match fetch_from_seed(&seed, connector.clone()).await { Ok(map) => { slots.send_replace(map.clone()); + if let Some(ref watchers) = tracker { + watchers.set_nodes(map.all_nodes()); + } info!(cluster = %cluster, seed = %seed, "cluster slots refreshed"); return Ok(()); } diff --git a/src/config/mod.rs b/src/config/mod.rs index af305c9..7991d01 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -11,6 +11,7 @@ use tokio::fs; use tracing::{info, warn}; use crate::auth::{AuthUserConfig, BackendAuthConfig, FrontendAuthConfig}; +use crate::cache::ClientCache; use crate::hotkey::{ Hotkey, HotkeyConfig, DEFAULT_DECAY, DEFAULT_HOTKEY_CAPACITY, DEFAULT_SAMPLE_EVERY, DEFAULT_SKETCH_DEPTH, DEFAULT_SKETCH_WIDTH, @@ -55,6 +56,26 @@ fn default_backend_resp_version() -> RespVersion { RespVersion::Resp2 } +fn default_client_cache_max_entries() -> usize { + 100_000 +} + +fn default_client_cache_max_value_bytes() -> usize { + 512 * 1024 +} + +fn default_client_cache_shards() -> usize { + 32 +} + +fn default_client_cache_drain_batch() -> usize { + 1024 +} + +fn default_client_cache_drain_interval_ms() -> u64 { + 50 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { #[serde(default)] @@ -125,6 +146,56 @@ impl Default for CacheType { } } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ClientCacheConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default = "default_client_cache_max_entries")] + pub max_entries: usize, + #[serde(default = "default_client_cache_max_value_bytes")] + pub max_value_bytes: usize, + #[serde(default = "default_client_cache_shards")] + pub shard_count: usize, + #[serde(default = "default_client_cache_drain_batch")] + pub drain_batch: usize, + #[serde(default = "default_client_cache_drain_interval_ms")] + pub drain_interval_ms: u64, +} + +impl ClientCacheConfig { + pub fn ensure_valid(&self) -> Result<()> { + if self.max_entries == 0 { + bail!("client cache max_entries must be > 0"); + } + if self.max_value_bytes == 0 { + bail!("client cache max_value_bytes must be > 0"); + } + if self.shard_count == 0 { + bail!("client cache shard_count must be > 0"); + } + if self.drain_batch == 0 { + bail!("client cache drain_batch must be > 0"); + } + if self.drain_interval_ms == 0 { + bail!("client cache drain_interval_ms must be > 0"); + } + Ok(()) + } +} + +impl Default for ClientCacheConfig { + fn default() -> Self { + Self { + enabled: false, + max_entries: default_client_cache_max_entries(), + max_value_bytes: default_client_cache_max_value_bytes(), + shard_count: default_client_cache_shards(), + drain_batch: default_client_cache_drain_batch(), + drain_interval_ms: default_client_cache_drain_interval_ms(), + } + } +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ClusterConfig { pub name: String, @@ -183,6 +254,8 @@ pub struct ClusterConfig { pub hotkey_decay: f64, #[serde(default = "default_backend_resp_version")] pub backend_resp_version: RespVersion, + #[serde(default)] + pub client_cache: ClientCacheConfig, } impl ClusterConfig { @@ -201,6 +274,8 @@ impl ClusterConfig { bail!("cluster {} listen_addr cannot be empty", self.name); } + self.client_cache.ensure_valid()?; + parse_port(&self.listen_addr).with_context(|| { format!( "cluster {} listen_addr {} is not a valid address", @@ -366,6 +441,7 @@ struct ClusterEntry { runtime: Arc, slowlog: Arc, hotkey: Arc, + client_cache: Arc, } #[derive(Debug)] @@ -396,6 +472,12 @@ impl ConfigManager { decay: cluster.hotkey_decay, }; let hotkey = Arc::new(Hotkey::new(hotkey_config)); + let cluster_label: Arc = cluster.name.clone().into(); + let client_cache = Arc::new(ClientCache::new( + cluster_label.clone(), + cluster.client_cache.clone(), + cluster.backend_resp_version == RespVersion::Resp3, + )); clusters.insert( key, ClusterEntry { @@ -403,6 +485,7 @@ impl ConfigManager { runtime, slowlog, hotkey, + client_cache, }, ); } @@ -432,6 +515,12 @@ impl ConfigManager { .map(|entry| entry.hotkey.clone()) } + pub fn client_cache_for(&self, name: &str) -> Option> { + self.clusters + .get(&name.to_ascii_lowercase()) + .map(|entry| entry.client_cache.clone()) + } + pub async fn handle_command(&self, command: &RedisCommand) -> Option { if !command.command_name().eq_ignore_ascii_case(b"CONFIG") { return None; @@ -609,6 +698,79 @@ impl ConfigManager { "cluster hotkey_decay updated via CONFIG SET" ); } + ClusterField::ClientCacheEnabled => { + let enabled = parse_bool_flag(value, "client-cache-enabled")?; + if enabled { + entry + .client_cache + .enable() + .with_context(|| format!("cluster {} failed to enable client cache", cluster_name))?; + } else { + entry.client_cache.disable(); + } + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].client_cache.enabled = enabled; + info!( + cluster = cluster_name, + value = value, + "cluster client_cache.enabled updated via CONFIG SET" + ); + } + ClusterField::ClientCacheMaxEntries => { + let parsed = parse_positive_usize(value, "client-cache-max-entries")?; + entry.client_cache.set_max_entries(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].client_cache.max_entries = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster client_cache.max_entries updated via CONFIG SET" + ); + } + ClusterField::ClientCacheMaxValueBytes => { + let parsed = parse_positive_usize(value, "client-cache-max-value-bytes")?; + entry.client_cache.set_max_value_bytes(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].client_cache.max_value_bytes = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster client_cache.max_value_bytes updated via CONFIG SET" + ); + } + ClusterField::ClientCacheShardCount => { + let parsed = parse_positive_usize(value, "client-cache-shard-count")?; + entry.client_cache.set_shard_count(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].client_cache.shard_count = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster client_cache.shard_count updated via CONFIG SET" + ); + } + ClusterField::ClientCacheDrainBatch => { + let parsed = parse_positive_usize(value, "client-cache-drain-batch")?; + entry.client_cache.set_drain_batch(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].client_cache.drain_batch = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster client_cache.drain_batch updated via CONFIG SET" + ); + } + ClusterField::ClientCacheDrainIntervalMs => { + let parsed = parse_positive_u64(value, "client-cache-drain-interval-ms")?; + entry.client_cache.set_drain_interval(parsed); + let mut guard = self.config.write(); + guard.clusters_mut()[entry.index].client_cache.drain_interval_ms = parsed; + info!( + cluster = cluster_name, + value = value, + "cluster client_cache.drain_interval_ms updated via CONFIG SET" + ); + } } Ok(()) } @@ -666,6 +828,31 @@ impl ConfigManager { format!("cluster.{}.hotkey-decay", name), hotkey_cfg.decay.to_string(), )); + let cache_cfg = &cluster.client_cache; + entries.push(( + format!("cluster.{}.client-cache-enabled", name), + cache_cfg.enabled.to_string(), + )); + entries.push(( + format!("cluster.{}.client-cache-max-entries", name), + cache_cfg.max_entries.to_string(), + )); + entries.push(( + format!("cluster.{}.client-cache-max-value-bytes", name), + cache_cfg.max_value_bytes.to_string(), + )); + entries.push(( + format!("cluster.{}.client-cache-shard-count", name), + cache_cfg.shard_count.to_string(), + )); + entries.push(( + format!("cluster.{}.client-cache-drain-batch", name), + cache_cfg.drain_batch.to_string(), + )); + entries.push(( + format!("cluster.{}.client-cache-drain-interval-ms", name), + cache_cfg.drain_interval_ms.to_string(), + )); } } entries.sort_by(|a, b| a.0.cmp(&b.0)); @@ -707,6 +894,12 @@ fn parse_key(key: &str) -> Result<(String, ClusterField)> { "hotkey-sketch-depth" => ClusterField::HotkeySketchDepth, "hotkey-capacity" => ClusterField::HotkeyCapacity, "hotkey-decay" => ClusterField::HotkeyDecay, + "client-cache-enabled" => ClusterField::ClientCacheEnabled, + "client-cache-max-entries" => ClusterField::ClientCacheMaxEntries, + "client-cache-max-value-bytes" => ClusterField::ClientCacheMaxValueBytes, + "client-cache-shard-count" => ClusterField::ClientCacheShardCount, + "client-cache-drain-batch" => ClusterField::ClientCacheDrainBatch, + "client-cache-drain-interval-ms" => ClusterField::ClientCacheDrainIntervalMs, unknown => bail!("unknown cluster field '{}'", unknown), }; Ok((cluster.to_string(), field)) @@ -800,6 +993,36 @@ fn parse_hotkey_decay(value: &str) -> Result { Ok(parsed) } +fn parse_bool_flag(value: &str, field: &str) -> Result { + match value.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Ok(true), + "0" | "false" | "no" | "off" => Ok(false), + other => bail!("invalid {} value '{}'", field, other), + } +} + +fn parse_positive_usize(value: &str, field: &str) -> Result { + let parsed: i64 = value + .trim() + .parse() + .with_context(|| format!("invalid {} value '{}'", field, value))?; + if parsed <= 0 { + bail!("{} must be > 0", field); + } + usize::try_from(parsed).map_err(|_| anyhow!("{} is too large", field)) +} + +fn parse_positive_u64(value: &str, field: &str) -> Result { + let parsed: i64 = value + .trim() + .parse() + .with_context(|| format!("invalid {} value '{}'", field, value))?; + if parsed <= 0 { + bail!("{} must be > 0", field); + } + Ok(parsed as u64) +} + fn option_to_string(value: Option) -> String { value .map(|v| v.to_string()) @@ -830,6 +1053,12 @@ enum ClusterField { HotkeySketchDepth, HotkeyCapacity, HotkeyDecay, + ClientCacheEnabled, + ClientCacheMaxEntries, + ClientCacheMaxValueBytes, + ClientCacheShardCount, + ClientCacheDrainBatch, + ClientCacheDrainIntervalMs, } fn wildcard_match(pattern: &str, target: &str) -> bool { diff --git a/src/lib.rs b/src/lib.rs index 55fb3b7..68bfde6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ use tracing::{info, warn}; use tracing_subscriber::{fmt, EnvFilter}; pub mod auth; +pub mod cache; pub mod backend; pub mod cluster; pub mod config; diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index bbd4f71..c7b852c 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -179,6 +179,50 @@ static FRONT_COMMAND_TOTAL: Lazy = Lazy::new(|| { .expect("front command counter registration must succeed") }); +static CLIENT_CACHE_LOOKUP: Lazy = Lazy::new(|| { + register_int_counter_vec!( + opts!( + "aster_client_cache_lookup_total", + "client cache lookup results grouped by kind" + ), + &["cluster", "kind", "result"] + ) + .expect("client cache lookup counter must succeed") +}); + +static CLIENT_CACHE_STORE: Lazy = Lazy::new(|| { + register_int_counter_vec!( + opts!( + "aster_client_cache_store_total", + "client cache store operations grouped by kind" + ), + &["cluster", "kind"] + ) + .expect("client cache store counter must succeed") +}); + +static CLIENT_CACHE_INVALIDATE: Lazy = Lazy::new(|| { + register_int_counter_vec!( + opts!( + "aster_client_cache_invalidate_total", + "client cache invalidations grouped by cluster" + ), + &["cluster"] + ) + .expect("client cache invalidation counter must succeed") +}); + +static CLIENT_CACHE_STATE: Lazy = Lazy::new(|| { + register_int_counter_vec!( + opts!( + "aster_client_cache_state_total", + "client cache state transitions grouped by cluster" + ), + &["cluster", "state"] + ) + .expect("client cache state counter must succeed") +}); + static BACKEND_REQUEST_TOTAL: Lazy = Lazy::new(|| { register_int_counter_vec!( opts!( @@ -323,6 +367,35 @@ pub fn front_command(cluster: &str, kind: &str, success: bool) { .inc(); } +/// Record a client cache lookup result. +pub fn client_cache_lookup(cluster: &str, kind: &str, hit: bool) { + let result = if hit { "hit" } else { "miss" }; + CLIENT_CACHE_LOOKUP + .with_label_values(&[cluster, kind, result]) + .inc(); +} + +/// Record a client cache store/update event. +pub fn client_cache_store(cluster: &str, kind: &str) { + CLIENT_CACHE_STORE + .with_label_values(&[cluster, kind]) + .inc(); +} + +/// Record the number of keys invalidated from the client cache. +pub fn client_cache_invalidate(cluster: &str, count: usize) { + CLIENT_CACHE_INVALIDATE + .with_label_values(&[cluster]) + .inc_by(count as u64); +} + +/// Record a client cache state transition. +pub fn client_cache_state(cluster: &str, state: &str) { + CLIENT_CACHE_STATE + .with_label_values(&[cluster, state]) + .inc(); +} + /// Record a backend request outcome. pub fn backend_request_result(cluster: &str, backend: &str, result: &str) { BACKEND_REQUEST_TOTAL diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index 3167e5e..edd1bc0 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -18,6 +18,7 @@ use tokio_util::codec::{Framed, FramedParts}; use tracing::{debug, info, warn}; use crate::auth::{AuthAction, BackendAuth, FrontendAuthenticator}; +use crate::cache::{tracker::CacheTrackerSet, ClientCache}; use crate::backend::client::{ClientId, FrontConnectionGuard}; use crate::backend::pool::{BackendNode, ConnectionPool, Connector, SessionCommand}; use crate::config::{ClusterConfig, ClusterRuntime, ConfigManager}; @@ -55,6 +56,8 @@ pub struct StandaloneProxy { hotkey: Arc, listen_port: u16, backend_nodes: usize, + client_cache: Arc, + cache_trackers: Arc, } impl StandaloneProxy { @@ -96,6 +99,21 @@ impl StandaloneProxy { let hotkey = config_manager .hotkey_for(&config.name) .ok_or_else(|| anyhow!("missing hotkey state for cluster {}", config.name))?; + let client_cache = config_manager + .client_cache_for(&config.name) + .ok_or_else(|| anyhow!("missing client cache state for cluster {}", config.name))?; + let cache_trackers = Arc::new(CacheTrackerSet::new( + cluster.clone(), + client_cache.clone(), + runtime.clone(), + backend_auth.clone(), + DEFAULT_TIMEOUT_MS, + )); + let tracker_nodes = nodes + .iter() + .map(|entry| entry.backend.as_str().to_string()) + .collect(); + cache_trackers.set_nodes(tracker_nodes); Ok(Self { cluster, @@ -110,6 +128,8 @@ impl StandaloneProxy { hotkey, listen_port, backend_nodes, + client_cache, + cache_trackers, }) } @@ -468,6 +488,16 @@ impl StandaloneProxy { continue; } + if let Some(hit) = self.client_cache.lookup(&command) { + self.hotkey.record_command(&command); + metrics::front_command(self.cluster.as_ref(), kind_label, true); + framed.send(hit).await?; + continue; + } + + let cache_candidate = command.clone(); + let cacheable_read = ClientCache::is_cacheable_read(&cache_candidate); + let invalidating_write = ClientCache::is_invalidating_write(&cache_candidate); let requested_version = command.resp_version_request(); let response = match self.dispatch(client_id, command).await { Ok(resp) => resp, @@ -480,6 +510,12 @@ impl StandaloneProxy { }; let success = !response.is_error(); + if success && cacheable_read { + self.client_cache.store(&cache_candidate, &response); + } + if invalidating_write { + self.client_cache.invalidate_command(&cache_candidate); + } if success { if let Some(version) = requested_version { framed.codec_mut().set_version(version); From fad765bcd3cf671185cd60e6c8d7925f3a018a48 Mon Sep 17 00:00:00 2001 From: wayslog Date: Thu, 13 Nov 2025 19:16:40 +0800 Subject: [PATCH 2/2] ci: compile and test pass --- src/cache/mod.rs | 48 +++++++++++++++++++++---------------------- src/cache/tracker.rs | 3 +-- src/cluster/mod.rs | 4 ++-- src/standalone/mod.rs | 4 ++-- 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/cache/mod.rs b/src/cache/mod.rs index abaf2ad..e8f3df2 100644 --- a/src/cache/mod.rs +++ b/src/cache/mod.rs @@ -1,12 +1,13 @@ use std::cmp::{max, Reverse}; use std::collections::BinaryHeap; +use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::atomic::{AtomicU8, Ordering}; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use std::time::Duration; use ahash::{AHasher, RandomState}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{bail, Result}; use arc_swap::ArcSwap; use bytes::Bytes; use hashbrown::HashMap; @@ -28,7 +29,6 @@ const STATE_ENABLED: u8 = 1; const STATE_DRAINING: u8 = 2; const MAX_MULTI_KEYS: usize = 64; -const MAX_HASH_FIELD_CACHED: usize = 64; /// Operational state for the cache, observable by trackers. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -67,6 +67,16 @@ pub struct ClientCache { state_tx: watch::Sender, } +impl fmt::Debug for ClientCache { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClientCache") + .field("cluster", &self.cluster) + .field("state", &self.state()) + .field("resp3_ready", &self.resp3_ready) + .finish_non_exhaustive() + } +} + impl ClientCache { pub fn new(cluster: Arc, config: ClientCacheConfig, resp3_ready: bool) -> Self { let shard_count = config.shard_count.max(1); @@ -303,11 +313,10 @@ impl ClientCache { Some(value) => value, None => return, }; - let footprint = resp_size(&normalized); - if footprint > config.max_value_bytes { + if resp_size(&normalized) > config.max_value_bytes { return; } - let entry = CacheEntry::new(normalized, footprint); + let entry = CacheEntry::new(normalized); let cache_key = CacheKey::new(kind, key.clone(), field.cloned()); let shards = self.shards.load(); let index = shard_index(cache_key.primary.as_ref(), shards.len().max(1)); @@ -345,11 +354,6 @@ impl ClientCache { self.shards.store(Arc::new(shards)); } - fn total_entries(&self) -> usize { - let shards = self.shards.load(); - shards.iter().map(|shard| shard.len()).sum() - } - fn start_drain_task(self: &Arc) { let mut guard = self.drain_handle.lock(); if let Some(handle) = guard.take() { @@ -468,15 +472,13 @@ impl CacheCommandKind { struct CacheEntry { value: RespValue, access: u64, - size: usize, } impl CacheEntry { - fn new(value: RespValue, size: usize) -> Self { + fn new(value: RespValue) -> Self { Self { value, access: 0, - size, } } } @@ -510,10 +512,6 @@ impl CacheShard { } } - fn len(&self) -> usize { - self.inner.lock().len() - } - fn get( &self, kind: CacheCommandKind, @@ -565,15 +563,15 @@ impl CacheShardInner { } } - fn len(&self) -> usize { - self.entries.len() - } - fn touch(&mut self, key: &CacheKey) -> Option { + if !self.entries.contains_key(key) { + return None; + } + let next_access = self.next_access(); if let Some(entry) = self.entries.get_mut(key) { - entry.access = self.next_access(); + entry.access = next_access; self.order - .push(Reverse(HeapEntry::new(entry.access, key.clone()))); + .push(Reverse(HeapEntry::new(next_access, key.clone()))); Some(entry.value.clone()) } else { None @@ -784,7 +782,7 @@ fn classify_write(command: &RedisCommand) -> Option> { if args.len() < 2 { return None; } - let mut keys = SmallVec::<[&Bytes; 1]>::new(); + let mut keys = SmallVec::<[&Bytes; MAX_MULTI_KEYS]>::new(); keys.push(&args[1]); Some(CacheWrite::Keys(keys)) } diff --git a/src/cache/tracker.rs b/src/cache/tracker.rs index d965288..d271dcd 100644 --- a/src/cache/tracker.rs +++ b/src/cache/tracker.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::time::Duration; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use bytes::Bytes; use futures::{SinkExt, StreamExt}; use parking_lot::Mutex; @@ -230,7 +230,6 @@ async fn open_stream( stream.set_nodelay(true).context("failed to enable TCP_NODELAY")?; #[cfg(any(unix, windows))] { - use socket2::{SockRef, TcpKeepalive}; let keepalive = TcpKeepalive::new() .with_time(Duration::from_secs(60)) .with_interval(Duration::from_secs(60)); diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index c0e4c7a..a3ba4bd 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -60,7 +60,7 @@ pub struct ClusterProxy { listen_port: u16, seed_nodes: usize, client_cache: Arc, - cache_trackers: Arc, + _cache_trackers: Arc, } impl ClusterProxy { @@ -123,7 +123,7 @@ impl ClusterProxy { listen_port, seed_nodes: config.servers.len(), client_cache, - cache_trackers: cache_trackers.clone(), + _cache_trackers: cache_trackers.clone(), }; // trigger an immediate topology fetch diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index edd1bc0..cb8f946 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -57,7 +57,7 @@ pub struct StandaloneProxy { listen_port: u16, backend_nodes: usize, client_cache: Arc, - cache_trackers: Arc, + _cache_trackers: Arc, } impl StandaloneProxy { @@ -129,7 +129,7 @@ impl StandaloneProxy { listen_port, backend_nodes, client_cache, - cache_trackers, + _cache_trackers: cache_trackers, }) }