Skip to content

Commit 97f9721

Browse files
committed
Add rtt to Client
1 parent d46f745 commit 97f9721

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

.config/nats.dic

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,4 @@ RequestErrorKind
135135
rustls
136136
Acker
137137
EndpointSchema
138+
RttError

async-nats/src/client.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,35 @@ impl Client {
485485
Ok(())
486486
}
487487

488+
/// Calculates the round trip time between this client and the server,
489+
/// if the server is currently connected.
490+
///
491+
/// # Examples
492+
///
493+
/// ```no_run
494+
/// # #[tokio::main]
495+
/// # async fn main() -> Result<(), async_nats::Error> {
496+
/// let client = async_nats::connect("demo.nats.io").await?;
497+
/// let rtt = client.rtt().await?;
498+
/// println!("server rtt: {:?}", rtt);
499+
/// # Ok(())
500+
/// # }
501+
/// ```
502+
pub async fn rtt(&self) -> Result<Duration, RttError> {
503+
let (tx, rx) = tokio::sync::oneshot::channel();
504+
505+
self.sender.send(Command::Rtt { result: tx }).await?;
506+
507+
let rtt = rx
508+
.await
509+
// first handle rx error
510+
.map_err(|err| RttError(Box::new(err)))?
511+
// second handle the actual rtt error
512+
.map_err(|err| RttError(Box::new(err)))?;
513+
514+
Ok(rtt)
515+
}
516+
488517
/// Returns the current state of the connection.
489518
///
490519
/// # Examples
@@ -713,3 +742,14 @@ impl From<SubscribeError> for RequestError {
713742
RequestError::with_source(RequestErrorKind::Other, e)
714743
}
715744
}
745+
746+
/// Error returned when doing a round-trip time measurement fails.
747+
#[derive(Debug, Error)]
748+
#[error("failed to measure round-trip time: {0}")]
749+
pub struct RttError(#[source] Box<dyn std::error::Error + Send + Sync>);
750+
751+
impl From<tokio::sync::mpsc::error::SendError<Command>> for RttError {
752+
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
753+
RttError(Box::new(err))
754+
}
755+
}

async-nats/src/lib.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ use thiserror::Error;
104104
use futures::future::FutureExt;
105105
use futures::select;
106106
use futures::stream::Stream;
107+
use std::time::Instant;
107108
use tracing::{debug, error};
108109

109110
use core::fmt;
@@ -259,6 +260,9 @@ pub(crate) enum Command {
259260
result: oneshot::Sender<Result<(), io::Error>>,
260261
},
261262
TryFlush,
263+
Rtt {
264+
result: oneshot::Sender<Result<Duration, io::Error>>,
265+
},
262266
}
263267

264268
/// `ClientOp` represents all actions of `Client`.
@@ -302,6 +306,9 @@ pub(crate) struct ConnectionHandler {
302306
info_sender: tokio::sync::watch::Sender<ServerInfo>,
303307
ping_interval: Interval,
304308
flush_interval: Interval,
309+
last_ping_time: Option<Instant>,
310+
last_pong_time: Option<Instant>,
311+
rtt_senders: Vec<oneshot::Sender<Result<Duration, io::Error>>>,
305312
}
306313

307314
impl ConnectionHandler {
@@ -326,6 +333,9 @@ impl ConnectionHandler {
326333
info_sender,
327334
ping_interval,
328335
flush_interval,
336+
last_ping_time: None,
337+
last_pong_time: None,
338+
rtt_senders: Vec::new(),
329339
}
330340
}
331341

@@ -404,6 +414,23 @@ impl ConnectionHandler {
404414
}
405415
ServerOp::Pong => {
406416
debug!("received PONG");
417+
if self.pending_pings == 1 {
418+
// Do we even need to store the last_pong_time?
419+
self.last_pong_time = Some(Instant::now());
420+
421+
while let Some(sender) = self.rtt_senders.pop() {
422+
if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time)
423+
{
424+
let rtt = pong.duration_since(ping);
425+
sender.send(Ok(rtt)).map_err(|_| {
426+
io::Error::new(
427+
io::ErrorKind::Other,
428+
"one shot failed to be received",
429+
)
430+
})?;
431+
}
432+
}
433+
}
407434
self.pending_pings = self.pending_pings.saturating_sub(1);
408435
}
409436
ServerOp::Error(error) => {
@@ -517,6 +544,14 @@ impl ConnectionHandler {
517544
}
518545
}
519546
}
547+
Command::Rtt { result } => {
548+
self.rtt_senders.push(result);
549+
550+
if self.pending_pings == 0 {
551+
// do a ping and expect a pong - will calculate rtt when handling the pong
552+
self.handle_ping().await?;
553+
}
554+
}
520555
Command::Flush { result } => {
521556
if let Err(_err) = self.handle_flush().await {
522557
if let Err(err) = self.handle_disconnect().await {
@@ -591,8 +626,39 @@ impl ConnectionHandler {
591626
Ok(())
592627
}
593628

629+
async fn handle_ping(&mut self) -> Result<(), io::Error> {
630+
debug!(
631+
"PING command. Pending pings {}, max pings {}",
632+
self.pending_pings, MAX_PENDING_PINGS
633+
);
634+
self.pending_pings += 1;
635+
self.ping_interval.reset();
636+
637+
if self.pending_pings > MAX_PENDING_PINGS {
638+
debug!(
639+
"pending pings {}, max pings {}. disconnecting",
640+
self.pending_pings, MAX_PENDING_PINGS
641+
);
642+
self.handle_disconnect().await?;
643+
}
644+
645+
if self.pending_pings == 1 {
646+
// start the clock for calculating round trip time
647+
self.last_ping_time = Some(Instant::now());
648+
}
649+
650+
if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
651+
self.handle_disconnect().await?;
652+
}
653+
654+
self.handle_flush().await?;
655+
Ok(())
656+
}
657+
594658
async fn handle_disconnect(&mut self) -> io::Result<()> {
595659
self.pending_pings = 0;
660+
self.last_ping_time = None;
661+
self.last_pong_time = None;
596662
self.connector.events_tx.try_send(Event::Disconnected).ok();
597663
self.connector.state_tx.send(State::Disconnected).ok();
598664
self.handle_reconnect().await?;

async-nats/tests/client_tests.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,4 +813,15 @@ mod client {
813813
drop(servers.remove(0));
814814
rx.recv().await;
815815
}
816+
817+
#[tokio::test]
818+
async fn rtt() {
819+
let server = nats_server::run_basic_server();
820+
let client = async_nats::connect(server.client_url()).await.unwrap();
821+
822+
let rtt = client.rtt().await.unwrap();
823+
824+
println!("rtt: {:?}", rtt);
825+
assert!(rtt.as_nanos() > 0);
826+
}
816827
}

0 commit comments

Comments
 (0)