Skip to content

Commit 33bccae

Browse files
committed
Add rtt to Client
1 parent 87d7f04 commit 33bccae

File tree

4 files changed

+150
-19
lines changed

4 files changed

+150
-19
lines changed

.config/nats.dic

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,4 @@ ConnectError
133133
DNS
134134
RequestErrorKind
135135
rustls
136+
RttError

async-nats/src/client.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,38 @@ impl Client {
463463
Ok(())
464464
}
465465

466+
/// Calculates the round trip time between this client and the server,
467+
/// if the server is currently connected.
468+
///
469+
/// # Examples
470+
///
471+
/// ```no_run
472+
/// # #[tokio::main]
473+
/// # async fn main() -> Result<(), async_nats::Error> {
474+
/// let client = async_nats::connect("demo.nats.io").await?;
475+
/// let rtt = client.rtt().await?;
476+
/// println!("server rtt: {:?}", rtt);
477+
/// # Ok(())
478+
/// # }
479+
/// ```
480+
pub async fn rtt(&self) -> Result<Duration, RttError> {
481+
let (tx, rx) = tokio::sync::oneshot::channel();
482+
483+
self.sender
484+
.send(Command::Rtt { result: tx })
485+
.await
486+
.map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?;
487+
488+
let rtt = rx
489+
.await
490+
// first handle rx error
491+
.map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?
492+
// second handle the atual ping error
493+
.map_err(|err| RttError::with_source(RttErrorKind::PingError, err))?;
494+
495+
Ok(rtt)
496+
}
497+
466498
/// Returns the current state of the connection.
467499
///
468500
/// # Examples
@@ -688,3 +720,48 @@ impl From<SubscribeError> for RequestError {
688720
RequestError::with_source(RequestErrorKind::Other, e)
689721
}
690722
}
723+
724+
/// Error returned when doing a round-trip time measurement fails.
725+
/// To enumerate over the variants, call [RttError::kind].
726+
#[derive(Debug, Error)]
727+
pub struct RttError {
728+
kind: RttErrorKind,
729+
source: Option<Box<dyn std::error::Error + Send + Sync>>,
730+
}
731+
732+
impl Display for RttError {
733+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
734+
let source_info = self
735+
.source
736+
.as_ref()
737+
.map(|e| e.to_string())
738+
.unwrap_or_else(|| "no details".into());
739+
match self.kind {
740+
RttErrorKind::PingError => {
741+
write!(f, "failed to ping server: {}", source_info)
742+
}
743+
RttErrorKind::Other => write!(f, "rtt failed: {}", source_info),
744+
}
745+
}
746+
}
747+
748+
impl RttError {
749+
fn with_source<E>(kind: RttErrorKind, source: E) -> RttError
750+
where
751+
E: Into<Box<dyn std::error::Error + Send + Sync>>,
752+
{
753+
RttError {
754+
kind,
755+
source: Some(source.into()),
756+
}
757+
}
758+
pub fn kind(&self) -> RttErrorKind {
759+
self.kind
760+
}
761+
}
762+
763+
#[derive(Debug, PartialEq, Clone, Copy)]
764+
pub enum RttErrorKind {
765+
PingError,
766+
Other,
767+
}

async-nats/src/lib.rs

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ use thiserror::Error;
105105
use futures::future::FutureExt;
106106
use futures::select;
107107
use futures::stream::Stream;
108+
use std::time::Instant;
108109
use tracing::{debug, error};
109110

110111
use core::fmt;
@@ -261,6 +262,9 @@ pub enum Command {
261262
},
262263
TryFlush,
263264
Connect(ConnectInfo),
265+
Rtt {
266+
result: oneshot::Sender<Result<Duration, io::Error>>,
267+
},
264268
}
265269

266270
/// `ClientOp` represents all actions of `Client`.
@@ -301,10 +305,13 @@ pub(crate) struct ConnectionHandler {
301305
connector: Connector,
302306
subscriptions: HashMap<u64, Subscription>,
303307
pending_pings: usize,
308+
pending_pongs: usize,
304309
max_pings: usize,
305310
info_sender: tokio::sync::watch::Sender<ServerInfo>,
306311
ping_interval: Interval,
307312
flush_interval: Interval,
313+
rtt_instant: Option<Instant>,
314+
rtt_sender: Option<oneshot::Sender<Result<Duration, io::Error>>>,
308315
}
309316

