diff --git a/docker/config/integration.toml b/docker/config/integration.toml index ee3421e..596d205 100644 --- a/docker/config/integration.toml +++ b/docker/config/integration.toml @@ -8,6 +8,7 @@ servers = [ ] password = "front-standalone-secret" backend_password = "backend-standalone-secret" +backend_resp_version = "resp3" [[clusters]] name = "cluster" @@ -24,3 +25,4 @@ auth = { password = "front-cluster-secret", users = [ { username = "ops", password = "ops-secret" } ] } backend_password = "backend-cluster-secret" +backend_resp_version = "resp3" diff --git a/docker/integration-test.sh b/docker/integration-test.sh old mode 100644 new mode 100755 index 9e0b466..db061ab --- a/docker/integration-test.sh +++ b/docker/integration-test.sh @@ -41,6 +41,26 @@ wait_for_cluster() { wait_for_cluster +if ! hello_standalone="$(REDISCLI_AUTH="$STANDALONE_PASS" redis-cli -h aster-proxy -p 6380 --raw HELLO 3 AUTH default "$STANDALONE_PASS")"; then + echo "HELLO 3 handshake against standalone proxy failed" >&2 + exit 1 +fi +if ! printf "%s\n" "$hello_standalone" | awk 'BEGIN{found=0} {if(prev=="proto" && $0=="3"){found=1} prev=$0} END{exit(found?0:1)}'; then + echo "HELLO 3 response from standalone proxy missing proto=3:" >&2 + printf "%s\n" "$hello_standalone" >&2 + exit 1 +fi + +if ! hello_cluster="$(REDISCLI_AUTH="$CLUSTER_USER_PASS" redis-cli -h aster-proxy -p 6381 --user "$CLUSTER_USER" --raw HELLO 3 AUTH "$CLUSTER_USER" "$CLUSTER_USER_PASS")"; then + echo "HELLO 3 handshake against cluster proxy failed" >&2 + exit 1 +fi +if ! printf "%s\n" "$hello_cluster" | awk 'BEGIN{found=0} {if(prev=="proto" && $0=="3"){found=1} prev=$0} END{exit(found?0:1)}'; then + echo "HELLO 3 response from cluster proxy missing proto=3:" >&2 + printf "%s\n" "$hello_cluster" >&2 + exit 1 +fi + noauth_output="$(redis-cli -h aster-proxy -p 6380 PING 2>&1 || true)" if echo "$noauth_output" | grep -q "PONG"; then echo "Expected standalone proxy to require authentication" >&2 diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 06aa855..3ac7f70 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -100,6 +100,17 @@ impl BackendAuth { .expect("AUTH command is always valid") } + pub fn hello_credentials(&self) -> Option<(Bytes, Bytes)> { + match self.parts.len() { + 2 => Some(( + Bytes::from_static(DEFAULT_USER.as_bytes()), + self.parts[1].clone(), + )), + 3 => Some((self.parts[1].clone(), self.parts[2].clone())), + _ => None, + } + } + pub async fn apply_to_stream( &self, framed: &mut Framed, diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 01b7170..36abfec 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -25,8 +25,8 @@ use crate::hotkey::Hotkey; use crate::info::{InfoContext, ProxyMode}; use crate::metrics; use crate::protocol::redis::{ - BlockingKind, MultiDispatch, RedisCommand, RespCodec, RespValue, SlotMap, SubCommand, - SubResponse, SubscriptionKind, SLOT_COUNT, + BlockingKind, MultiDispatch, RedisCommand, RespCodec, RespValue, RespVersion, SlotMap, + SubCommand, SubResponse, SubscriptionKind, SLOT_COUNT, }; use crate::slowlog::Slowlog; use crate::utils::{crc16, trim_hash_tag}; @@ -78,6 +78,7 @@ impl ClusterProxy { runtime.clone(), REQUEST_TIMEOUT_MS, backend_auth.clone(), + config.backend_resp_version, )); let pool = Arc::new(ConnectionPool::new(cluster.clone(), connector.clone())); let auth = config @@ -143,17 +144,25 @@ impl ClusterProxy { let client_id = ClientId::new(); let _guard = FrontConnectionGuard::new(&self.cluster); - let (mut sink, stream) = Framed::new(socket, RespCodec::default()).split(); + let framed = Framed::new(socket, RespCodec::default()); + let codec_handle = framed.codec().clone(); + let (mut sink, stream) = framed.split(); let mut stream = stream.fuse(); - let mut pending: FuturesOrdered> = FuturesOrdered::new(); + let mut pending: FuturesOrdered)>> = + FuturesOrdered::new(); + let mut _resp3_negotiated = codec_handle.version() == RespVersion::Resp3; let mut inflight = 0usize; let mut stream_closed = false; let mut auth_state = self.auth.as_ref().map(|auth| auth.new_session()); loop { tokio::select! { - Some(resp) = pending.next(), if inflight > 0 => { + Some((resp, version)) = pending.next(), if inflight > 0 => { inflight -= 1; + if let Some(version) = version { + codec_handle.set_version(version); + _resp3_negotiated = version == RespVersion::Resp3; + } sink.send(resp).await?; } frame_opt = stream.next(), if !stream_closed && inflight < PIPELINE_LIMIT => { @@ -168,14 +177,13 @@ impl ClusterProxy { cmd = new_cmd; } AuthAction::Reply(resp) => { - let fut = async move { resp }; + let fut = async move { (resp, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; } } } - if matches!(cmd.as_subscription(), SubscriptionKind::Channel | SubscriptionKind::Pattern) { let kind_label = cmd.kind_label(); if cmd.args().len() <= 1 { @@ -230,8 +238,12 @@ impl ClusterProxy { } }; while inflight > 0 { - if let Some(resp) = pending.next().await { + if let Some((resp, version)) = pending.next().await { inflight -= 1; + if let Some(version) = version { + codec_handle.set_version(version); + _resp3_negotiated = version == RespVersion::Resp3; + } sink.send(resp).await?; } else { inflight = 0; @@ -291,7 +303,7 @@ impl ClusterProxy { kind_label, success, ); - let fut = async move { response }; + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -302,7 +314,7 @@ impl ClusterProxy { cmd.kind_label(), true, ); - let fut = async move { response }; + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -314,7 +326,7 @@ impl ClusterProxy { cmd.kind_label(), success, ); - let fut = async move { response }; + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -326,13 +338,23 @@ impl ClusterProxy { cmd.kind_label(), success, ); - let fut = async move { response }; + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; } + let requested_version = cmd.resp_version_request(); let guard = self.prepare_dispatch(client_id, cmd); - pending.push_back(Box::pin(guard)); + let fut = async move { + let resp = guard.await; + let version = if !resp.is_error() { + requested_version + } else { + None + }; + (resp, version) + }; + pending.push_back(Box::pin(fut)); inflight += 1; } Err(err) => { @@ -340,7 +362,7 @@ impl ClusterProxy { metrics::front_error(self.cluster.as_ref(), "parse"); metrics::front_command(self.cluster.as_ref(), "invalid", false); let message = Bytes::from(format!("ERR {err}")); - let fut = async move { RespValue::Error(message) }; + let fut = async move { (RespValue::Error(message), None) }; pending.push_back(Box::pin(fut)); inflight += 1; } @@ -364,7 +386,11 @@ impl ClusterProxy { } } - while let Some(resp) = pending.next().await { + while let Some((resp, version)) = pending.next().await { + if let Some(version) = version { + codec_handle.set_version(version); + _resp3_negotiated = version == RespVersion::Resp3; + } sink.send(resp).await?; } sink.close().await?; @@ -903,6 +929,7 @@ struct ClusterConnector { heartbeat_interval: Duration, reconnect_base_delay: Duration, max_reconnect_attempts: usize, + backend_resp_version: RespVersion, } impl ClusterConnector { @@ -910,6 +937,7 @@ impl ClusterConnector { runtime: Arc, default_timeout_ms: u64, backend_auth: Option, + backend_resp_version: RespVersion, ) -> Self { Self { runtime, @@ -918,6 +946,7 @@ impl ClusterConnector { heartbeat_interval: Duration::from_secs(30), reconnect_base_delay: Duration::from_millis(50), max_reconnect_attempts: 3, + backend_resp_version, } } @@ -950,11 +979,91 @@ impl ClusterConnector { Ok(framed) } + async fn negotiate_resp_version( + &self, + cluster: &str, + node: &BackendNode, + framed: &mut Framed, + ) -> Result { + framed.codec_mut().set_version(RespVersion::Resp2); + if self.backend_resp_version != RespVersion::Resp3 { + return Ok(RespVersion::Resp2); + } + + let timeout_duration = self.current_timeout(); + let mut hello_parts = vec![ + RespValue::BulkString(Bytes::from_static(b"HELLO")), + RespValue::BulkString(Bytes::from_static(b"3")), + ]; + if let Some(auth) = &self.backend_auth { + if let Some((username, password)) = auth.hello_credentials() { + hello_parts.push(RespValue::BulkString(Bytes::from_static(b"AUTH"))); + hello_parts.push(RespValue::BulkString(username)); + hello_parts.push(RespValue::BulkString(password)); + } + } + let hello = RespValue::Array(hello_parts); + + match timeout(timeout_duration, framed.send(hello)).await { + Ok(Ok(())) => {} + Ok(Err(err)) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!("failed to send RESP3 HELLO to {}", node.as_str()))); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out sending RESP3 HELLO", + node.as_str() + )); + } + } + + let reply = match timeout(timeout_duration, framed.next()).await { + Ok(Some(Ok(value))) => value, + Ok(Some(Err(err))) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!( + "failed to read RESP3 HELLO reply from {}", + node.as_str() + ))); + } + Ok(None) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} closed connection during RESP3 HELLO", + node.as_str() + )); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out waiting for RESP3 HELLO reply", + node.as_str() + )); + } + }; + + if reply.is_error() { + info!( + cluster = %cluster, + backend = %node.as_str(), + "backend rejected RESP3 HELLO; falling back to RESP2" + ); + framed.codec_mut().set_version(RespVersion::Resp2); + return Ok(RespVersion::Resp2); + } + + framed.codec_mut().set_version(RespVersion::Resp3); + Ok(RespVersion::Resp3) + } + async fn execute( &self, framed: &mut Framed, command: RedisCommand, ) -> Result { + let requested_version = command.resp_version_request(); let blocking = command.as_blocking(); if let Ok(name) = std::str::from_utf8(command.command_name()) { if name.eq_ignore_ascii_case("blpop") || name.eq_ignore_ascii_case("brpop") { @@ -966,7 +1075,7 @@ impl ClusterConnector { .await .context("timed out sending command")??; - match blocking { + let response = match blocking { BlockingKind::Queue { .. } | BlockingKind::Stream { .. } => match framed.next().await { Some(Ok(value)) => Ok(value), Some(Err(err)) => Err(err.into()), @@ -978,7 +1087,14 @@ impl ClusterConnector { Ok(None) => Err(anyhow!("backend closed connection")), Err(_) => Err(anyhow!("timed out waiting for response")), }, + }?; + + if let Some(version) = requested_version { + if !response.is_error() { + framed.codec_mut().set_version(version); + } } + Ok(response) } async fn connect_with_retry( @@ -990,7 +1106,7 @@ impl ClusterConnector { for attempt in 0..self.max_reconnect_attempts { let attempt_start = Instant::now(); match self.open_stream(node.as_str()).await { - Ok(stream) => { + Ok(mut stream) => { metrics::backend_probe_duration( cluster, node.as_str(), @@ -998,7 +1114,25 @@ impl ClusterConnector { attempt_start.elapsed(), ); metrics::backend_probe_result(cluster, node.as_str(), "connect", true); - return Ok(stream); + match self + .negotiate_resp_version(cluster, node, &mut stream) + .await + { + Ok(_) => return Ok(stream), + Err(err) => { + warn!( + cluster = %cluster, + backend = %node.as_str(), + attempt = attempt + 1, + error = %err, + "failed to negotiate RESP version with backend" + ); + last_error = Some(err); + if attempt + 1 < self.max_reconnect_attempts { + sleep(self.reconnect_base_delay).await; + } + } + } } Err(err) => { let elapsed = attempt_start.elapsed(); diff --git a/src/config/mod.rs b/src/config/mod.rs index 69b3179..af305c9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -15,7 +15,7 @@ use crate::hotkey::{ Hotkey, HotkeyConfig, DEFAULT_DECAY, DEFAULT_HOTKEY_CAPACITY, DEFAULT_SAMPLE_EVERY, DEFAULT_SKETCH_DEPTH, DEFAULT_SKETCH_WIDTH, }; -use crate::protocol::redis::{RedisCommand, RespValue}; +use crate::protocol::redis::{RedisCommand, RespValue, RespVersion}; use crate::slowlog::Slowlog; /// Environment variable controlling the default worker thread count when a @@ -51,6 +51,10 @@ fn default_hotkey_decay() -> f64 { DEFAULT_DECAY } +fn default_backend_resp_version() -> RespVersion { + RespVersion::Resp2 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { #[serde(default)] @@ -177,6 +181,8 @@ pub struct ClusterConfig { pub hotkey_capacity: usize, #[serde(default = "default_hotkey_decay")] pub hotkey_decay: f64, + #[serde(default = "default_backend_resp_version")] + pub backend_resp_version: RespVersion, } impl ClusterConfig { diff --git a/src/protocol/redis/codec.rs b/src/protocol/redis/codec.rs index 8870e74..88851e8 100644 --- a/src/protocol/redis/codec.rs +++ b/src/protocol/redis/codec.rs @@ -1,11 +1,61 @@ use anyhow::{anyhow, Result}; use bytes::{Buf, Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; use super::types::RespValue; -#[derive(Debug, Default, Clone, Copy)] -pub struct RespCodec; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum RespVersion { + Resp2, + Resp3, +} + +impl RespVersion { + fn as_u8(self) -> u8 { + match self { + RespVersion::Resp2 => 2, + RespVersion::Resp3 => 3, + } + } + + fn from_u8(value: u8) -> Self { + match value { + 3 => RespVersion::Resp3, + _ => RespVersion::Resp2, + } + } +} + +#[derive(Debug, Clone)] +pub struct RespCodec { + version: Arc, +} + +impl Default for RespCodec { + fn default() -> Self { + Self { + version: Arc::new(AtomicU8::new(RespVersion::Resp2.as_u8())), + } + } +} + +impl RespCodec { + pub fn version(&self) -> RespVersion { + RespVersion::from_u8(self.version.load(Ordering::SeqCst)) + } + + pub fn set_version(&self, version: RespVersion) { + self.version.store(version.as_u8(), Ordering::SeqCst); + } + + pub fn upgrade_to_resp3(&self) { + self.set_version(RespVersion::Resp3); + } +} impl Decoder for RespCodec { type Item = RespValue; @@ -27,7 +77,8 @@ impl Encoder for RespCodec { type Error = anyhow::Error; fn encode(&mut self, item: RespValue, dst: &mut BytesMut) -> Result<()> { - write_value(&item, dst); + let version = self.version(); + write_value(&item, version, dst); Ok(()) } } @@ -75,7 +126,9 @@ fn parse_value(src: &[u8], pos: &mut usize) -> Result> { .map_err(|err| anyhow!("invalid integer: {err}"))?; Ok(Some(RespValue::Integer(value))) } - b'$' => { + b'$' => parse_bulk_string(src, pos, start, false), + b'*' => parse_array(src, pos, start), + b'_' => { let line = match read_line(src, pos)? { Some(line) => line, None => { @@ -83,57 +136,230 @@ fn parse_value(src: &[u8], pos: &mut usize) -> Result> { return Ok(None); } }; - let len_str = std::str::from_utf8(line)?; - let len = len_str - .parse::() - .map_err(|err| anyhow!("invalid bulk length: {err}"))?; - if len < 0 { - return Ok(Some(RespValue::NullBulk)); - } - let len = len as usize; - if *pos + len + 2 > src.len() { - *pos = start; - return Ok(None); + if !line.is_empty() { + return Err(anyhow!("invalid null frame payload")); } - let data = &src[*pos..*pos + len]; - *pos += len + 2; // skip data and CRLF - Ok(Some(RespValue::BulkString(Bytes::copy_from_slice(data)))) + Ok(Some(RespValue::Null)) } - b'*' => { - let mut local_pos = *pos; - let line = match read_line(src, &mut local_pos)? { + b'#' => { + let line = match read_line(src, pos)? { Some(line) => line, None => { *pos = start; return Ok(None); } }; - let len_str = std::str::from_utf8(line)?; - let len = len_str - .parse::() - .map_err(|err| anyhow!("invalid array length: {err}"))?; - if len < 0 { - *pos = local_pos; - return Ok(Some(RespValue::NullArray)); + match line { + b"t" | b"T" => Ok(Some(RespValue::Boolean(true))), + b"f" | b"F" => Ok(Some(RespValue::Boolean(false))), + _ => Err(anyhow!("invalid boolean literal '{:?}'", line)), } - let mut values = Vec::with_capacity(len as usize); - let mut element_pos = local_pos; - for _ in 0..len { - match parse_value(src, &mut element_pos)? { - Some(value) => values.push(value), - None => { - *pos = start; - return Ok(None); - } + } + b',' => { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); } - } - *pos = element_pos; - Ok(Some(RespValue::Array(values))) + }; + Ok(Some(RespValue::Double(Bytes::copy_from_slice(line)))) } + b'(' => { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + Ok(Some(RespValue::BigNumber(Bytes::copy_from_slice(line)))) + } + b'=' => parse_verbatim_string(src, pos, start), + b'!' => parse_bulk_string(src, pos, start, true), + b'%' => parse_map(src, pos, start, RespValue::Map), + b'~' => parse_collection(src, pos, start, RespValue::Set), + b'>' => parse_collection(src, pos, start, RespValue::Push), + b'|' => parse_map(src, pos, start, RespValue::Attribute), _ => Err(anyhow!("unsupported RESP prefix '{}'.", prefix as char)), } } +fn parse_bulk_string( + src: &[u8], + pos: &mut usize, + start: usize, + as_error: bool, +) -> Result> { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "bulk string")?; + if len < 0 { + return Ok(Some(RespValue::NullBulk)); + } + let len = len as usize; + if *pos + len + 2 > src.len() { + *pos = start; + return Ok(None); + } + let data = &src[*pos..*pos + len]; + *pos += len + 2; + let payload = Bytes::copy_from_slice(data); + if as_error { + Ok(Some(RespValue::BlobError(payload))) + } else { + Ok(Some(RespValue::BulkString(payload))) + } +} + +fn parse_verbatim_string(src: &[u8], pos: &mut usize, start: usize) -> Result> { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "verbatim string")?; + if len < 0 { + return Err(anyhow!("verbatim string must not be null")); + } + let len = len as usize; + if *pos + len + 2 > src.len() { + *pos = start; + return Ok(None); + } + let data = &src[*pos..*pos + len]; + *pos += len + 2; + if len < 4 || data[3] != b':' { + return Err(anyhow!("invalid verbatim string header")); + } + let mut format = [0u8; 3]; + format.copy_from_slice(&data[..3]); + let payload = Bytes::copy_from_slice(&data[4..]); + Ok(Some(RespValue::VerbatimString { + format, + data: payload, + })) +} + +fn parse_array(src: &[u8], pos: &mut usize, start: usize) -> Result> { + let mut local_pos = *pos; + let line = match read_line(src, &mut local_pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "array")?; + if len < 0 { + *pos = local_pos; + return Ok(Some(RespValue::NullArray)); + } + let mut values = Vec::with_capacity(len as usize); + let mut element_pos = local_pos; + for _ in 0..len { + match parse_value(src, &mut element_pos)? { + Some(value) => values.push(value), + None => { + *pos = start; + return Ok(None); + } + } + } + *pos = element_pos; + Ok(Some(RespValue::Array(values))) +} + +fn parse_collection( + src: &[u8], + pos: &mut usize, + start: usize, + ctor: F, +) -> Result> +where + F: FnOnce(Vec) -> RespValue, +{ + let mut local_pos = *pos; + let line = match read_line(src, &mut local_pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "collection")?; + if len < 0 { + *pos = local_pos; + return Ok(Some(RespValue::Null)); + } + let mut values = Vec::with_capacity(len as usize); + let mut element_pos = local_pos; + for _ in 0..len { + match parse_value(src, &mut element_pos)? { + Some(value) => values.push(value), + None => { + *pos = start; + return Ok(None); + } + } + } + *pos = element_pos; + Ok(Some(ctor(values))) +} + +fn parse_map(src: &[u8], pos: &mut usize, start: usize, ctor: F) -> Result> +where + F: FnOnce(Vec<(RespValue, RespValue)>) -> RespValue, +{ + let mut local_pos = *pos; + let line = match read_line(src, &mut local_pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "map")?; + if len < 0 { + *pos = local_pos; + return Ok(Some(RespValue::Null)); + } + let mut entries = Vec::with_capacity(len as usize); + let mut element_pos = local_pos; + for _ in 0..len { + let key = match parse_value(src, &mut element_pos)? { + Some(value) => value, + None => { + *pos = start; + return Ok(None); + } + }; + let value = match parse_value(src, &mut element_pos)? { + Some(value) => value, + None => { + *pos = start; + return Ok(None); + } + }; + entries.push((key, value)); + } + *pos = element_pos; + Ok(Some(ctor(entries))) +} + +fn parse_length(bytes: &[u8], kind: &str) -> Result { + let text = std::str::from_utf8(bytes)?; + text.parse::() + .map_err(|err| anyhow!("invalid {kind} length: {err}")) +} + fn read_line<'a>(src: &'a [u8], pos: &mut usize) -> Result> { if *pos >= src.len() { return Ok(None); @@ -150,7 +376,7 @@ fn read_line<'a>(src: &'a [u8], pos: &mut usize) -> Result> { Ok(None) } -fn write_value(value: &RespValue, dst: &mut BytesMut) { +fn write_value(value: &RespValue, version: RespVersion, dst: &mut BytesMut) { match value { RespValue::SimpleString(data) => { dst.extend_from_slice(b"+"); @@ -162,31 +388,168 @@ fn write_value(value: &RespValue, dst: &mut BytesMut) { dst.extend_from_slice(data); dst.extend_from_slice(b"\r\n"); } - RespValue::Integer(value) => { - dst.extend_from_slice(b":"); - dst.extend_from_slice(value.to_string().as_bytes()); - dst.extend_from_slice(b"\r\n"); - } - RespValue::BulkString(data) => { - dst.extend_from_slice(b"$"); - dst.extend_from_slice(data.len().to_string().as_bytes()); - dst.extend_from_slice(b"\r\n"); - dst.extend_from_slice(data); - dst.extend_from_slice(b"\r\n"); - } - RespValue::NullBulk => { - dst.extend_from_slice(b"$-1\r\n"); - } - RespValue::Array(values) => { - dst.extend_from_slice(b"*"); - dst.extend_from_slice(values.len().to_string().as_bytes()); - dst.extend_from_slice(b"\r\n"); - for value in values { - write_value(value, dst); + RespValue::Integer(value) => write_integer(*value, dst), + RespValue::BulkString(data) => write_bulk(data, dst), + RespValue::NullBulk => dst.extend_from_slice(b"$-1\r\n"), + RespValue::Array(values) => write_aggregate(b'*', values.len(), values, version, dst), + RespValue::NullArray => dst.extend_from_slice(b"*-1\r\n"), + RespValue::Null => match version { + RespVersion::Resp3 => dst.extend_from_slice(b"_\r\n"), + RespVersion::Resp2 => dst.extend_from_slice(b"$-1\r\n"), + }, + RespValue::Boolean(flag) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b"#"); + dst.extend_from_slice(if *flag { b"t" } else { b"f" }); + dst.extend_from_slice(b"\r\n"); } - } - RespValue::NullArray => { - dst.extend_from_slice(b"*-1\r\n"); - } + RespVersion::Resp2 => write_integer(if *flag { 1 } else { 0 }, dst), + }, + RespValue::Double(data) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b","); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => write_bulk(data, dst), + }, + RespValue::BigNumber(data) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b"("); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => write_bulk(data, dst), + }, + RespValue::VerbatimString { format, data } => match version { + RespVersion::Resp3 => { + let total_len = 4 + data.len(); + dst.extend_from_slice(b"="); + dst.extend_from_slice(total_len.to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + dst.extend_from_slice(format); + dst.extend_from_slice(b":"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => write_bulk(data, dst), + }, + RespValue::BlobError(data) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b"!"); + dst.extend_from_slice(data.len().to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => { + dst.extend_from_slice(b"-"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + }, + RespValue::Map(entries) => match version { + RespVersion::Resp3 => write_map(b'%', entries, version, dst), + RespVersion::Resp2 => write_map_as_array(entries, version, dst), + }, + RespValue::Set(values) => match version { + RespVersion::Resp3 => write_aggregate(b'~', values.len(), values, version, dst), + RespVersion::Resp2 => write_aggregate(b'*', values.len(), values, version, dst), + }, + RespValue::Push(values) => match version { + RespVersion::Resp3 => write_aggregate(b'>', values.len(), values, version, dst), + RespVersion::Resp2 => write_aggregate(b'*', values.len(), values, version, dst), + }, + RespValue::Attribute(entries) => match version { + RespVersion::Resp3 => write_map(b'|', entries, version, dst), + RespVersion::Resp2 => write_map_as_array(entries, version, dst), + }, + } +} + +fn write_integer(value: i64, dst: &mut BytesMut) { + dst.extend_from_slice(b":"); + dst.extend_from_slice(value.to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); +} + +fn write_bulk(data: &[u8], dst: &mut BytesMut) { + dst.extend_from_slice(b"$"); + dst.extend_from_slice(data.len().to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); +} + +fn write_aggregate( + prefix: u8, + len: usize, + values: &[RespValue], + version: RespVersion, + dst: &mut BytesMut, +) { + dst.extend_from_slice(&[prefix]); + dst.extend_from_slice(len.to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + for value in values { + write_value(value, version, dst); + } +} + +fn write_map( + prefix: u8, + entries: &[(RespValue, RespValue)], + version: RespVersion, + dst: &mut BytesMut, +) { + dst.extend_from_slice(&[prefix]); + dst.extend_from_slice(entries.len().to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + for (key, value) in entries { + write_value(key, version, dst); + write_value(value, version, dst); + } +} + +fn write_map_as_array( + entries: &[(RespValue, RespValue)], + version: RespVersion, + dst: &mut BytesMut, +) { + dst.extend_from_slice(b"*"); + dst.extend_from_slice((entries.len() * 2).to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + for (key, value) in entries { + write_value(key, version, dst); + write_value(value, version, dst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn map_entry() -> RespValue { + RespValue::Map(vec![( + RespValue::SimpleString(Bytes::from_static(b"mode")), + RespValue::BulkString(Bytes::from_static(b"standalone")), + )]) + } + + #[test] + fn encodes_map_with_resp3_prefix() { + let mut codec = RespCodec::default(); + codec.upgrade_to_resp3(); + let mut buf = BytesMut::new(); + codec.encode(map_entry(), &mut buf).unwrap(); + assert_eq!(buf.as_ref(), b"%1\r\n+mode\r\n$10\r\nstandalone\r\n"); + } + + #[test] + fn encodes_map_as_array_in_resp2() { + let mut codec = RespCodec::default(); + let mut buf = BytesMut::new(); + codec.encode(map_entry(), &mut buf).unwrap(); + assert_eq!(buf.as_ref(), b"*2\r\n+mode\r\n$10\r\nstandalone\r\n"); } } diff --git a/src/protocol/redis/command.rs b/src/protocol/redis/command.rs index 0637315..1e4f02f 100644 --- a/src/protocol/redis/command.rs +++ b/src/protocol/redis/command.rs @@ -8,6 +8,7 @@ use crate::backend::pool::BackendRequest; use crate::metrics; use crate::utils::{crc16, trim_hash_tag}; +use super::codec::RespVersion; use super::types::RespValue; pub const SLOT_COUNT: u16 = 16384; @@ -81,6 +82,7 @@ impl RedisCommand { RespValue::Array(_) => { bail!("nested array arguments are not supported"); } + other => bail!("unsupported RESP argument type: {:?}", other), } } Self::new(parts) @@ -196,6 +198,19 @@ impl RedisCommand { _ => None, } } + + pub fn resp_version_request(&self) -> Option { + if !self.command_name().eq_ignore_ascii_case(b"HELLO") { + return None; + } + let version_arg = self.args().get(1)?; + let version_text = std::str::from_utf8(version_arg).ok()?; + match version_text.parse::().ok()? { + 2 => Some(RespVersion::Resp2), + 3 => Some(RespVersion::Resp3), + _ => None, + } + } } impl BackendRequest for RedisCommand { @@ -212,6 +227,40 @@ impl BackendRequest for RedisCommand { } } +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + fn cmd(parts: &[&[u8]]) -> RedisCommand { + RedisCommand::new(parts.iter().map(|p| Bytes::copy_from_slice(p)).collect()).unwrap() + } + + #[test] + fn detects_resp3_request_from_hello() { + let command = cmd(&[b"HELLO", b"3"]); + assert_eq!(command.resp_version_request(), Some(RespVersion::Resp3)); + } + + #[test] + fn detects_resp2_request_from_hello() { + let command = cmd(&[b"HELLO", b"2"]); + assert_eq!(command.resp_version_request(), Some(RespVersion::Resp2)); + } + + #[test] + fn ignores_invalid_or_missing_version() { + let without_version = cmd(&[b"HELLO"]); + assert_eq!(without_version.resp_version_request(), None); + + let invalid_version = cmd(&[b"HELLO", b"foo"]); + assert_eq!(invalid_version.resp_version_request(), None); + + let other_command = cmd(&[b"PING"]); + assert_eq!(other_command.resp_version_request(), None); + } +} + impl fmt::Display for RedisCommand { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let args: Vec = self diff --git a/src/protocol/redis/mod.rs b/src/protocol/redis/mod.rs index 8d7a218..9c63e3a 100644 --- a/src/protocol/redis/mod.rs +++ b/src/protocol/redis/mod.rs @@ -3,7 +3,7 @@ mod command; mod slots; mod types; -pub use codec::RespCodec; +pub use codec::{RespCodec, RespVersion}; pub use command::{ Aggregator, BlockingKind, CommandKind, MultiDispatch, RedisCommand, RedisResponse, SubCommand, SubResponse, SubscriptionKind, SLOT_COUNT, diff --git a/src/protocol/redis/types.rs b/src/protocol/redis/types.rs index ba7f382..4073470 100644 --- a/src/protocol/redis/types.rs +++ b/src/protocol/redis/types.rs @@ -9,6 +9,16 @@ pub enum RespValue { NullBulk, Array(Vec), NullArray, + Null, + Boolean(bool), + Double(Bytes), + BigNumber(Bytes), + VerbatimString { format: [u8; 3], data: Bytes }, + BlobError(Bytes), + Map(Vec<(RespValue, RespValue)>), + Set(Vec), + Push(Vec), + Attribute(Vec<(RespValue, RespValue)>), } impl RespValue { @@ -28,8 +38,20 @@ impl RespValue { RespValue::Array(values) } + pub fn map(entries: Vec<(RespValue, RespValue)>) -> Self { + RespValue::Map(entries) + } + + pub fn null() -> Self { + RespValue::Null + } + + pub fn boolean(value: bool) -> Self { + RespValue::Boolean(value) + } + pub fn is_error(&self) -> bool { - matches!(self, RespValue::Error(_)) + matches!(self, RespValue::Error(_) | RespValue::BlobError(_)) } pub fn as_array(&self) -> Option<&[RespValue]> { diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index 5e427da..3167e5e 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -25,8 +25,8 @@ use crate::hotkey::Hotkey; use crate::info::{InfoContext, ProxyMode}; use crate::metrics; use crate::protocol::redis::{ - BlockingKind, MultiDispatch, RedisCommand, RedisResponse, RespCodec, RespValue, SubCommand, - SubResponse, SubscriptionKind, + BlockingKind, MultiDispatch, RedisCommand, RedisResponse, RespCodec, RespValue, RespVersion, + SubCommand, SubResponse, SubscriptionKind, }; use crate::slowlog::Slowlog; use crate::utils::trim_hash_tag; @@ -79,6 +79,7 @@ impl StandaloneProxy { runtime.clone(), DEFAULT_TIMEOUT_MS, backend_auth.clone(), + config.backend_resp_version, )); let auth = config .frontend_auth_users() @@ -467,6 +468,7 @@ impl StandaloneProxy { continue; } + let requested_version = command.resp_version_request(); let response = match self.dispatch(client_id, command).await { Ok(resp) => resp, Err(err) => { @@ -477,7 +479,12 @@ impl StandaloneProxy { } }; - let success = !matches!(response, RespValue::Error(_)); + let success = !response.is_error(); + if success { + if let Some(version) = requested_version { + framed.codec_mut().set_version(version); + } + } metrics::front_command(self.cluster.as_ref(), kind_label, success); framed.send(response).await?; } @@ -658,6 +665,7 @@ struct RedisConnector { max_reconnect_delay: Duration, heartbeat_interval: Duration, backend_auth: Option, + backend_resp_version: RespVersion, } impl RedisConnector { @@ -665,6 +673,7 @@ impl RedisConnector { runtime: Arc, default_timeout_ms: u64, backend_auth: Option, + backend_resp_version: RespVersion, ) -> Self { Self { runtime, @@ -673,6 +682,7 @@ impl RedisConnector { max_reconnect_delay: Duration::from_secs(2), heartbeat_interval: Duration::from_secs(20), backend_auth, + backend_resp_version, } } @@ -706,11 +716,91 @@ impl RedisConnector { Ok(framed) } + async fn negotiate_resp_version( + &self, + cluster: &str, + node: &BackendNode, + framed: &mut Framed, + ) -> Result { + framed.codec_mut().set_version(RespVersion::Resp2); + if self.backend_resp_version != RespVersion::Resp3 { + return Ok(RespVersion::Resp2); + } + + let timeout_duration = self.current_timeout(); + let mut hello_parts = vec![ + RespValue::BulkString(Bytes::from_static(b"HELLO")), + RespValue::BulkString(Bytes::from_static(b"3")), + ]; + if let Some(auth) = &self.backend_auth { + if let Some((username, password)) = auth.hello_credentials() { + hello_parts.push(RespValue::BulkString(Bytes::from_static(b"AUTH"))); + hello_parts.push(RespValue::BulkString(username)); + hello_parts.push(RespValue::BulkString(password)); + } + } + let hello = RespValue::Array(hello_parts); + + match timeout(timeout_duration, framed.send(hello)).await { + Ok(Ok(())) => {} + Ok(Err(err)) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!("failed to send RESP3 HELLO to {}", node.as_str()))); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out sending RESP3 HELLO", + node.as_str() + )); + } + } + + let reply = match timeout(timeout_duration, framed.next()).await { + Ok(Some(Ok(value))) => value, + Ok(Some(Err(err))) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!( + "failed to read RESP3 HELLO reply from {}", + node.as_str() + ))); + } + Ok(None) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} closed connection during RESP3 HELLO", + node.as_str() + )); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out waiting for RESP3 HELLO reply", + node.as_str() + )); + } + }; + + if reply.is_error() { + info!( + cluster = %cluster, + backend = %node.as_str(), + "backend rejected RESP3 HELLO; falling back to RESP2" + ); + framed.codec_mut().set_version(RespVersion::Resp2); + return Ok(RespVersion::Resp2); + } + + framed.codec_mut().set_version(RespVersion::Resp3); + Ok(RespVersion::Resp3) + } + async fn execute_request( &self, framed: &mut Framed, request: RedisCommand, ) -> Result { + let requested_version = request.resp_version_request(); let blocking = request.as_blocking(); let frame = request.to_resp(); let timeout_duration = self.current_timeout(); @@ -718,7 +808,7 @@ impl RedisConnector { .await .context("timed out while sending request")??; - match blocking { + let response = match blocking { BlockingKind::Queue { .. } | BlockingKind::Stream { .. } => match framed.next().await { Some(Ok(response)) => Ok(response), Some(Err(err)) => Err(err.into()), @@ -730,7 +820,14 @@ impl RedisConnector { Ok(None) => Err(anyhow!("backend closed connection")), Err(_) => Err(anyhow!("timed out waiting for backend reply")), }, + }?; + + if let Some(version) = requested_version { + if !response.is_error() { + framed.codec_mut().set_version(version); + } } + Ok(response) } async fn heartbeat(&self, framed: &mut Framed) -> Result<()> { @@ -802,7 +899,7 @@ impl Connector for RedisConnector { if connection.is_none() { let attempt_start = Instant::now(); match self.open_stream(&node).await { - Ok(stream) => { + Ok(mut stream) => { metrics::backend_probe_duration( &cluster, node.as_str(), @@ -810,8 +907,27 @@ impl Connector for RedisConnector { attempt_start.elapsed(), ); metrics::backend_probe_result(&cluster, node.as_str(), "connect", true); - connection = Some(stream); - current_delay = self.reconnect_delay; + match self + .negotiate_resp_version(&cluster, &node, &mut stream) + .await + { + Ok(_) => { + connection = Some(stream); + current_delay = self.reconnect_delay; + } + Err(err) => { + warn!( + cluster = %cluster, + backend = %node.as_str(), + error = %err, + "failed to negotiate RESP version with backend" + ); + let _ = respond_to.send(Err(err)); + current_delay = self.increase_delay(current_delay); + sleep(current_delay).await; + continue; + } + } } Err(err) => { let elapsed = attempt_start.elapsed();