diff --git a/Cargo.lock b/Cargo.lock index 6e7f81ec..346f80e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1084,6 +1084,17 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "backon" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" +dependencies = [ + "fastrand", + "gloo-timers", + "tokio", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -1667,6 +1678,7 @@ dependencies = [ "alloy-node-bindings", "anyhow", "async-trait", + "backon", "chrono", "serde", "serde_json", @@ -1917,6 +1929,18 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" +[[package]] +name = "gloo-timers" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "group" version = "0.13.0" diff --git a/Cargo.toml b/Cargo.toml index 59d28d39..71f1bf85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } test-log = { version = "0.2.18", features = ["trace"] } hex = "0.4" +backon = "1.5.2" [package] name = "event-scanner" @@ -67,6 +68,7 @@ chrono.workspace = true alloy-node-bindings.workspace = true tokio-stream.workspace = true tracing.workspace = true +backon.workspace = true [dev-dependencies] tracing-subscriber.workspace = true diff --git a/src/block_range_scanner.rs b/src/block_range_scanner.rs index 6acf1eb0..e1064152 100644 --- a/src/block_range_scanner.rs +++ b/src/block_range_scanner.rs @@ -41,11 +41,6 @@ //! error!("Received error from subscription: {e}"); //! match e { //! ScannerError::ServiceShutdown => break, -//! ScannerError::WebSocketConnectionFailed(_) => { -//! error!( -//! "WebSocket connection failed, continuing to listen for reconnection" -//! ); -//! } //! _ => { //! error!("Non-fatal error, continuing: {e}"); //! } @@ -63,16 +58,20 @@ //! } //! ``` -use std::{cmp::Ordering, ops::RangeInclusive}; +use std::{cmp::Ordering, ops::RangeInclusive, time::Duration}; use tokio::{ - join, sync::{mpsc, oneshot}, + try_join, }; use tokio_stream::{StreamExt, wrappers::ReceiverStream}; use crate::{ ScannerMessage, error::ScannerError, + robust_provider::{ + DEFAULT_MAX_RETRIES, DEFAULT_MAX_TIMEOUT, DEFAULT_RETRY_INTERVAL, + Error as RobustProviderError, RobustProvider, + }, types::{ScannerStatus, TryStream}, }; use alloy::{ @@ -80,7 +79,7 @@ use alloy::{ eips::BlockNumberOrTag, network::{BlockResponse, Network, primitives::HeaderResponse}, primitives::{B256, BlockNumber}, - providers::{Provider, RootProvider}, + providers::RootProvider, pubsub::Subscription, rpc::client::ClientBuilder, transports::{ @@ -113,6 +112,12 @@ impl PartialEq> for Message { } } +impl From for Message { + fn from(error: RobustProviderError) -> Self { + Message::Error(error.into()) + } +} + impl From> for Message { fn from(error: RpcError) -> Self { Message::Error(error.into()) @@ -128,6 +133,9 @@ impl From for Message { #[derive(Clone, Copy)] pub struct BlockRangeScanner { pub max_block_range: u64, + pub max_timeout: Duration, + pub max_retries: usize, + pub retry_interval: Duration, } impl Default for BlockRangeScanner { @@ -139,7 +147,12 @@ impl Default for BlockRangeScanner { impl BlockRangeScanner { #[must_use] pub fn new() -> Self { - Self { max_block_range: DEFAULT_MAX_BLOCK_RANGE } + Self { + max_block_range: DEFAULT_MAX_BLOCK_RANGE, + max_timeout: DEFAULT_MAX_TIMEOUT, + max_retries: DEFAULT_MAX_RETRIES, + retry_interval: DEFAULT_RETRY_INTERVAL, + } } #[must_use] @@ -148,6 +161,24 @@ impl BlockRangeScanner { self } + #[must_use] + pub fn with_max_timeout(mut self, rpc_timeout: Duration) -> Self { + self.max_timeout = rpc_timeout; + self + } + + #[must_use] + pub fn with_max_retries(mut self, rpc_max_retries: usize) -> Self { + self.max_retries = rpc_max_retries; + self + } + + #[must_use] + pub fn with_retry_interval(mut self, rpc_retry_interval: Duration) -> Self { + self.retry_interval = rpc_retry_interval; + self + } + /// Connects to the provider via WebSocket /// /// # Errors @@ -182,19 +213,26 @@ impl BlockRangeScanner { /// Returns an error if the connection fails #[must_use] pub fn connect(self, provider: RootProvider) -> ConnectedBlockRangeScanner { - ConnectedBlockRangeScanner { provider, max_block_range: self.max_block_range } + let robust_provider = RobustProvider::new(provider) + .max_timeout(self.max_timeout) + .max_retries(self.max_retries) + .retry_interval(self.retry_interval); + ConnectedBlockRangeScanner { + provider: robust_provider, + max_block_range: self.max_block_range, + } } } pub struct ConnectedBlockRangeScanner { - provider: RootProvider, + provider: RobustProvider, max_block_range: u64, } impl ConnectedBlockRangeScanner { - /// Returns the underlying Provider. + /// Returns the `RobustProvider` #[must_use] - pub fn provider(&self) -> &RootProvider { + pub fn provider(&self) -> &RobustProvider { &self.provider } @@ -240,7 +278,7 @@ pub enum Command { } struct Service { - provider: RootProvider, + provider: RobustProvider, max_block_range: u64, error_count: u64, command_receiver: mpsc::Receiver, @@ -248,7 +286,7 @@ struct Service { } impl Service { - pub fn new(provider: RootProvider, max_block_range: u64) -> (Self, mpsc::Sender) { + pub fn new(provider: RobustProvider, max_block_range: u64) -> (Self, mpsc::Sender) { let (cmd_tx, cmd_rx) = mpsc::channel(100); let service = Self { @@ -351,10 +389,8 @@ impl Service { self.provider.get_block_by_number(end_height) )?; - let start_block_num = - start_block.ok_or_else(|| ScannerError::BlockNotFound(start_height))?.header().number(); - let end_block_num = - end_block.ok_or_else(|| ScannerError::BlockNotFound(end_height))?.header().number(); + let start_block_num = start_block.header().number(); + let end_block_num = end_block.header().number(); let (start_block_num, end_block_num) = match start_block_num.cmp(&end_block_num) { Ordering::Greater => (end_block_num, start_block_num), @@ -388,23 +424,14 @@ impl Service { let get_start_block = async || -> Result { let block = match start_height { BlockNumberOrTag::Number(num) => num, - block_tag => provider - .get_block_by_number(block_tag) - .await? - .ok_or_else(|| ScannerError::BlockNotFound(block_tag))? - .header() - .number(), + block_tag => provider.get_block_by_number(block_tag).await?.header().number(), }; Ok(block) }; let get_latest_block = async || -> Result { - let block = provider - .get_block_by_number(BlockNumberOrTag::Latest) - .await? - .ok_or_else(|| ScannerError::BlockNotFound(BlockNumberOrTag::Latest))? - .header() - .number(); + let block = + provider.get_block_by_number(BlockNumberOrTag::Latest).await?.header().number(); Ok(block) }; @@ -496,13 +523,10 @@ impl Service { let max_block_range = self.max_block_range; let provider = self.provider.clone(); - let (start_block, end_block) = join!( + let (start_block, end_block) = try_join!( self.provider.get_block_by_number(start_height), self.provider.get_block_by_number(end_height), - ); - - let start_block = start_block?.ok_or(ScannerError::BlockNotFound(start_height))?; - let end_block = end_block?.ok_or(ScannerError::BlockNotFound(end_height))?; + )?; // normalize block range let (from, to) = match start_block.header().number().cmp(&end_block.header().number()) { @@ -529,7 +553,7 @@ impl Service { to: N::BlockResponse, max_block_range: u64, sender: &mpsc::Sender, - provider: &RootProvider, + provider: &RobustProvider, ) { let mut batch_count = 0; @@ -584,12 +608,10 @@ impl Service { batch_from = from; // store the updated end block hash tip_hash = match provider.get_block_by_number(from.into()).await { - Ok(block) => block - .unwrap_or_else(|| { - panic!("Block with number '{from}' should exist post-reorg") - }) - .header() - .hash(), + Ok(block) => block.header().hash(), + Err(RobustProviderError::BlockNotFound(_)) => { + panic!("Block with number '{from}' should exist post-reorg"); + } Err(e) => { error!(error = %e, "Terminal RPC call error, shutting down"); _ = sender.try_stream(e); @@ -644,9 +666,9 @@ impl Service { info!(batch_count = batch_count, "Historical sync completed"); } - async fn stream_live_blocks>( + async fn stream_live_blocks( mut range_start: BlockNumber, - provider: P, + provider: RobustProvider, sender: mpsc::Sender, block_confirmations: u64, max_block_range: u64, @@ -749,22 +771,22 @@ impl Service { } async fn get_block_subscription( - provider: &impl Provider, + provider: &RobustProvider, ) -> Result, ScannerError> { - let ws_stream = provider - .subscribe_blocks() - .await - .map_err(|_| ScannerError::WebSocketConnectionFailed(1))?; - + let ws_stream = provider.subscribe_blocks().await?; Ok(ws_stream) } } async fn reorg_detected( - provider: &RootProvider, + provider: &RobustProvider, hash_to_check: B256, -) -> Result> { - Ok(provider.get_block_by_hash(hash_to_check).await?.is_none()) +) -> Result { + match provider.get_block_by_hash(hash_to_check).await { + Ok(_) => Ok(false), + Err(RobustProviderError::BlockNotFound(_)) => Ok(true), + Err(e) => Err(e.into()), + } } pub struct BlockRangeScannerClient { @@ -913,6 +935,7 @@ mod tests { use super::*; use crate::{assert_closed, assert_empty, assert_next}; use alloy::{ + eips::BlockId, network::Ethereum, providers::{ProviderBuilder, ext::AnvilApi}, rpc::types::anvil::ReorgOptions, @@ -1365,13 +1388,15 @@ mod tests { #[tokio::test] async fn try_send_forwards_errors_to_subscribers() { - let (tx, mut rx) = mpsc::channel(1); + let (tx, mut rx) = mpsc::channel::(1); - _ = tx.try_stream(ScannerError::WebSocketConnectionFailed(4)).await; + _ = tx.try_stream(ScannerError::BlockNotFound(4.into())).await; assert!(matches!( rx.recv().await, - Some(Message::Error(ScannerError::WebSocketConnectionFailed(4))) + Some(ScannerMessage::Error(ScannerError::BlockNotFound(BlockId::Number( + BlockNumberOrTag::Number(4) + )))) )); } @@ -1566,7 +1591,10 @@ mod tests { let stream = client.rewind(0, 999).await; - assert!(matches!(stream, Err(ScannerError::BlockNotFound(BlockNumberOrTag::Number(999))))); + assert!(matches!( + stream, + Err(ScannerError::BlockNotFound(BlockId::Number(BlockNumberOrTag::Number(999)))) + )); Ok(()) } diff --git a/src/error.rs b/src/error.rs index 3d849a11..cc67dc13 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,16 +1,15 @@ use std::sync::Arc; use alloy::{ - eips::BlockNumberOrTag, - transports::{RpcError, TransportErrorKind, http::reqwest}, + eips::BlockId, + transports::{RpcError, TransportErrorKind}, }; use thiserror::Error; +use crate::robust_provider::Error as RobustProviderError; + #[derive(Error, Debug, Clone)] pub enum ScannerError { - #[error("HTTP request failed: {0}")] - HttpError(Arc), - // #[error("WebSocket error: {0}")] // WebSocketError(#[from] tokio_tungstenite::tungstenite::Error), #[error("Serialization error: {0}")] @@ -31,16 +30,23 @@ pub enum ScannerError { #[error("Historical sync failed: {0}")] HistoricalSyncError(String), - #[error("WebSocket connection failed after {0} attempts")] - WebSocketConnectionFailed(usize), + #[error("Block not found, Block Id: {0}")] + BlockNotFound(BlockId), + + #[error("Operation timed out")] + Timeout, - #[error("Block not found, block number: {0}")] - BlockNotFound(BlockNumberOrTag), + #[error("RPC call failed after exhausting all retry attempts: {0}")] + RetryFailure(Arc>), } -impl From for ScannerError { - fn from(error: reqwest::Error) -> Self { - ScannerError::HttpError(Arc::new(error)) +impl From for ScannerError { + fn from(error: RobustProviderError) -> ScannerError { + match error { + RobustProviderError::Timeout => ScannerError::Timeout, + RobustProviderError::RetryFailure(err) => ScannerError::RetryFailure(err), + RobustProviderError::BlockNotFound(block) => ScannerError::BlockNotFound(block), + } } } diff --git a/src/event_scanner/message.rs b/src/event_scanner/message.rs index 5f916388..ebd1081a 100644 --- a/src/event_scanner/message.rs +++ b/src/event_scanner/message.rs @@ -1,6 +1,6 @@ use alloy::{rpc::types::Log, sol_types::SolEvent}; -use crate::{ScannerError, ScannerMessage}; +use crate::{ScannerError, ScannerMessage, robust_provider::Error as RobustProviderError}; pub type Message = ScannerMessage, ScannerError>; @@ -10,6 +10,13 @@ impl From> for Message { } } +impl From for Message { + fn from(error: RobustProviderError) -> Message { + let scanner_error: ScannerError = error.into(); + scanner_error.into() + } +} + impl PartialEq> for Message { fn eq(&self, other: &Vec) -> bool { self.eq(&other.as_slice()) diff --git a/src/event_scanner/scanner/common.rs b/src/event_scanner/scanner/common.rs index 83b734b1..6bae3d9f 100644 --- a/src/event_scanner/scanner/common.rs +++ b/src/event_scanner/scanner/common.rs @@ -3,13 +3,12 @@ use std::ops::RangeInclusive; use crate::{ block_range_scanner::{MAX_BUFFERED_MESSAGES, Message as BlockRangeMessage}, event_scanner::{filter::EventFilter, listener::EventListener}, + robust_provider::{Error as RobustProviderError, RobustProvider}, types::TryStream, }; use alloy::{ network::Network, - providers::{Provider, RootProvider}, rpc::types::{Filter, Log}, - transports::{RpcError, TransportErrorKind}, }; use tokio::{ sync::broadcast::{self, Sender, error::RecvError}, @@ -48,7 +47,7 @@ pub enum ConsumerMode { /// Assumes it is running in a separate tokio task, so as to be non-blocking. pub async fn handle_stream + Unpin>( mut stream: S, - provider: &RootProvider, + provider: &RobustProvider, listeners: &[EventListener], mode: ConsumerMode, ) { @@ -72,7 +71,7 @@ pub async fn handle_stream + Unp #[must_use] pub fn spawn_log_consumers( - provider: &RootProvider, + provider: &RobustProvider, listeners: &[EventListener], range_tx: &Sender, mode: ConsumerMode, @@ -166,8 +165,8 @@ async fn get_logs( range: RangeInclusive, event_filter: &EventFilter, log_filter: &Filter, - provider: &RootProvider, -) -> Result, RpcError> { + provider: &RobustProvider, +) -> Result, RobustProviderError> { let log_filter = log_filter.clone().from_block(*range.start()).to_block(*range.end()); match provider.get_logs(&log_filter).await { diff --git a/src/event_scanner/scanner/sync/from_latest.rs b/src/event_scanner/scanner/sync/from_latest.rs index 91626c18..e266b699 100644 --- a/src/event_scanner/scanner/sync/from_latest.rs +++ b/src/event_scanner/scanner/sync/from_latest.rs @@ -2,7 +2,6 @@ use alloy::{ consensus::BlockHeader, eips::BlockNumberOrTag, network::{BlockResponse, Network}, - providers::Provider, }; use tokio::sync::mpsc; @@ -57,12 +56,8 @@ impl EventScanner { // This is used to determine the starting point for the rewind stream and the live // stream. We do this before starting the streams to avoid a race condition // where the latest block changes while we're setting up the streams. - let latest_block = provider - .get_block_by_number(BlockNumberOrTag::Latest) - .await? - .ok_or(ScannerError::BlockNotFound(BlockNumberOrTag::Latest))? - .header() - .number(); + let latest_block = + provider.get_block_by_number(BlockNumberOrTag::Latest).await?.header().number(); // Setup rewind and live streams to run in parallel. let rewind_stream = client.rewind(BlockNumberOrTag::Earliest, latest_block).await?; diff --git a/src/lib.rs b/src/lib.rs index cb699aad..1b374057 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod block_range_scanner; +mod robust_provider; #[cfg(any(test, feature = "test-utils"))] pub mod test_utils; diff --git a/src/robust_provider.rs b/src/robust_provider.rs new file mode 100644 index 00000000..f66158d4 --- /dev/null +++ b/src/robust_provider.rs @@ -0,0 +1,289 @@ +use std::{future::Future, sync::Arc, time::Duration}; + +use alloy::{ + eips::{BlockId, BlockNumberOrTag}, + network::Network, + providers::{Provider, RootProvider}, + pubsub::Subscription, + rpc::types::{Filter, Log}, + transports::{RpcError, TransportErrorKind}, +}; +use backon::{ExponentialBuilder, Retryable}; +use thiserror::Error; +use tracing::{error, info}; + +#[derive(Error, Debug, Clone)] +pub enum Error { + #[error("Operation timed out")] + Timeout, + #[error("RPC call failed after exhausting all retry attempts: {0}")] + RetryFailure(Arc>), + #[error("Block not found, Block Id: {0}")] + BlockNotFound(BlockId), +} + +impl From> for Error { + fn from(err: RpcError) -> Self { + Error::RetryFailure(Arc::new(err)) + } +} + +/// Provider wrapper with built-in retry and timeout mechanisms. +/// +/// This wrapper around Alloy providers automatically handles retries, +/// timeouts, and error logging for RPC calls. +#[derive(Clone)] +pub struct RobustProvider { + provider: RootProvider, + max_timeout: Duration, + max_retries: usize, + retry_interval: Duration, +} + +// RPC retry and timeout settings +/// Default timeout used by `RobustProvider` +pub const DEFAULT_MAX_TIMEOUT: Duration = Duration::from_secs(30); +/// Default maximum number of retry attempts. +pub const DEFAULT_MAX_RETRIES: usize = 5; +/// Default base delay between retries. +pub const DEFAULT_RETRY_INTERVAL: Duration = Duration::from_secs(1); + +impl RobustProvider { + /// Create a new `RobustProvider` with default settings. + #[must_use] + pub fn new(provider: RootProvider) -> Self { + Self { + provider, + max_timeout: DEFAULT_MAX_TIMEOUT, + max_retries: DEFAULT_MAX_RETRIES, + retry_interval: DEFAULT_RETRY_INTERVAL, + } + } + + #[must_use] + pub fn max_timeout(mut self, timeout: Duration) -> Self { + self.max_timeout = timeout; + self + } + + #[must_use] + pub fn max_retries(mut self, max_retries: usize) -> Self { + self.max_retries = max_retries; + self + } + + #[must_use] + pub fn retry_interval(mut self, retry_interval: Duration) -> Self { + self.retry_interval = retry_interval; + self + } + + /// Fetch a block by number with retry and timeout. + /// + /// # Errors + /// + /// Returns an error if RPC call fails repeatedly even + /// after exhausting retries or if the call times out. + pub async fn get_block_by_number( + &self, + number: BlockNumberOrTag, + ) -> Result { + info!("eth_getBlockByNumber called"); + let operation = async || self.provider.get_block_by_number(number).await; + let result = self.retry_with_total_timeout(operation).await; + if let Err(e) = &result { + error!(error = %e, "eth_getByBlockNumber failed"); + } + + result?.ok_or_else(|| Error::BlockNotFound(number.into())) + } + + /// Fetch the latest block number with retry and timeout. + /// + /// # Errors + /// + /// Returns an error if RPC call fails repeatedly even + /// after exhausting retries or if the call times out. + pub async fn get_block_number(&self) -> Result { + info!("eth_getBlockNumber called"); + let operation = async || self.provider.get_block_number().await; + let result = self.retry_with_total_timeout(operation).await; + if let Err(e) = &result { + error!(error = %e, "eth_getBlockNumber failed"); + } + result + } + + /// Fetch a block by hash with retry and timeout. + /// + /// # Errors + /// + /// Returns an error if RPC call fails repeatedly even + /// after exhausting retries or if the call times out. + pub async fn get_block_by_hash( + &self, + hash: alloy::primitives::BlockHash, + ) -> Result { + info!("eth_getBlockByHash called"); + let operation = async || self.provider.get_block_by_hash(hash).await; + let result = self.retry_with_total_timeout(operation).await; + if let Err(e) = &result { + error!(error = %e, "eth_getBlockByHash failed"); + } + + result?.ok_or_else(|| Error::BlockNotFound(hash.into())) + } + + /// Fetch logs for the given filter with retry and timeout. + /// + /// # Errors + /// + /// Returns an error if RPC call fails repeatedly even + /// after exhausting retries or if the call times out. + pub async fn get_logs(&self, filter: &Filter) -> Result, Error> { + info!("eth_getLogs called"); + let operation = async || self.provider.get_logs(filter).await; + let result = self.retry_with_total_timeout(operation).await; + if let Err(e) = &result { + error!(error = %e, "eth_getLogs failed"); + } + result + } + + /// Subscribe to new block headers with retry and timeout. + /// + /// # Errors + /// + /// Returns an error if RPC call fails repeatedly even + /// after exhausting retries or if the call times out. + pub async fn subscribe_blocks(&self) -> Result, Error> { + info!("eth_subscribe called"); + let operation = async || self.provider.subscribe_blocks().await; + let result = self.retry_with_total_timeout(operation).await; + if let Err(e) = &result { + error!(error = %e, "eth_subscribe failed"); + } + result + } + + /// Execute `operation` with exponential backoff and a total timeout. + /// + /// Wraps the retry logic with `tokio::time::timeout(self.max_timeout, ...)` so + /// the entire operation (including time spent inside the RPC call) cannot exceed + /// `max_timeout`. + /// + /// # Errors + /// + /// - Returns [`RpcError`] with message "total operation timeout exceeded" + /// if the overall timeout elapses. + /// - Propagates any [`RpcError`] from the underlying retries. + async fn retry_with_total_timeout(&self, operation: F) -> Result + where + F: Fn() -> Fut, + Fut: Future>>, + { + let retry_strategy = ExponentialBuilder::default() + .with_max_times(self.max_retries) + .with_min_delay(self.retry_interval); + + match tokio::time::timeout( + self.max_timeout, + operation.retry(retry_strategy).sleep(tokio::time::sleep), + ) + .await + { + Ok(res) => res.map_err(Error::from), + Err(_) => Err(Error::Timeout), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloy::network::Ethereum; + use std::sync::atomic::{AtomicUsize, Ordering}; + use tokio::time::sleep; + + fn test_provider( + timeout: u64, + max_retries: usize, + retry_interval: u64, + ) -> RobustProvider { + RobustProvider { + provider: RootProvider::new_http("http://localhost:8545".parse().unwrap()), + max_timeout: Duration::from_millis(timeout), + max_retries, + retry_interval: Duration::from_millis(retry_interval), + } + } + + #[tokio::test] + async fn test_retry_with_timeout_succeeds_on_first_attempt() { + let provider = test_provider(100, 3, 10); + + let call_count = AtomicUsize::new(0); + + let result = provider + .retry_with_total_timeout(|| async { + call_count.fetch_add(1, Ordering::SeqCst); + let count = call_count.load(Ordering::SeqCst); + Ok(count) + }) + .await; + + assert!(matches!(result, Ok(1))); + } + + #[tokio::test] + async fn test_retry_with_timeout_retries_on_error() { + let provider = test_provider(100, 3, 10); + + let call_count = AtomicUsize::new(0); + + let result = provider + .retry_with_total_timeout(|| async { + call_count.fetch_add(1, Ordering::SeqCst); + let count = call_count.load(Ordering::SeqCst); + match count { + 3 => Ok(count), + _ => Err(TransportErrorKind::BackendGone.into()), + } + }) + .await; + + assert!(matches!(result, Ok(3))); + } + + #[tokio::test] + async fn test_retry_with_timeout_fails_after_max_retries() { + let provider = test_provider(100, 2, 10); + + let call_count = AtomicUsize::new(0); + + let result: Result<(), Error> = provider + .retry_with_total_timeout(|| async { + call_count.fetch_add(1, Ordering::SeqCst); + Err(TransportErrorKind::BackendGone.into()) + }) + .await; + + assert!(matches!(result, Err(Error::RetryFailure(_)))); + assert_eq!(call_count.load(Ordering::SeqCst), 3); + } + + #[tokio::test] + async fn test_retry_with_timeout_respects_max_timeout() { + let max_timeout = 50; + let provider = test_provider(max_timeout, 10, 1); + + let result = provider + .retry_with_total_timeout(|| async { + sleep(Duration::from_millis(max_timeout + 10)).await; + Ok(42) + }) + .await; + + assert!(matches!(result, Err(Error::Timeout))); + } +}