@@ -104,6 +104,7 @@ use thiserror::Error;
104104use futures:: future:: FutureExt ;
105105use futures:: select;
106106use futures:: stream:: Stream ;
107+ use std:: time:: Instant ;
107108use tracing:: { debug, error} ;
108109
109110use core:: fmt;
@@ -259,6 +260,9 @@ pub(crate) enum Command {
259260 result : oneshot:: Sender < Result < ( ) , io:: Error > > ,
260261 } ,
261262 TryFlush ,
263+ Rtt {
264+ result : oneshot:: Sender < Result < Duration , io:: Error > > ,
265+ } ,
262266}
263267
264268/// `ClientOp` represents all actions of `Client`.
@@ -302,6 +306,9 @@ pub(crate) struct ConnectionHandler {
302306 info_sender : tokio:: sync:: watch:: Sender < ServerInfo > ,
303307 ping_interval : Interval ,
304308 flush_interval : Interval ,
309+ last_ping_time : Option < Instant > ,
310+ last_pong_time : Option < Instant > ,
311+ rtt_senders : Vec < oneshot:: Sender < Result < Duration , io:: Error > > > ,
305312}
306313
307314impl ConnectionHandler {
@@ -326,6 +333,9 @@ impl ConnectionHandler {
326333 info_sender,
327334 ping_interval,
328335 flush_interval,
336+ last_ping_time : None ,
337+ last_pong_time : None ,
338+ rtt_senders : Vec :: new ( ) ,
329339 }
330340 }
331341
@@ -404,6 +414,23 @@ impl ConnectionHandler {
404414 }
405415 ServerOp :: Pong => {
406416 debug ! ( "received PONG" ) ;
417+ if self . pending_pings == 1 {
418+ // Do we even need to store the last_pong_time?
419+ self . last_pong_time = Some ( Instant :: now ( ) ) ;
420+
421+ while let Some ( sender) = self . rtt_senders . pop ( ) {
422+ if let ( Some ( ping) , Some ( pong) ) = ( self . last_ping_time , self . last_pong_time )
423+ {
424+ let rtt = pong. duration_since ( ping) ;
425+ sender. send ( Ok ( rtt) ) . map_err ( |_| {
426+ io:: Error :: new (
427+ io:: ErrorKind :: Other ,
428+ "one shot failed to be received" ,
429+ )
430+ } ) ?;
431+ }
432+ }
433+ }
407434 self . pending_pings = self . pending_pings . saturating_sub ( 1 ) ;
408435 }
409436 ServerOp :: Error ( error) => {
@@ -517,6 +544,14 @@ impl ConnectionHandler {
517544 }
518545 }
519546 }
547+ Command :: Rtt { result } => {
548+ self . rtt_senders . push ( result) ;
549+
550+ if self . pending_pings == 0 {
551+ // do a ping and expect a pong - will calculate rtt when handling the pong
552+ self . handle_ping ( ) . await ?;
553+ }
554+ }
520555 Command :: Flush { result } => {
521556 if let Err ( _err) = self . handle_flush ( ) . await {
522557 if let Err ( err) = self . handle_disconnect ( ) . await {
@@ -591,8 +626,39 @@ impl ConnectionHandler {
591626 Ok ( ( ) )
592627 }
593628
629+ async fn handle_ping ( & mut self ) -> Result < ( ) , io:: Error > {
630+ debug ! (
631+ "PING command. Pending pings {}, max pings {}" ,
632+ self . pending_pings, MAX_PENDING_PINGS
633+ ) ;
634+ self . pending_pings += 1 ;
635+ self . ping_interval . reset ( ) ;
636+
637+ if self . pending_pings > MAX_PENDING_PINGS {
638+ debug ! (
639+ "pending pings {}, max pings {}. disconnecting" ,
640+ self . pending_pings, MAX_PENDING_PINGS
641+ ) ;
642+ self . handle_disconnect ( ) . await ?;
643+ }
644+
645+ if self . pending_pings == 1 {
646+ // start the clock for calculating round trip time
647+ self . last_ping_time = Some ( Instant :: now ( ) ) ;
648+ }
649+
650+ if let Err ( _err) = self . connection . write_op ( & ClientOp :: Ping ) . await {
651+ self . handle_disconnect ( ) . await ?;
652+ }
653+
654+ self . handle_flush ( ) . await ?;
655+ Ok ( ( ) )
656+ }
657+
594658 async fn handle_disconnect ( & mut self ) -> io:: Result < ( ) > {
595659 self . pending_pings = 0 ;
660+ self . last_ping_time = None ;
661+ self . last_pong_time = None ;
596662 self . connector . events_tx . try_send ( Event :: Disconnected ) . ok ( ) ;
597663 self . connector . state_tx . send ( State :: Disconnected ) . ok ( ) ;
598664 self . handle_reconnect ( ) . await ?;
0 commit comments