@@ -100,7 +100,8 @@ class ConnectionStates(object):
100100 HANDSHAKE = '<handshake>'
101101 CONNECTED = '<connected>'
102102 AUTHENTICATING = '<authenticating>'
103- API_VERSIONS = '<checking_api_versions>'
103+ API_VERSIONS_SEND = '<checking_api_versions_send>'
104+ API_VERSIONS_RECV = '<checking_api_versions_recv>'
104105
105106
106107class BrokerConnection (object ):
@@ -419,7 +420,7 @@ def connect(self):
419420 self ._wrap_ssl ()
420421 else :
421422 log .debug ('%s: checking broker Api Versions' , self )
422- self .state = ConnectionStates .API_VERSIONS
423+ self .state = ConnectionStates .API_VERSIONS_SEND
423424 self .config ['state_change_callback' ](self .node_id , self ._sock , self )
424425
425426 # Connection failed
@@ -439,13 +440,13 @@ def connect(self):
439440 if self ._try_handshake ():
440441 log .debug ('%s: completed SSL handshake.' , self )
441442 log .debug ('%s: checking broker Api Versions' , self )
442- self .state = ConnectionStates .API_VERSIONS
443+ self .state = ConnectionStates .API_VERSIONS_SEND
443444 self .config ['state_change_callback' ](self .node_id , self ._sock , self )
444445
445- if self .state is ConnectionStates .API_VERSIONS :
446+ if self .state in ( ConnectionStates .API_VERSIONS_SEND , ConnectionStates . API_VERSIONS_RECV ) :
446447 if self ._try_api_versions_check ():
447448 # _try_api_versions_check has side-effects: possibly disconnected on socket errors
448- if self .state is ConnectionStates .API_VERSIONS :
449+ if self .state in ( ConnectionStates .API_VERSIONS_SEND , ConnectionStates . API_VERSIONS_RECV ) :
449450 if self .config ['security_protocol' ] in ('SASL_PLAINTEXT' , 'SASL_SSL' ):
450451 log .debug ('%s: initiating SASL authentication' , self )
451452 self .state = ConnectionStates .AUTHENTICATING
@@ -555,13 +556,17 @@ def _try_api_versions_check(self):
555556 response .add_callback (self ._handle_api_versions_response , future )
556557 response .add_errback (self ._handle_api_versions_failure , future )
557558 self ._api_versions_future = future
559+ self .state = ConnectionStates .API_VERSIONS_RECV
560+ self .config ['state_change_callback' ](self .node_id , self ._sock , self )
558561 elif self ._check_version_idx < len (self .VERSION_CHECKS ):
559562 version , request = self .VERSION_CHECKS [self ._check_version_idx ]
560563 future = Future ()
561564 response = self ._send (request , blocking = True , request_timeout_ms = (self .config ['api_version_auto_timeout_ms' ] * 0.8 ))
562565 response .add_callback (self ._handle_check_version_response , future , version )
563566 response .add_errback (self ._handle_check_version_failure , future )
564567 self ._api_versions_future = future
568+ self .state = ConnectionStates .API_VERSIONS_RECV
569+ self .config ['state_change_callback' ](self .node_id , self ._sock , self )
565570 else :
566571 raise 'Unable to determine broker version.'
567572
@@ -991,14 +996,16 @@ def connecting(self):
991996 return self .state in (ConnectionStates .CONNECTING ,
992997 ConnectionStates .HANDSHAKE ,
993998 ConnectionStates .AUTHENTICATING ,
994- ConnectionStates .API_VERSIONS )
999+ ConnectionStates .API_VERSIONS_SEND ,
1000+ ConnectionStates .API_VERSIONS_RECV )
9951001
9961002 def initializing (self ):
9971003 """Returns True if socket is connected but full connection is not complete.
9981004 During this time the connection may send api requests to the broker to
9991005 check api versions and perform SASL authentication."""
10001006 return self .state in (ConnectionStates .AUTHENTICATING ,
1001- ConnectionStates .API_VERSIONS )
1007+ ConnectionStates .API_VERSIONS_SEND ,
1008+ ConnectionStates .API_VERSIONS_RECV )
10021009
10031010 def disconnected (self ):
10041011 """Return True iff socket is closed"""
0 commit comments