Skip to content

Commit 85121a7

Browse files
committed
Add rtt to Client
1 parent d46f745 commit 85121a7

File tree

4 files changed

+117
-0
lines changed

4 files changed

+117
-0
lines changed

.config/nats.dic

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,4 @@ RequestErrorKind
135135
rustls
136136
Acker
137137
EndpointSchema
138+
RttError

async-nats/src/client.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,35 @@ impl Client {
485485
Ok(())
486486
}
487487

488+
/// Calculates the round trip time between this client and the server,
489+
/// if the server is currently connected.
490+
///
491+
/// # Examples
492+
///
493+
/// ```no_run
494+
/// # #[tokio::main]
495+
/// # async fn main() -> Result<(), async_nats::Error> {
496+
/// let client = async_nats::connect("demo.nats.io").await?;
497+
/// let rtt = client.rtt().await?;
498+
/// println!("server rtt: {:?}", rtt);
499+
/// # Ok(())
500+
/// # }
501+
/// ```
502+
pub async fn rtt(&self) -> Result<Duration, RttError> {
503+
let (tx, rx) = tokio::sync::oneshot::channel();
504+
505+
self.sender.send(Command::Rtt { result: tx }).await?;
506+
507+
let rtt = rx
508+
.await
509+
// first handle rx error
510+
.map_err(|err| RttError(Box::new(err)))?
511+
// second handle the actual rtt error
512+
.map_err(|err| RttError(Box::new(err)))?;
513+
514+
Ok(rtt)
515+
}
516+
488517
/// Returns the current state of the connection.
489518
///
490519
/// # Examples
@@ -713,3 +742,14 @@ impl From<SubscribeError> for RequestError {
713742
RequestError::with_source(RequestErrorKind::Other, e)
714743
}
715744
}
745+
746+
/// Error returned when doing a round-trip time measurement fails.
747+
#[derive(Debug, Error)]
748+
#[error("failed to measure round-trip time: {0}")]
749+
pub struct RttError(#[source] Box<dyn std::error::Error + Send + Sync>);
750+
751+
impl From<tokio::sync::mpsc::error::SendError<Command>> for RttError {
752+
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
753+
RttError(Box::new(err))
754+
}
755+
}

async-nats/src/lib.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ use thiserror::Error;
104104
use futures::future::FutureExt;
105105
use futures::select;
106106
use futures::stream::Stream;
107+
use std::time::Instant;
107108
use tracing::{debug, error};
108109

109110
use core::fmt;
@@ -259,6 +260,9 @@ pub(crate) enum Command {
259260
result: oneshot::Sender<Result<(), io::Error>>,
260261
},
261262
TryFlush,
263+
Rtt {
264+
result: oneshot::Sender<Result<Duration, io::Error>>,
265+
},
262266
}
263267

264268
/// `ClientOp` represents all actions of `Client`.
@@ -302,6 +306,9 @@ pub(crate) struct ConnectionHandler {
302306
info_sender: tokio::sync::watch::Sender<ServerInfo>,
303307
ping_interval: Interval,
304308
flush_interval: Interval,
309+
last_ping_time: Option<Instant>,
310+
last_pong_time: Option<Instant>,
311+
rtt_senders: Vec<oneshot::Sender<Result<Duration, io::Error>>>,
305312
}
306313

307314
impl ConnectionHandler {
@@ -326,6 +333,9 @@ impl ConnectionHandler {
326333
info_sender,
327334
ping_interval,
328335
flush_interval,
336+
last_ping_time: None,
337+
last_pong_time: None,
338+
rtt_senders: Vec::new(),
329339
}
330340
}
331341

@@ -404,6 +414,22 @@ impl ConnectionHandler {
404414
}
405415
ServerOp::Pong => {
406416
debug!("received PONG");
417+
if self.pending_pings == 1 {
418+
self.last_pong_time = Some(Instant::now());
419+
420+
while let Some(sender) = self.rtt_senders.pop() {
421+
if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time)
422+
{
423+
let rtt = pong.duration_since(ping);
424+
sender.send(Ok(rtt)).map_err(|_| {
425+
io::Error::new(
426+
io::ErrorKind::Other,
427+
"one shot failed to be received",
428+
)
429+
})?;
430+
}
431+
}
432+
}
407433
self.pending_pings = self.pending_pings.saturating_sub(1);
408434
}
409435
ServerOp::Error(error) => {
@@ -517,6 +543,14 @@ impl ConnectionHandler {
517543
}
518544
}
519545
}
546+
Command::Rtt { result } => {
547+
self.rtt_senders.push(result);
548+
549+
if self.pending_pings == 0 {
550+
// do a ping and expect a pong - will calculate rtt when handling the pong
551+
self.handle_ping().await?;
552+
}
553+
}
520554
Command::Flush { result } => {
521555
if let Err(_err) = self.handle_flush().await {
522556
if let Err(err) = self.handle_disconnect().await {
@@ -591,8 +625,39 @@ impl ConnectionHandler {
591625
Ok(())
592626
}
593627

628+
async fn handle_ping(&mut self) -> Result<(), io::Error> {
629+
debug!(
630+
"PING command. Pending pings {}, max pings {}",
631+
self.pending_pings, MAX_PENDING_PINGS
632+
);
633+
self.pending_pings += 1;
634+
self.ping_interval.reset();
635+
636+
if self.pending_pings > MAX_PENDING_PINGS {
637+
debug!(
638+
"pending pings {}, max pings {}. disconnecting",
639+
self.pending_pings, MAX_PENDING_PINGS
640+
);
641+
self.handle_disconnect().await?;
642+
}
643+
644+
if self.pending_pings == 1 {
645+
// start the clock for calculating round trip time
646+
self.last_ping_time = Some(Instant::now());
647+
}
648+
649+
if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
650+
self.handle_disconnect().await?;
651+
}
652+
653+
self.handle_flush().await?;
654+
Ok(())
655+
}
656+
594657
async fn handle_disconnect(&mut self) -> io::Result<()> {
595658
self.pending_pings = 0;
659+
self.last_ping_time = None;
660+
self.last_pong_time = None;
596661
self.connector.events_tx.try_send(Event::Disconnected).ok();
597662
self.connector.state_tx.send(State::Disconnected).ok();
598663
self.handle_reconnect().await?;

async-nats/tests/client_tests.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,4 +813,15 @@ mod client {
813813
drop(servers.remove(0));
814814
rx.recv().await;
815815
}
816+
817+
#[tokio::test]
818+
async fn rtt() {
819+
let server = nats_server::run_basic_server();
820+
let client = async_nats::connect(server.client_url()).await.unwrap();
821+
822+
let rtt = client.rtt().await.unwrap();
823+
824+
println!("rtt: {:?}", rtt);
825+
assert!(rtt.as_nanos() > 0);
826+
}
816827
}

0 commit comments

Comments
 (0)