From bb694abcf981b7bb262d8b4eaa5049e7abda19a2 Mon Sep 17 00:00:00 2001 From: wayslog Date: Mon, 27 Oct 2025 16:14:18 +0800 Subject: [PATCH 1/4] feat: add resp3 support --- src/cluster/mod.rs | 44 ++- src/protocol/redis/codec.rs | 497 +++++++++++++++++++++++++++++----- src/protocol/redis/command.rs | 49 ++++ src/protocol/redis/mod.rs | 2 +- src/protocol/redis/types.rs | 22 ++ src/standalone/mod.rs | 18 +- 6 files changed, 559 insertions(+), 73 deletions(-) diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 01b7170..59b5fa9 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -25,8 +25,8 @@ use crate::hotkey::Hotkey; use crate::info::{InfoContext, ProxyMode}; use crate::metrics; use crate::protocol::redis::{ - BlockingKind, MultiDispatch, RedisCommand, RespCodec, RespValue, SlotMap, SubCommand, - SubResponse, SubscriptionKind, SLOT_COUNT, + BlockingKind, MultiDispatch, RedisCommand, RespCodec, RespValue, RespVersion, SlotMap, + SubCommand, SubResponse, SubscriptionKind, SLOT_COUNT, }; use crate::slowlog::Slowlog; use crate::utils::{crc16, trim_hash_tag}; @@ -143,9 +143,12 @@ impl ClusterProxy { let client_id = ClientId::new(); let _guard = FrontConnectionGuard::new(&self.cluster); - let (mut sink, stream) = Framed::new(socket, RespCodec::default()).split(); + let framed = Framed::new(socket, RespCodec::default()); + let codec_handle = framed.codec().clone(); + let (mut sink, stream) = framed.split(); let mut stream = stream.fuse(); let mut pending: FuturesOrdered> = FuturesOrdered::new(); + let mut pending_versions: VecDeque> = VecDeque::new(); let mut inflight = 0usize; let mut stream_closed = false; let mut auth_state = self.auth.as_ref().map(|auth| auth.new_session()); @@ -154,6 +157,15 @@ impl ClusterProxy { tokio::select! { Some(resp) = pending.next(), if inflight > 0 => { inflight -= 1; + let version_hint = pending_versions + .pop_front() + .unwrap_or(None); + let success = !resp.is_error(); + if success { + if let Some(version) = version_hint { + codec_handle.set_version(version); + } + } sink.send(resp).await?; } frame_opt = stream.next(), if !stream_closed && inflight < PIPELINE_LIMIT => { @@ -169,6 +181,7 @@ impl ClusterProxy { } AuthAction::Reply(resp) => { let fut = async move { resp }; + pending_versions.push_back(None); pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -176,6 +189,8 @@ impl ClusterProxy { } } + let requested_version = cmd.resp_version_request(); + if matches!(cmd.as_subscription(), SubscriptionKind::Channel | SubscriptionKind::Pattern) { let kind_label = cmd.kind_label(); if cmd.args().len() <= 1 { @@ -260,6 +275,7 @@ impl ClusterProxy { sink = new_sink; stream = new_stream.fuse(); pending = FuturesOrdered::new(); + pending_versions = VecDeque::new(); inflight = 0; stream_closed = false; continue; @@ -292,6 +308,7 @@ impl ClusterProxy { success, ); let fut = async move { response }; + pending_versions.push_back(None); pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -303,6 +320,7 @@ impl ClusterProxy { true, ); let fut = async move { response }; + pending_versions.push_back(None); pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -315,6 +333,7 @@ impl ClusterProxy { success, ); let fut = async move { response }; + pending_versions.push_back(None); pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -327,11 +346,13 @@ impl ClusterProxy { success, ); let fut = async move { response }; + pending_versions.push_back(None); pending.push_back(Box::pin(fut)); inflight += 1; continue; } let guard = self.prepare_dispatch(client_id, cmd); + pending_versions.push_back(requested_version); pending.push_back(Box::pin(guard)); inflight += 1; } @@ -341,6 +362,7 @@ impl ClusterProxy { metrics::front_command(self.cluster.as_ref(), "invalid", false); let message = Bytes::from(format!("ERR {err}")); let fut = async move { RespValue::Error(message) }; + pending_versions.push_back(None); pending.push_back(Box::pin(fut)); inflight += 1; } @@ -365,6 +387,12 @@ impl ClusterProxy { } while let Some(resp) = pending.next().await { + let version_hint = pending_versions.pop_front().unwrap_or(None); + if !resp.is_error() { + if let Some(version) = version_hint { + codec_handle.set_version(version); + } + } sink.send(resp).await?; } sink.close().await?; @@ -955,6 +983,7 @@ impl ClusterConnector { framed: &mut Framed, command: RedisCommand, ) -> Result { + let requested_version = command.resp_version_request(); let blocking = command.as_blocking(); if let Ok(name) = std::str::from_utf8(command.command_name()) { if name.eq_ignore_ascii_case("blpop") || name.eq_ignore_ascii_case("brpop") { @@ -966,7 +995,7 @@ impl ClusterConnector { .await .context("timed out sending command")??; - match blocking { + let response = match blocking { BlockingKind::Queue { .. } | BlockingKind::Stream { .. } => match framed.next().await { Some(Ok(value)) => Ok(value), Some(Err(err)) => Err(err.into()), @@ -978,7 +1007,14 @@ impl ClusterConnector { Ok(None) => Err(anyhow!("backend closed connection")), Err(_) => Err(anyhow!("timed out waiting for response")), }, + }?; + + if let Some(version) = requested_version { + if !response.is_error() { + framed.codec_mut().set_version(version); + } } + Ok(response) } async fn connect_with_retry( diff --git a/src/protocol/redis/codec.rs b/src/protocol/redis/codec.rs index 8870e74..84d1668 100644 --- a/src/protocol/redis/codec.rs +++ b/src/protocol/redis/codec.rs @@ -1,11 +1,59 @@ use anyhow::{anyhow, Result}; use bytes::{Buf, Bytes, BytesMut}; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; use super::types::RespValue; -#[derive(Debug, Default, Clone, Copy)] -pub struct RespCodec; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RespVersion { + Resp2, + Resp3, +} + +impl RespVersion { + fn as_u8(self) -> u8 { + match self { + RespVersion::Resp2 => 2, + RespVersion::Resp3 => 3, + } + } + + fn from_u8(value: u8) -> Self { + match value { + 3 => RespVersion::Resp3, + _ => RespVersion::Resp2, + } + } +} + +#[derive(Debug, Clone)] +pub struct RespCodec { + version: Arc, +} + +impl Default for RespCodec { + fn default() -> Self { + Self { + version: Arc::new(AtomicU8::new(RespVersion::Resp2.as_u8())), + } + } +} + +impl RespCodec { + pub fn version(&self) -> RespVersion { + RespVersion::from_u8(self.version.load(Ordering::SeqCst)) + } + + pub fn set_version(&self, version: RespVersion) { + self.version.store(version.as_u8(), Ordering::SeqCst); + } + + pub fn upgrade_to_resp3(&self) { + self.set_version(RespVersion::Resp3); + } +} impl Decoder for RespCodec { type Item = RespValue; @@ -27,7 +75,8 @@ impl Encoder for RespCodec { type Error = anyhow::Error; fn encode(&mut self, item: RespValue, dst: &mut BytesMut) -> Result<()> { - write_value(&item, dst); + let version = self.version(); + write_value(&item, version, dst); Ok(()) } } @@ -75,7 +124,9 @@ fn parse_value(src: &[u8], pos: &mut usize) -> Result> { .map_err(|err| anyhow!("invalid integer: {err}"))?; Ok(Some(RespValue::Integer(value))) } - b'$' => { + b'$' => parse_bulk_string(src, pos, start, false), + b'*' => parse_array(src, pos, start), + b'_' => { let line = match read_line(src, pos)? { Some(line) => line, None => { @@ -83,57 +134,232 @@ fn parse_value(src: &[u8], pos: &mut usize) -> Result> { return Ok(None); } }; - let len_str = std::str::from_utf8(line)?; - let len = len_str - .parse::() - .map_err(|err| anyhow!("invalid bulk length: {err}"))?; - if len < 0 { - return Ok(Some(RespValue::NullBulk)); - } - let len = len as usize; - if *pos + len + 2 > src.len() { - *pos = start; - return Ok(None); + if !line.is_empty() { + return Err(anyhow!("invalid null frame payload")); } - let data = &src[*pos..*pos + len]; - *pos += len + 2; // skip data and CRLF - Ok(Some(RespValue::BulkString(Bytes::copy_from_slice(data)))) + Ok(Some(RespValue::Null)) } - b'*' => { - let mut local_pos = *pos; - let line = match read_line(src, &mut local_pos)? { + b'#' => { + let line = match read_line(src, pos)? { Some(line) => line, None => { *pos = start; return Ok(None); } }; - let len_str = std::str::from_utf8(line)?; - let len = len_str - .parse::() - .map_err(|err| anyhow!("invalid array length: {err}"))?; - if len < 0 { - *pos = local_pos; - return Ok(Some(RespValue::NullArray)); + match line { + b"t" | b"T" => Ok(Some(RespValue::Boolean(true))), + b"f" | b"F" => Ok(Some(RespValue::Boolean(false))), + _ => Err(anyhow!("invalid boolean literal '{:?}'", line)), } - let mut values = Vec::with_capacity(len as usize); - let mut element_pos = local_pos; - for _ in 0..len { - match parse_value(src, &mut element_pos)? { - Some(value) => values.push(value), - None => { - *pos = start; - return Ok(None); - } + } + b',' => { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); } - } - *pos = element_pos; - Ok(Some(RespValue::Array(values))) + }; + Ok(Some(RespValue::Double(Bytes::copy_from_slice(line)))) } + b'(' => { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + Ok(Some(RespValue::BigNumber(Bytes::copy_from_slice(line)))) + } + b'=' => parse_verbatim_string(src, pos, start), + b'!' => parse_bulk_string(src, pos, start, true), + b'%' => parse_map(src, pos, start, RespValue::Map), + b'~' => parse_collection(src, pos, start, RespValue::Set), + b'>' => parse_collection(src, pos, start, RespValue::Push), + b'|' => parse_map(src, pos, start, RespValue::Attribute), _ => Err(anyhow!("unsupported RESP prefix '{}'.", prefix as char)), } } +fn parse_bulk_string( + src: &[u8], + pos: &mut usize, + start: usize, + as_error: bool, +) -> Result> { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "bulk string")?; + if len < 0 { + return Ok(Some(RespValue::NullBulk)); + } + let len = len as usize; + if *pos + len + 2 > src.len() { + *pos = start; + return Ok(None); + } + let data = &src[*pos..*pos + len]; + *pos += len + 2; + let payload = Bytes::copy_from_slice(data); + if as_error { + Ok(Some(RespValue::BlobError(payload))) + } else { + Ok(Some(RespValue::BulkString(payload))) + } +} + +fn parse_verbatim_string(src: &[u8], pos: &mut usize, start: usize) -> Result> { + let line = match read_line(src, pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "verbatim string")?; + if len < 0 { + return Err(anyhow!("verbatim string must not be null")); + } + let len = len as usize; + if *pos + len + 2 > src.len() { + *pos = start; + return Ok(None); + } + let data = &src[*pos..*pos + len]; + *pos += len + 2; + if len < 4 || data[3] != b':' { + return Err(anyhow!("invalid verbatim string header")); + } + let mut format = [0u8; 3]; + format.copy_from_slice(&data[..3]); + let payload = Bytes::copy_from_slice(&data[4..]); + Ok(Some(RespValue::VerbatimString { format, data: payload })) +} + +fn parse_array(src: &[u8], pos: &mut usize, start: usize) -> Result> { + let mut local_pos = *pos; + let line = match read_line(src, &mut local_pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "array")?; + if len < 0 { + *pos = local_pos; + return Ok(Some(RespValue::NullArray)); + } + let mut values = Vec::with_capacity(len as usize); + let mut element_pos = local_pos; + for _ in 0..len { + match parse_value(src, &mut element_pos)? { + Some(value) => values.push(value), + None => { + *pos = start; + return Ok(None); + } + } + } + *pos = element_pos; + Ok(Some(RespValue::Array(values))) +} + +fn parse_collection( + src: &[u8], + pos: &mut usize, + start: usize, + ctor: F, +) -> Result> +where + F: FnOnce(Vec) -> RespValue, +{ + let mut local_pos = *pos; + let line = match read_line(src, &mut local_pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "collection")?; + if len < 0 { + *pos = local_pos; + return Ok(Some(RespValue::Null)); + } + let mut values = Vec::with_capacity(len as usize); + let mut element_pos = local_pos; + for _ in 0..len { + match parse_value(src, &mut element_pos)? { + Some(value) => values.push(value), + None => { + *pos = start; + return Ok(None); + } + } + } + *pos = element_pos; + Ok(Some(ctor(values))) +} + +fn parse_map( + src: &[u8], + pos: &mut usize, + start: usize, + ctor: F, +) -> Result> +where + F: FnOnce(Vec<(RespValue, RespValue)>) -> RespValue, +{ + let mut local_pos = *pos; + let line = match read_line(src, &mut local_pos)? { + Some(line) => line, + None => { + *pos = start; + return Ok(None); + } + }; + let len = parse_length(line, "map")?; + if len < 0 { + *pos = local_pos; + return Ok(Some(RespValue::Null)); + } + let mut entries = Vec::with_capacity(len as usize); + let mut element_pos = local_pos; + for _ in 0..len { + let key = match parse_value(src, &mut element_pos)? { + Some(value) => value, + None => { + *pos = start; + return Ok(None); + } + }; + let value = match parse_value(src, &mut element_pos)? { + Some(value) => value, + None => { + *pos = start; + return Ok(None); + } + }; + entries.push((key, value)); + } + *pos = element_pos; + Ok(Some(ctor(entries))) +} + +fn parse_length(bytes: &[u8], kind: &str) -> Result { + let text = std::str::from_utf8(bytes)?; + text.parse::() + .map_err(|err| anyhow!("invalid {kind} length: {err}")) +} + fn read_line<'a>(src: &'a [u8], pos: &mut usize) -> Result> { if *pos >= src.len() { return Ok(None); @@ -150,7 +376,7 @@ fn read_line<'a>(src: &'a [u8], pos: &mut usize) -> Result> { Ok(None) } -fn write_value(value: &RespValue, dst: &mut BytesMut) { +fn write_value(value: &RespValue, version: RespVersion, dst: &mut BytesMut) { match value { RespValue::SimpleString(data) => { dst.extend_from_slice(b"+"); @@ -162,31 +388,170 @@ fn write_value(value: &RespValue, dst: &mut BytesMut) { dst.extend_from_slice(data); dst.extend_from_slice(b"\r\n"); } - RespValue::Integer(value) => { - dst.extend_from_slice(b":"); - dst.extend_from_slice(value.to_string().as_bytes()); - dst.extend_from_slice(b"\r\n"); - } - RespValue::BulkString(data) => { - dst.extend_from_slice(b"$"); - dst.extend_from_slice(data.len().to_string().as_bytes()); - dst.extend_from_slice(b"\r\n"); - dst.extend_from_slice(data); - dst.extend_from_slice(b"\r\n"); - } - RespValue::NullBulk => { - dst.extend_from_slice(b"$-1\r\n"); - } - RespValue::Array(values) => { - dst.extend_from_slice(b"*"); - dst.extend_from_slice(values.len().to_string().as_bytes()); - dst.extend_from_slice(b"\r\n"); - for value in values { - write_value(value, dst); + RespValue::Integer(value) => write_integer(*value, dst), + RespValue::BulkString(data) => write_bulk(data, dst), + RespValue::NullBulk => dst.extend_from_slice(b"$-1\r\n"), + RespValue::Array(values) => write_aggregate(b'*', values.len(), values, version, dst), + RespValue::NullArray => dst.extend_from_slice(b"*-1\r\n"), + RespValue::Null => match version { + RespVersion::Resp3 => dst.extend_from_slice(b"_\r\n"), + RespVersion::Resp2 => dst.extend_from_slice(b"$-1\r\n"), + }, + RespValue::Boolean(flag) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b"#"); + dst.extend_from_slice(if *flag { b"t" } else { b"f" }); + dst.extend_from_slice(b"\r\n"); } - } - RespValue::NullArray => { - dst.extend_from_slice(b"*-1\r\n"); - } + RespVersion::Resp2 => write_integer(if *flag { 1 } else { 0 }, dst), + }, + RespValue::Double(data) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b","); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => write_bulk(data, dst), + }, + RespValue::BigNumber(data) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b"("); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => write_bulk(data, dst), + }, + RespValue::VerbatimString { format, data } => match version { + RespVersion::Resp3 => { + let total_len = 4 + data.len(); + dst.extend_from_slice(b"="); + dst.extend_from_slice(total_len.to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + dst.extend_from_slice(format); + dst.extend_from_slice(b":"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => write_bulk(data, dst), + }, + RespValue::BlobError(data) => match version { + RespVersion::Resp3 => { + dst.extend_from_slice(b"!"); + dst.extend_from_slice(data.len().to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + RespVersion::Resp2 => { + dst.extend_from_slice(b"-"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); + } + }, + RespValue::Map(entries) => match version { + RespVersion::Resp3 => write_map(b'%', entries, version, dst), + RespVersion::Resp2 => write_map_as_array(entries, version, dst), + }, + RespValue::Set(values) => match version { + RespVersion::Resp3 => write_aggregate(b'~', values.len(), values, version, dst), + RespVersion::Resp2 => write_aggregate(b'*', values.len(), values, version, dst), + }, + RespValue::Push(values) => match version { + RespVersion::Resp3 => write_aggregate(b'>', values.len(), values, version, dst), + RespVersion::Resp2 => write_aggregate(b'*', values.len(), values, version, dst), + }, + RespValue::Attribute(entries) => match version { + RespVersion::Resp3 => write_map(b'|', entries, version, dst), + RespVersion::Resp2 => write_map_as_array(entries, version, dst), + }, + } +} + +fn write_integer(value: i64, dst: &mut BytesMut) { + dst.extend_from_slice(b":"); + dst.extend_from_slice(value.to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); +} + +fn write_bulk(data: &[u8], dst: &mut BytesMut) { + dst.extend_from_slice(b"$"); + dst.extend_from_slice(data.len().to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + dst.extend_from_slice(data); + dst.extend_from_slice(b"\r\n"); +} + +fn write_aggregate( + prefix: u8, + len: usize, + values: &[RespValue], + version: RespVersion, + dst: &mut BytesMut, +) { + dst.extend_from_slice(&[prefix]); + dst.extend_from_slice(len.to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + for value in values { + write_value(value, version, dst); + } +} + +fn write_map( + prefix: u8, + entries: &[(RespValue, RespValue)], + version: RespVersion, + dst: &mut BytesMut, +) { + dst.extend_from_slice(&[prefix]); + dst.extend_from_slice(entries.len().to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + for (key, value) in entries { + write_value(key, version, dst); + write_value(value, version, dst); + } +} + +fn write_map_as_array(entries: &[(RespValue, RespValue)], version: RespVersion, dst: &mut BytesMut) { + dst.extend_from_slice(b"*"); + dst.extend_from_slice((entries.len() * 2).to_string().as_bytes()); + dst.extend_from_slice(b"\r\n"); + for (key, value) in entries { + write_value(key, version, dst); + write_value(value, version, dst); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn map_entry() -> RespValue { + RespValue::Map(vec![( + RespValue::SimpleString(Bytes::from_static(b"mode")), + RespValue::BulkString(Bytes::from_static(b"standalone")), + )]) + } + + #[test] + fn encodes_map_with_resp3_prefix() { + let mut codec = RespCodec::default(); + codec.upgrade_to_resp3(); + let mut buf = BytesMut::new(); + codec.encode(map_entry(), &mut buf).unwrap(); + assert_eq!( + buf.as_ref(), + b"%1\r\n+mode\r\n$10\r\nstandalone\r\n" + ); + } + + #[test] + fn encodes_map_as_array_in_resp2() { + let mut codec = RespCodec::default(); + let mut buf = BytesMut::new(); + codec.encode(map_entry(), &mut buf).unwrap(); + assert_eq!( + buf.as_ref(), + b"*2\r\n+mode\r\n$10\r\nstandalone\r\n" + ); } } diff --git a/src/protocol/redis/command.rs b/src/protocol/redis/command.rs index 0637315..1e4f02f 100644 --- a/src/protocol/redis/command.rs +++ b/src/protocol/redis/command.rs @@ -8,6 +8,7 @@ use crate::backend::pool::BackendRequest; use crate::metrics; use crate::utils::{crc16, trim_hash_tag}; +use super::codec::RespVersion; use super::types::RespValue; pub const SLOT_COUNT: u16 = 16384; @@ -81,6 +82,7 @@ impl RedisCommand { RespValue::Array(_) => { bail!("nested array arguments are not supported"); } + other => bail!("unsupported RESP argument type: {:?}", other), } } Self::new(parts) @@ -196,6 +198,19 @@ impl RedisCommand { _ => None, } } + + pub fn resp_version_request(&self) -> Option { + if !self.command_name().eq_ignore_ascii_case(b"HELLO") { + return None; + } + let version_arg = self.args().get(1)?; + let version_text = std::str::from_utf8(version_arg).ok()?; + match version_text.parse::().ok()? { + 2 => Some(RespVersion::Resp2), + 3 => Some(RespVersion::Resp3), + _ => None, + } + } } impl BackendRequest for RedisCommand { @@ -212,6 +227,40 @@ impl BackendRequest for RedisCommand { } } +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + fn cmd(parts: &[&[u8]]) -> RedisCommand { + RedisCommand::new(parts.iter().map(|p| Bytes::copy_from_slice(p)).collect()).unwrap() + } + + #[test] + fn detects_resp3_request_from_hello() { + let command = cmd(&[b"HELLO", b"3"]); + assert_eq!(command.resp_version_request(), Some(RespVersion::Resp3)); + } + + #[test] + fn detects_resp2_request_from_hello() { + let command = cmd(&[b"HELLO", b"2"]); + assert_eq!(command.resp_version_request(), Some(RespVersion::Resp2)); + } + + #[test] + fn ignores_invalid_or_missing_version() { + let without_version = cmd(&[b"HELLO"]); + assert_eq!(without_version.resp_version_request(), None); + + let invalid_version = cmd(&[b"HELLO", b"foo"]); + assert_eq!(invalid_version.resp_version_request(), None); + + let other_command = cmd(&[b"PING"]); + assert_eq!(other_command.resp_version_request(), None); + } +} + impl fmt::Display for RedisCommand { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let args: Vec = self diff --git a/src/protocol/redis/mod.rs b/src/protocol/redis/mod.rs index 8d7a218..9c63e3a 100644 --- a/src/protocol/redis/mod.rs +++ b/src/protocol/redis/mod.rs @@ -3,7 +3,7 @@ mod command; mod slots; mod types; -pub use codec::RespCodec; +pub use codec::{RespCodec, RespVersion}; pub use command::{ Aggregator, BlockingKind, CommandKind, MultiDispatch, RedisCommand, RedisResponse, SubCommand, SubResponse, SubscriptionKind, SLOT_COUNT, diff --git a/src/protocol/redis/types.rs b/src/protocol/redis/types.rs index ba7f382..dce22bc 100644 --- a/src/protocol/redis/types.rs +++ b/src/protocol/redis/types.rs @@ -9,6 +9,16 @@ pub enum RespValue { NullBulk, Array(Vec), NullArray, + Null, + Boolean(bool), + Double(Bytes), + BigNumber(Bytes), + VerbatimString { format: [u8; 3], data: Bytes }, + BlobError(Bytes), + Map(Vec<(RespValue, RespValue)>), + Set(Vec), + Push(Vec), + Attribute(Vec<(RespValue, RespValue)>), } impl RespValue { @@ -28,6 +38,18 @@ impl RespValue { RespValue::Array(values) } + pub fn map(entries: Vec<(RespValue, RespValue)>) -> Self { + RespValue::Map(entries) + } + + pub fn null() -> Self { + RespValue::Null + } + + pub fn boolean(value: bool) -> Self { + RespValue::Boolean(value) + } + pub fn is_error(&self) -> bool { matches!(self, RespValue::Error(_)) } diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index 5e427da..9525e83 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -467,6 +467,7 @@ impl StandaloneProxy { continue; } + let requested_version = command.resp_version_request(); let response = match self.dispatch(client_id, command).await { Ok(resp) => resp, Err(err) => { @@ -477,7 +478,12 @@ impl StandaloneProxy { } }; - let success = !matches!(response, RespValue::Error(_)); + let success = !response.is_error(); + if success { + if let Some(version) = requested_version { + framed.codec_mut().set_version(version); + } + } metrics::front_command(self.cluster.as_ref(), kind_label, success); framed.send(response).await?; } @@ -711,6 +717,7 @@ impl RedisConnector { framed: &mut Framed, request: RedisCommand, ) -> Result { + let requested_version = request.resp_version_request(); let blocking = request.as_blocking(); let frame = request.to_resp(); let timeout_duration = self.current_timeout(); @@ -718,7 +725,7 @@ impl RedisConnector { .await .context("timed out while sending request")??; - match blocking { + let response = match blocking { BlockingKind::Queue { .. } | BlockingKind::Stream { .. } => match framed.next().await { Some(Ok(response)) => Ok(response), Some(Err(err)) => Err(err.into()), @@ -730,7 +737,14 @@ impl RedisConnector { Ok(None) => Err(anyhow!("backend closed connection")), Err(_) => Err(anyhow!("timed out waiting for backend reply")), }, + }?; + + if let Some(version) = requested_version { + if !response.is_error() { + framed.codec_mut().set_version(version); + } } + Ok(response) } async fn heartbeat(&self, framed: &mut Framed) -> Result<()> { From 8bd5700c8cd36a7b99bb2eb5d659057cbeaa236e Mon Sep 17 00:00:00 2001 From: wayslog Date: Thu, 30 Oct 2025 16:13:48 +0800 Subject: [PATCH 2/4] fixed: add support of upgrade/downgrade --- src/cluster/mod.rs | 69 ++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 59b5fa9..1212652 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -147,24 +147,20 @@ impl ClusterProxy { let codec_handle = framed.codec().clone(); let (mut sink, stream) = framed.split(); let mut stream = stream.fuse(); - let mut pending: FuturesOrdered> = FuturesOrdered::new(); - let mut pending_versions: VecDeque> = VecDeque::new(); + let mut pending: FuturesOrdered)>> = + FuturesOrdered::new(); + let mut resp3_negotiated = codec_handle.version() == RespVersion::Resp3; let mut inflight = 0usize; let mut stream_closed = false; let mut auth_state = self.auth.as_ref().map(|auth| auth.new_session()); loop { tokio::select! { - Some(resp) = pending.next(), if inflight > 0 => { + Some((resp, version)) = pending.next(), if inflight > 0 => { inflight -= 1; - let version_hint = pending_versions - .pop_front() - .unwrap_or(None); - let success = !resp.is_error(); - if success { - if let Some(version) = version_hint { - codec_handle.set_version(version); - } + if let Some(version) = version { + codec_handle.set_version(version); + resp3_negotiated = version == RespVersion::Resp3; } sink.send(resp).await?; } @@ -180,17 +176,13 @@ impl ClusterProxy { cmd = new_cmd; } AuthAction::Reply(resp) => { - let fut = async move { resp }; - pending_versions.push_back(None); + let fut = async move { (resp, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; } } } - - let requested_version = cmd.resp_version_request(); - if matches!(cmd.as_subscription(), SubscriptionKind::Channel | SubscriptionKind::Pattern) { let kind_label = cmd.kind_label(); if cmd.args().len() <= 1 { @@ -245,8 +237,12 @@ impl ClusterProxy { } }; while inflight > 0 { - if let Some(resp) = pending.next().await { + if let Some((resp, version)) = pending.next().await { inflight -= 1; + if let Some(version) = version { + codec_handle.set_version(version); + resp3_negotiated = version == RespVersion::Resp3; + } sink.send(resp).await?; } else { inflight = 0; @@ -275,7 +271,6 @@ impl ClusterProxy { sink = new_sink; stream = new_stream.fuse(); pending = FuturesOrdered::new(); - pending_versions = VecDeque::new(); inflight = 0; stream_closed = false; continue; @@ -307,8 +302,7 @@ impl ClusterProxy { kind_label, success, ); - let fut = async move { response }; - pending_versions.push_back(None); + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -319,8 +313,7 @@ impl ClusterProxy { cmd.kind_label(), true, ); - let fut = async move { response }; - pending_versions.push_back(None); + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -332,8 +325,7 @@ impl ClusterProxy { cmd.kind_label(), success, ); - let fut = async move { response }; - pending_versions.push_back(None); + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; @@ -345,15 +337,23 @@ impl ClusterProxy { cmd.kind_label(), success, ); - let fut = async move { response }; - pending_versions.push_back(None); + let fut = async move { (response, None) }; pending.push_back(Box::pin(fut)); inflight += 1; continue; } + let requested_version = cmd.resp_version_request(); let guard = self.prepare_dispatch(client_id, cmd); - pending_versions.push_back(requested_version); - pending.push_back(Box::pin(guard)); + let fut = async move { + let resp = guard.await; + let version = if !resp.is_error() { + requested_version + } else { + None + }; + (resp, version) + }; + pending.push_back(Box::pin(fut)); inflight += 1; } Err(err) => { @@ -361,8 +361,7 @@ impl ClusterProxy { metrics::front_error(self.cluster.as_ref(), "parse"); metrics::front_command(self.cluster.as_ref(), "invalid", false); let message = Bytes::from(format!("ERR {err}")); - let fut = async move { RespValue::Error(message) }; - pending_versions.push_back(None); + let fut = async move { (RespValue::Error(message), None) }; pending.push_back(Box::pin(fut)); inflight += 1; } @@ -386,12 +385,10 @@ impl ClusterProxy { } } - while let Some(resp) = pending.next().await { - let version_hint = pending_versions.pop_front().unwrap_or(None); - if !resp.is_error() { - if let Some(version) = version_hint { - codec_handle.set_version(version); - } + while let Some((resp, version)) = pending.next().await { + if let Some(version) = version { + codec_handle.set_version(version); + resp3_negotiated = version == RespVersion::Resp3; } sink.send(resp).await?; } From 1219ca7a9f67f4385ad52aacf3743b2c867419ba Mon Sep 17 00:00:00 2001 From: wayslog Date: Thu, 6 Nov 2025 15:49:10 +0800 Subject: [PATCH 3/4] fix: unable upgrade resp3 of error reply --- src/protocol/redis/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/protocol/redis/types.rs b/src/protocol/redis/types.rs index dce22bc..4073470 100644 --- a/src/protocol/redis/types.rs +++ b/src/protocol/redis/types.rs @@ -51,7 +51,7 @@ impl RespValue { } pub fn is_error(&self) -> bool { - matches!(self, RespValue::Error(_)) + matches!(self, RespValue::Error(_) | RespValue::BlobError(_)) } pub fn as_array(&self) -> Option<&[RespValue]> { From d1c5e2c9da9217d5f1886fc0ef32a0aa36157e09 Mon Sep 17 00:00:00 2001 From: wayslog Date: Thu, 6 Nov 2025 18:22:32 +0800 Subject: [PATCH 4/4] fix: add backend resp3 impl --- docker/config/integration.toml | 2 + docker/integration-test.sh | 20 ++++++ src/auth/mod.rs | 11 ++++ src/cluster/mod.rs | 113 +++++++++++++++++++++++++++++++-- src/config/mod.rs | 8 ++- src/protocol/redis/codec.rs | 32 +++++----- src/standalone/mod.rs | 112 ++++++++++++++++++++++++++++++-- 7 files changed, 269 insertions(+), 29 deletions(-) mode change 100644 => 100755 docker/integration-test.sh diff --git a/docker/config/integration.toml b/docker/config/integration.toml index ee3421e..596d205 100644 --- a/docker/config/integration.toml +++ b/docker/config/integration.toml @@ -8,6 +8,7 @@ servers = [ ] password = "front-standalone-secret" backend_password = "backend-standalone-secret" +backend_resp_version = "resp3" [[clusters]] name = "cluster" @@ -24,3 +25,4 @@ auth = { password = "front-cluster-secret", users = [ { username = "ops", password = "ops-secret" } ] } backend_password = "backend-cluster-secret" +backend_resp_version = "resp3" diff --git a/docker/integration-test.sh b/docker/integration-test.sh old mode 100644 new mode 100755 index 9e0b466..db061ab --- a/docker/integration-test.sh +++ b/docker/integration-test.sh @@ -41,6 +41,26 @@ wait_for_cluster() { wait_for_cluster +if ! hello_standalone="$(REDISCLI_AUTH="$STANDALONE_PASS" redis-cli -h aster-proxy -p 6380 --raw HELLO 3 AUTH default "$STANDALONE_PASS")"; then + echo "HELLO 3 handshake against standalone proxy failed" >&2 + exit 1 +fi +if ! printf "%s\n" "$hello_standalone" | awk 'BEGIN{found=0} {if(prev=="proto" && $0=="3"){found=1} prev=$0} END{exit(found?0:1)}'; then + echo "HELLO 3 response from standalone proxy missing proto=3:" >&2 + printf "%s\n" "$hello_standalone" >&2 + exit 1 +fi + +if ! hello_cluster="$(REDISCLI_AUTH="$CLUSTER_USER_PASS" redis-cli -h aster-proxy -p 6381 --user "$CLUSTER_USER" --raw HELLO 3 AUTH "$CLUSTER_USER" "$CLUSTER_USER_PASS")"; then + echo "HELLO 3 handshake against cluster proxy failed" >&2 + exit 1 +fi +if ! printf "%s\n" "$hello_cluster" | awk 'BEGIN{found=0} {if(prev=="proto" && $0=="3"){found=1} prev=$0} END{exit(found?0:1)}'; then + echo "HELLO 3 response from cluster proxy missing proto=3:" >&2 + printf "%s\n" "$hello_cluster" >&2 + exit 1 +fi + noauth_output="$(redis-cli -h aster-proxy -p 6380 PING 2>&1 || true)" if echo "$noauth_output" | grep -q "PONG"; then echo "Expected standalone proxy to require authentication" >&2 diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 06aa855..3ac7f70 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -100,6 +100,17 @@ impl BackendAuth { .expect("AUTH command is always valid") } + pub fn hello_credentials(&self) -> Option<(Bytes, Bytes)> { + match self.parts.len() { + 2 => Some(( + Bytes::from_static(DEFAULT_USER.as_bytes()), + self.parts[1].clone(), + )), + 3 => Some((self.parts[1].clone(), self.parts[2].clone())), + _ => None, + } + } + pub async fn apply_to_stream( &self, framed: &mut Framed, diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 1212652..36abfec 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -78,6 +78,7 @@ impl ClusterProxy { runtime.clone(), REQUEST_TIMEOUT_MS, backend_auth.clone(), + config.backend_resp_version, )); let pool = Arc::new(ConnectionPool::new(cluster.clone(), connector.clone())); let auth = config @@ -149,7 +150,7 @@ impl ClusterProxy { let mut stream = stream.fuse(); let mut pending: FuturesOrdered)>> = FuturesOrdered::new(); - let mut resp3_negotiated = codec_handle.version() == RespVersion::Resp3; + let mut _resp3_negotiated = codec_handle.version() == RespVersion::Resp3; let mut inflight = 0usize; let mut stream_closed = false; let mut auth_state = self.auth.as_ref().map(|auth| auth.new_session()); @@ -160,7 +161,7 @@ impl ClusterProxy { inflight -= 1; if let Some(version) = version { codec_handle.set_version(version); - resp3_negotiated = version == RespVersion::Resp3; + _resp3_negotiated = version == RespVersion::Resp3; } sink.send(resp).await?; } @@ -241,7 +242,7 @@ impl ClusterProxy { inflight -= 1; if let Some(version) = version { codec_handle.set_version(version); - resp3_negotiated = version == RespVersion::Resp3; + _resp3_negotiated = version == RespVersion::Resp3; } sink.send(resp).await?; } else { @@ -388,7 +389,7 @@ impl ClusterProxy { while let Some((resp, version)) = pending.next().await { if let Some(version) = version { codec_handle.set_version(version); - resp3_negotiated = version == RespVersion::Resp3; + _resp3_negotiated = version == RespVersion::Resp3; } sink.send(resp).await?; } @@ -928,6 +929,7 @@ struct ClusterConnector { heartbeat_interval: Duration, reconnect_base_delay: Duration, max_reconnect_attempts: usize, + backend_resp_version: RespVersion, } impl ClusterConnector { @@ -935,6 +937,7 @@ impl ClusterConnector { runtime: Arc, default_timeout_ms: u64, backend_auth: Option, + backend_resp_version: RespVersion, ) -> Self { Self { runtime, @@ -943,6 +946,7 @@ impl ClusterConnector { heartbeat_interval: Duration::from_secs(30), reconnect_base_delay: Duration::from_millis(50), max_reconnect_attempts: 3, + backend_resp_version, } } @@ -975,6 +979,85 @@ impl ClusterConnector { Ok(framed) } + async fn negotiate_resp_version( + &self, + cluster: &str, + node: &BackendNode, + framed: &mut Framed, + ) -> Result { + framed.codec_mut().set_version(RespVersion::Resp2); + if self.backend_resp_version != RespVersion::Resp3 { + return Ok(RespVersion::Resp2); + } + + let timeout_duration = self.current_timeout(); + let mut hello_parts = vec![ + RespValue::BulkString(Bytes::from_static(b"HELLO")), + RespValue::BulkString(Bytes::from_static(b"3")), + ]; + if let Some(auth) = &self.backend_auth { + if let Some((username, password)) = auth.hello_credentials() { + hello_parts.push(RespValue::BulkString(Bytes::from_static(b"AUTH"))); + hello_parts.push(RespValue::BulkString(username)); + hello_parts.push(RespValue::BulkString(password)); + } + } + let hello = RespValue::Array(hello_parts); + + match timeout(timeout_duration, framed.send(hello)).await { + Ok(Ok(())) => {} + Ok(Err(err)) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!("failed to send RESP3 HELLO to {}", node.as_str()))); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out sending RESP3 HELLO", + node.as_str() + )); + } + } + + let reply = match timeout(timeout_duration, framed.next()).await { + Ok(Some(Ok(value))) => value, + Ok(Some(Err(err))) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!( + "failed to read RESP3 HELLO reply from {}", + node.as_str() + ))); + } + Ok(None) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} closed connection during RESP3 HELLO", + node.as_str() + )); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out waiting for RESP3 HELLO reply", + node.as_str() + )); + } + }; + + if reply.is_error() { + info!( + cluster = %cluster, + backend = %node.as_str(), + "backend rejected RESP3 HELLO; falling back to RESP2" + ); + framed.codec_mut().set_version(RespVersion::Resp2); + return Ok(RespVersion::Resp2); + } + + framed.codec_mut().set_version(RespVersion::Resp3); + Ok(RespVersion::Resp3) + } + async fn execute( &self, framed: &mut Framed, @@ -1023,7 +1106,7 @@ impl ClusterConnector { for attempt in 0..self.max_reconnect_attempts { let attempt_start = Instant::now(); match self.open_stream(node.as_str()).await { - Ok(stream) => { + Ok(mut stream) => { metrics::backend_probe_duration( cluster, node.as_str(), @@ -1031,7 +1114,25 @@ impl ClusterConnector { attempt_start.elapsed(), ); metrics::backend_probe_result(cluster, node.as_str(), "connect", true); - return Ok(stream); + match self + .negotiate_resp_version(cluster, node, &mut stream) + .await + { + Ok(_) => return Ok(stream), + Err(err) => { + warn!( + cluster = %cluster, + backend = %node.as_str(), + attempt = attempt + 1, + error = %err, + "failed to negotiate RESP version with backend" + ); + last_error = Some(err); + if attempt + 1 < self.max_reconnect_attempts { + sleep(self.reconnect_base_delay).await; + } + } + } } Err(err) => { let elapsed = attempt_start.elapsed(); diff --git a/src/config/mod.rs b/src/config/mod.rs index 69b3179..af305c9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -15,7 +15,7 @@ use crate::hotkey::{ Hotkey, HotkeyConfig, DEFAULT_DECAY, DEFAULT_HOTKEY_CAPACITY, DEFAULT_SAMPLE_EVERY, DEFAULT_SKETCH_DEPTH, DEFAULT_SKETCH_WIDTH, }; -use crate::protocol::redis::{RedisCommand, RespValue}; +use crate::protocol::redis::{RedisCommand, RespValue, RespVersion}; use crate::slowlog::Slowlog; /// Environment variable controlling the default worker thread count when a @@ -51,6 +51,10 @@ fn default_hotkey_decay() -> f64 { DEFAULT_DECAY } +fn default_backend_resp_version() -> RespVersion { + RespVersion::Resp2 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { #[serde(default)] @@ -177,6 +181,8 @@ pub struct ClusterConfig { pub hotkey_capacity: usize, #[serde(default = "default_hotkey_decay")] pub hotkey_decay: f64, + #[serde(default = "default_backend_resp_version")] + pub backend_resp_version: RespVersion, } impl ClusterConfig { diff --git a/src/protocol/redis/codec.rs b/src/protocol/redis/codec.rs index 84d1668..88851e8 100644 --- a/src/protocol/redis/codec.rs +++ b/src/protocol/redis/codec.rs @@ -1,12 +1,14 @@ use anyhow::{anyhow, Result}; use bytes::{Buf, Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicU8, Ordering}; use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; use super::types::RespValue; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] pub enum RespVersion { Resp2, Resp3, @@ -240,7 +242,10 @@ fn parse_verbatim_string(src: &[u8], pos: &mut usize, start: usize) -> Result Result> { @@ -309,12 +314,7 @@ where Ok(Some(ctor(values))) } -fn parse_map( - src: &[u8], - pos: &mut usize, - start: usize, - ctor: F, -) -> Result> +fn parse_map(src: &[u8], pos: &mut usize, start: usize, ctor: F) -> Result> where F: FnOnce(Vec<(RespValue, RespValue)>) -> RespValue, { @@ -511,7 +511,11 @@ fn write_map( } } -fn write_map_as_array(entries: &[(RespValue, RespValue)], version: RespVersion, dst: &mut BytesMut) { +fn write_map_as_array( + entries: &[(RespValue, RespValue)], + version: RespVersion, + dst: &mut BytesMut, +) { dst.extend_from_slice(b"*"); dst.extend_from_slice((entries.len() * 2).to_string().as_bytes()); dst.extend_from_slice(b"\r\n"); @@ -538,10 +542,7 @@ mod tests { codec.upgrade_to_resp3(); let mut buf = BytesMut::new(); codec.encode(map_entry(), &mut buf).unwrap(); - assert_eq!( - buf.as_ref(), - b"%1\r\n+mode\r\n$10\r\nstandalone\r\n" - ); + assert_eq!(buf.as_ref(), b"%1\r\n+mode\r\n$10\r\nstandalone\r\n"); } #[test] @@ -549,9 +550,6 @@ mod tests { let mut codec = RespCodec::default(); let mut buf = BytesMut::new(); codec.encode(map_entry(), &mut buf).unwrap(); - assert_eq!( - buf.as_ref(), - b"*2\r\n+mode\r\n$10\r\nstandalone\r\n" - ); + assert_eq!(buf.as_ref(), b"*2\r\n+mode\r\n$10\r\nstandalone\r\n"); } } diff --git a/src/standalone/mod.rs b/src/standalone/mod.rs index 9525e83..3167e5e 100644 --- a/src/standalone/mod.rs +++ b/src/standalone/mod.rs @@ -25,8 +25,8 @@ use crate::hotkey::Hotkey; use crate::info::{InfoContext, ProxyMode}; use crate::metrics; use crate::protocol::redis::{ - BlockingKind, MultiDispatch, RedisCommand, RedisResponse, RespCodec, RespValue, SubCommand, - SubResponse, SubscriptionKind, + BlockingKind, MultiDispatch, RedisCommand, RedisResponse, RespCodec, RespValue, RespVersion, + SubCommand, SubResponse, SubscriptionKind, }; use crate::slowlog::Slowlog; use crate::utils::trim_hash_tag; @@ -79,6 +79,7 @@ impl StandaloneProxy { runtime.clone(), DEFAULT_TIMEOUT_MS, backend_auth.clone(), + config.backend_resp_version, )); let auth = config .frontend_auth_users() @@ -664,6 +665,7 @@ struct RedisConnector { max_reconnect_delay: Duration, heartbeat_interval: Duration, backend_auth: Option, + backend_resp_version: RespVersion, } impl RedisConnector { @@ -671,6 +673,7 @@ impl RedisConnector { runtime: Arc, default_timeout_ms: u64, backend_auth: Option, + backend_resp_version: RespVersion, ) -> Self { Self { runtime, @@ -679,6 +682,7 @@ impl RedisConnector { max_reconnect_delay: Duration::from_secs(2), heartbeat_interval: Duration::from_secs(20), backend_auth, + backend_resp_version, } } @@ -712,6 +716,85 @@ impl RedisConnector { Ok(framed) } + async fn negotiate_resp_version( + &self, + cluster: &str, + node: &BackendNode, + framed: &mut Framed, + ) -> Result { + framed.codec_mut().set_version(RespVersion::Resp2); + if self.backend_resp_version != RespVersion::Resp3 { + return Ok(RespVersion::Resp2); + } + + let timeout_duration = self.current_timeout(); + let mut hello_parts = vec![ + RespValue::BulkString(Bytes::from_static(b"HELLO")), + RespValue::BulkString(Bytes::from_static(b"3")), + ]; + if let Some(auth) = &self.backend_auth { + if let Some((username, password)) = auth.hello_credentials() { + hello_parts.push(RespValue::BulkString(Bytes::from_static(b"AUTH"))); + hello_parts.push(RespValue::BulkString(username)); + hello_parts.push(RespValue::BulkString(password)); + } + } + let hello = RespValue::Array(hello_parts); + + match timeout(timeout_duration, framed.send(hello)).await { + Ok(Ok(())) => {} + Ok(Err(err)) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!("failed to send RESP3 HELLO to {}", node.as_str()))); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out sending RESP3 HELLO", + node.as_str() + )); + } + } + + let reply = match timeout(timeout_duration, framed.next()).await { + Ok(Some(Ok(value))) => value, + Ok(Some(Err(err))) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(err.context(format!( + "failed to read RESP3 HELLO reply from {}", + node.as_str() + ))); + } + Ok(None) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} closed connection during RESP3 HELLO", + node.as_str() + )); + } + Err(_) => { + metrics::backend_error(cluster, node.as_str(), "resp3_handshake"); + return Err(anyhow!( + "backend {} timed out waiting for RESP3 HELLO reply", + node.as_str() + )); + } + }; + + if reply.is_error() { + info!( + cluster = %cluster, + backend = %node.as_str(), + "backend rejected RESP3 HELLO; falling back to RESP2" + ); + framed.codec_mut().set_version(RespVersion::Resp2); + return Ok(RespVersion::Resp2); + } + + framed.codec_mut().set_version(RespVersion::Resp3); + Ok(RespVersion::Resp3) + } + async fn execute_request( &self, framed: &mut Framed, @@ -816,7 +899,7 @@ impl Connector for RedisConnector { if connection.is_none() { let attempt_start = Instant::now(); match self.open_stream(&node).await { - Ok(stream) => { + Ok(mut stream) => { metrics::backend_probe_duration( &cluster, node.as_str(), @@ -824,8 +907,27 @@ impl Connector for RedisConnector { attempt_start.elapsed(), ); metrics::backend_probe_result(&cluster, node.as_str(), "connect", true); - connection = Some(stream); - current_delay = self.reconnect_delay; + match self + .negotiate_resp_version(&cluster, &node, &mut stream) + .await + { + Ok(_) => { + connection = Some(stream); + current_delay = self.reconnect_delay; + } + Err(err) => { + warn!( + cluster = %cluster, + backend = %node.as_str(), + error = %err, + "failed to negotiate RESP version with backend" + ); + let _ = respond_to.send(Err(err)); + current_delay = self.increase_delay(current_delay); + sleep(current_delay).await; + continue; + } + } } Err(err) => { let elapsed = attempt_start.elapsed();