310317
impl ConnectionHandler {
@@ -326,10 +333,13 @@ impl ConnectionHandler {
326333
connector,
327334
subscriptions: HashMap::new(),
328335
pending_pings: 0,
336+
pending_pongs: 0,
329337
max_pings: 2,
330338
info_sender,
331339
ping_interval,
332340
flush_interval,
341+
rtt_instant: None,
342+
rtt_sender: None,
333343
}
334344
}
335345

@@ -398,6 +408,18 @@ impl ConnectionHandler {
398408
ServerOp::Pong => {
399409
debug!("received PONG");
400410
self.pending_pings = self.pending_pings.saturating_sub(1);
411+
412+
if self.pending_pongs == 1 {
413+
if let (Some(sender), Some(rtt)) = (self.rtt_sender.take(), self.rtt_instant) {
414+
sender.send(Ok(rtt.elapsed())).map_err(|_| {
415+
io::Error::new(io::ErrorKind::Other, "one shot failed to be received")
416+
})?;
417+
}
418+
419+
// reset the pending pongs (we have at most 1 at any given moment to measure rtt)
420+
self.pending_pongs = 0;
421+
self.rtt_instant = None;
422+
}
401423
}
402424
ServerOp::Error(error) => {
403425
self.connector
@@ -509,26 +531,17 @@ impl ConnectionHandler {
509531
}
510532
}
511533
Command::Ping => {
512-
debug!(
513-
"PING command. Pending pings {}, max pings {}",
514-
self.pending_pings, self.max_pings
515-
);
516-
self.pending_pings += 1;
517-
self.ping_interval.reset();
518-
519-
if self.pending_pings > self.max_pings {
520-
debug!(
521-
"pending pings {}, max pings {}. disconnecting",
522-
self.pending_pings, self.max_pings
523-
);
524-
self.handle_disconnect().await?;
525-
}
526-
527-
if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
528-
self.handle_disconnect().await?;
534+
self.handle_ping().await?;
535+
}
536+
Command::Rtt { result } => {
537+
self.rtt_sender = Some(result);
538+
539+
if self.pending_pongs == 0 {
540+
// start the clock for calculating round trip time
541+
self.rtt_instant = Some(Instant::now());
542+
// do a ping and stop clock when handling pong
543+
self.handle_ping().await?;
529544
}
530-
531-
self.handle_flush().await?;
532545
}
533546
Command::Flush { result } => {
534547
if let Err(_err) = self.handle_flush().await {
@@ -613,8 +626,37 @@ impl ConnectionHandler {
613626
Ok(())
614627
}
615628

629+
async fn handle_ping(&mut self) -> Result<(), io::Error> {
630+
debug!(
631+
"PING command. Pending pings {}, max pings {}",
632+
self.pending_pings, self.max_pings
633+
);
634+
self.pending_pings += 1;
635+
self.ping_interval.reset();
636+
637+
if self.pending_pongs == 0 {
638+
self.pending_pongs = 1;
639+
}
640+
641+
if self.pending_pings > self.max_pings {
642+
debug!(
643+
"pending pings {}, max pings {}. disconnecting",
644+
self.pending_pings, self.max_pings
645+
);
646+
self.handle_disconnect().await?;
647+
}
648+
649+
if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
650+
self.handle_disconnect().await?;
651+
}
652+
653+
self.handle_flush().await?;
654+
Ok(())
655+
}
656+
616657
async fn handle_disconnect(&mut self) -> io::Result<()> {
617658
self.pending_pings = 0;
659+
self.pending_pongs = 0;
618660
self.connector.events_tx.try_send(Event::Disconnected).ok();
619661
self.connector.state_tx.send(State::Disconnected).ok();
620662
self.handle_reconnect().await?;

async-nats/tests/client_tests.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,4 +764,15 @@ mod client {
764764
drop(servers.remove(0));
765765
rx.recv().await;
766766
}
767+
768+
#[tokio::test]
769+
async fn rtt() {
770+
let server = nats_server::run_basic_server();
771+
let client = async_nats::connect(server.client_url()).await.unwrap();
772+
773+
let rtt = client.rtt().await.unwrap();
774+
775+
println!("rtt: {:?}", rtt);
776+
assert!(rtt.as_nanos() > 0);
777+
}
767778
}

0 commit comments

Comments
 (0)