@@ -105,6 +105,7 @@ use thiserror::Error;
105105use futures:: future:: FutureExt ;
106106use futures:: select;
107107use futures:: stream:: Stream ;
108+ use std:: time:: Instant ;
108109use tracing:: { debug, error} ;
109110
110111use core:: fmt;
@@ -261,6 +262,9 @@ pub enum Command {
261262 } ,
262263 TryFlush ,
263264 Connect ( ConnectInfo ) ,
265+ Rtt {
266+ result : oneshot:: Sender < Result < Duration , io:: Error > > ,
267+ } ,
264268}
265269
266270/// `ClientOp` represents all actions of `Client`.
@@ -305,6 +309,9 @@ pub(crate) struct ConnectionHandler {
305309 info_sender : tokio:: sync:: watch:: Sender < ServerInfo > ,
306310 ping_interval : Interval ,
307311 flush_interval : Interval ,
312+ last_ping_time : Option < Instant > ,
313+ last_pong_time : Option < Instant > ,
314+ rtt_senders : Vec < oneshot:: Sender < Result < Duration , io:: Error > > > ,
308315}
309316
310317impl ConnectionHandler {
@@ -330,6 +337,9 @@ impl ConnectionHandler {
330337 info_sender,
331338 ping_interval,
332339 flush_interval,
340+ last_ping_time : None ,
341+ last_pong_time : None ,
342+ rtt_senders : Vec :: new ( ) ,
333343 }
334344 }
335345
@@ -397,6 +407,23 @@ impl ConnectionHandler {
397407 }
398408 ServerOp :: Pong => {
399409 debug ! ( "received PONG" ) ;
410+ if self . pending_pings == 1 {
411+ // Do we even need to store the last_pong_time?
412+ self . last_pong_time = Some ( Instant :: now ( ) ) ;
413+
414+ while let Some ( sender) = self . rtt_senders . pop ( ) {
415+ if let ( Some ( ping) , Some ( pong) ) = ( self . last_ping_time , self . last_pong_time )
416+ {
417+ let rtt = pong. duration_since ( ping) ;
418+ sender. send ( Ok ( rtt) ) . map_err ( |_| {
419+ io:: Error :: new (
420+ io:: ErrorKind :: Other ,
421+ "one shot failed to be received" ,
422+ )
423+ } ) ?;
424+ }
425+ }
426+ }
400427 self . pending_pings = self . pending_pings . saturating_sub ( 1 ) ;
401428 }
402429 ServerOp :: Error ( error) => {
@@ -509,26 +536,17 @@ impl ConnectionHandler {
509536 }
510537 }
511538 Command :: Ping => {
512- debug ! (
513- "PING command. Pending pings {}, max pings {}" ,
514- self . pending_pings, self . max_pings
515- ) ;
516- self . pending_pings += 1 ;
517- self . ping_interval . reset ( ) ;
518-
519- if self . pending_pings > self . max_pings {
520- debug ! (
521- "pending pings {}, max pings {}. disconnecting" ,
522- self . pending_pings, self . max_pings
523- ) ;
524- self . handle_disconnect ( ) . await ?;
525- }
526-
527- if let Err ( _err) = self . connection . write_op ( & ClientOp :: Ping ) . await {
528- self . handle_disconnect ( ) . await ?;
539+ self . handle_ping ( ) . await ?;
540+ }
541+ Command :: Rtt { result } => {
542+ self . rtt_senders . push ( result) ;
543+
544+ if self . pending_pings == 0 {
545+ // start the clock for calculating round trip time
546+ self . last_ping_time = Some ( Instant :: now ( ) ) ;
547+ // do a ping and expect a pong - will calculate rtt when handling the pong
548+ self . handle_ping ( ) . await ?;
529549 }
530-
531- self . handle_flush ( ) . await ?;
532550 }
533551 Command :: Flush { result } => {
534552 if let Err ( _err) = self . handle_flush ( ) . await {
@@ -613,8 +631,34 @@ impl ConnectionHandler {
613631 Ok ( ( ) )
614632 }
615633
634+ async fn handle_ping ( & mut self ) -> Result < ( ) , io:: Error > {
635+ debug ! (
636+ "PING command. Pending pings {}, max pings {}" ,
637+ self . pending_pings, self . max_pings
638+ ) ;
639+ self . pending_pings += 1 ;
640+ self . ping_interval . reset ( ) ;
641+
642+ if self . pending_pings > self . max_pings {
643+ debug ! (
644+ "pending pings {}, max pings {}. disconnecting" ,
645+ self . pending_pings, self . max_pings
646+ ) ;
647+ self . handle_disconnect ( ) . await ?;
648+ }
649+
650+ if let Err ( _err) = self . connection . write_op ( & ClientOp :: Ping ) . await {
651+ self . handle_disconnect ( ) . await ?;
652+ }
653+
654+ self . handle_flush ( ) . await ?;
655+ Ok ( ( ) )
656+ }
657+
616658 async fn handle_disconnect ( & mut self ) -> io:: Result < ( ) > {
617659 self . pending_pings = 0 ;
660+ self . last_ping_time = None ;
661+ self . last_pong_time = None ;
618662 self . connector . events_tx . try_send ( Event :: Disconnected ) . ok ( ) ;
619663 self . connector . state_tx . send ( State :: Disconnected ) . ok ( ) ;
620664 self . handle_reconnect ( ) . await ?;
0 commit comments