Skip to content

Commit c0dfa53

Browse files
committed
Add rtt to Client
1 parent 625d1da commit c0dfa53

File tree

4 files changed

+120
-0
lines changed

4 files changed

+120
-0
lines changed

.config/nats.dic

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,12 @@ RequestErrorKind
135135
rustls
136136
Acker
137137
EndpointSchema
138+
<<<<<<< HEAD
138139
auth
139140
filter_subject
140141
filter_subjects
141142
rollup
142143
IoT
144+
=======
145+
RttError
146+
>>>>>>> 85121a7 (Add `rtt` to `Client`)

async-nats/src/client.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,35 @@ impl Client {
485485
Ok(())
486486
}
487487

488+
/// Calculates the round trip time between this client and the server,
489+
/// if the server is currently connected.
490+
///
491+
/// # Examples
492+
///
493+
/// ```no_run
494+
/// # #[tokio::main]
495+
/// # async fn main() -> Result<(), async_nats::Error> {
496+
/// let client = async_nats::connect("demo.nats.io").await?;
497+
/// let rtt = client.rtt().await?;
498+
/// println!("server rtt: {:?}", rtt);
499+
/// # Ok(())
500+
/// # }
501+
/// ```
502+
pub async fn rtt(&self) -> Result<Duration, RttError> {
503+
let (tx, rx) = tokio::sync::oneshot::channel();
504+
505+
self.sender.send(Command::Rtt { result: tx }).await?;
506+
507+
let rtt = rx
508+
.await
509+
// first handle rx error
510+
.map_err(|err| RttError(Box::new(err)))?
511+
// second handle the actual rtt error
512+
.map_err(|err| RttError(Box::new(err)))?;
513+
514+
Ok(rtt)
515+
}
516+
488517
/// Returns the current state of the connection.
489518
///
490519
/// # Examples
@@ -684,3 +713,14 @@ impl From<SubscribeError> for RequestError {
684713
RequestError::with_source(RequestErrorKind::Other, e)
685714
}
686715
}
716+
717+
/// Error returned when doing a round-trip time measurement fails.
718+
#[derive(Debug, Error)]
719+
#[error("failed to measure round-trip time: {0}")]
720+
pub struct RttError(#[source] Box<dyn std::error::Error + Send + Sync>);
721+
722+
impl From<tokio::sync::mpsc::error::SendError<Command>> for RttError {
723+
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
724+
RttError(Box::new(err))
725+
}
726+
}

async-nats/src/lib.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ use thiserror::Error;
124124
use futures::future::FutureExt;
125125
use futures::select;
126126
use futures::stream::Stream;
127+
use std::time::Instant;
127128
use tracing::{debug, error};
128129

129130
use core::fmt;
@@ -280,6 +281,9 @@ pub(crate) enum Command {
280281
result: oneshot::Sender<Result<(), io::Error>>,
281282
},
282283
TryFlush,
284+
Rtt {
285+
result: oneshot::Sender<Result<Duration, io::Error>>,
286+
},
283287
}
284288

285289
/// `ClientOp` represents all actions of `Client`.
@@ -323,6 +327,9 @@ pub(crate) struct ConnectionHandler {
323327
info_sender: tokio::sync::watch::Sender<ServerInfo>,
324328
ping_interval: Interval,
325329
flush_interval: Interval,
330+
last_ping_time: Option<Instant>,
331+
last_pong_time: Option<Instant>,
332+
rtt_senders: Vec<oneshot::Sender<Result<Duration, io::Error>>>,
326333
}
327334

328335
impl ConnectionHandler {
@@ -347,6 +354,9 @@ impl ConnectionHandler {
347354
info_sender,
348355
ping_interval,
349356
flush_interval,
357+
last_ping_time: None,
358+
last_pong_time: None,
359+
rtt_senders: Vec::new(),
350360
}
351361
}
352362

@@ -425,6 +435,22 @@ impl ConnectionHandler {
425435
}
426436
ServerOp::Pong => {
427437
debug!("received PONG");
438+
if self.pending_pings == 1 {
439+
self.last_pong_time = Some(Instant::now());
440+
441+
while let Some(sender) = self.rtt_senders.pop() {
442+
if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time)
443+
{
444+
let rtt = pong.duration_since(ping);
445+
sender.send(Ok(rtt)).map_err(|_| {
446+
io::Error::new(
447+
io::ErrorKind::Other,
448+
"one shot failed to be received",
449+
)
450+
})?;
451+
}
452+
}
453+
}
428454
self.pending_pings = self.pending_pings.saturating_sub(1);
429455
}
430456
ServerOp::Error(error) => {
@@ -538,6 +564,14 @@ impl ConnectionHandler {
538564
}
539565
}
540566
}
567+
Command::Rtt { result } => {
568+
self.rtt_senders.push(result);
569+
570+
if self.pending_pings == 0 {
571+
// do a ping and expect a pong - will calculate rtt when handling the pong
572+
self.handle_ping().await?;
573+
}
574+
}
541575
Command::Flush { result } => {
542576
if let Err(_err) = self.handle_flush().await {
543577
if let Err(err) = self.handle_disconnect().await {
@@ -612,8 +646,39 @@ impl ConnectionHandler {
612646
Ok(())
613647
}
614648

649+
async fn handle_ping(&mut self) -> Result<(), io::Error> {
650+
debug!(
651+
"PING command. Pending pings {}, max pings {}",
652+
self.pending_pings, MAX_PENDING_PINGS
653+
);
654+
self.pending_pings += 1;
655+
self.ping_interval.reset();
656+
657+
if self.pending_pings > MAX_PENDING_PINGS {
658+
debug!(
659+
"pending pings {}, max pings {}. disconnecting",
660+
self.pending_pings, MAX_PENDING_PINGS
661+
);
662+
self.handle_disconnect().await?;
663+
}
664+
665+
if self.pending_pings == 1 {
666+
// start the clock for calculating round trip time
667+
self.last_ping_time = Some(Instant::now());
668+
}
669+
670+
if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
671+
self.handle_disconnect().await?;
672+
}
673+
674+
self.handle_flush().await?;
675+
Ok(())
676+
}
677+
615678
async fn handle_disconnect(&mut self) -> io::Result<()> {
616679
self.pending_pings = 0;
680+
self.last_ping_time = None;
681+
self.last_pong_time = None;
617682
self.connector.events_tx.try_send(Event::Disconnected).ok();
618683
self.connector.state_tx.send(State::Disconnected).ok();
619684
self.handle_reconnect().await?;

async-nats/tests/client_tests.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,4 +867,15 @@ mod client {
867867
.await
868868
.unwrap();
869869
}
870+
871+
#[tokio::test]
872+
async fn rtt() {
873+
let server = nats_server::run_basic_server();
874+
let client = async_nats::connect(server.client_url()).await.unwrap();
875+
876+
let rtt = client.rtt().await.unwrap();
877+
878+
println!("rtt: {:?}", rtt);
879+
assert!(rtt.as_nanos() > 0);
880+
}
870881
}

0 commit comments

Comments
 (0)