66from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
77from databricks .sql .auth .retry import DatabricksRetryPolicy
88
9- PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
9+ PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn"
10+
1011
1112def create_mock_conn (responses ):
1213 """Creates a mock connection object whose getresponse() method yields a series of responses."""
@@ -16,15 +17,18 @@ def create_mock_conn(responses):
1617 mock_http_response = MagicMock ()
1718 mock_http_response .status = resp .get ("status" )
1819 mock_http_response .headers = resp .get ("headers" , {})
19- body = resp .get ("body" , b'{}' )
20+ body = resp .get ("body" , b"{}" )
2021 mock_http_response .fp = io .BytesIO (body )
22+
2123 def release ():
2224 mock_http_response .fp .close ()
25+
2326 mock_http_response .release_conn = release
2427 mock_http_responses .append (mock_http_response )
2528 mock_conn .getresponse .side_effect = mock_http_responses
2629 return mock_conn
2730
31+
2832class TestTelemetryClientRetries :
2933 @pytest .fixture (autouse = True )
3034 def setup_and_teardown (self ):
@@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3):
4953 host_url = "test.databricks.com" ,
5054 )
5155 client = TelemetryClientFactory .get_telemetry_client (session_id )
52-
56+
5357 retry_policy = DatabricksRetryPolicy (
5458 delay_min = 0.01 ,
5559 delay_max = 0.02 ,
5660 stop_after_attempts_duration = 2.0 ,
57- stop_after_attempts_count = num_retries ,
61+ stop_after_attempts_count = num_retries ,
5862 delay_default = 0.1 ,
5963 force_dangerous_codes = [],
60- urllib3_kwargs = {' total' : num_retries }
64+ urllib3_kwargs = {" total" : num_retries },
6165 )
6266 adapter = client ._http_client .session .adapters .get ("https://" )
6367 adapter .max_retries = retry_policy
6468 return client
6569
6670 @pytest .mark .parametrize (
67- "status_code, description" ,
68- [
69- (401 , "Unauthorized" ),
70- (403 , "Forbidden" ),
71- (501 , "Not Implemented" ),
72- (200 , "Success" ),
73- ],
71+ "status_code, description" ,
72+ [
73+ (401 , "Unauthorized" ),
74+ (403 , "Forbidden" ),
75+ (501 , "Not Implemented" ),
76+ (200 , "Success" ),
77+ ],
7478 )
7579 def test_non_retryable_status_codes_are_not_retried (self , status_code , description ):
7680 """
@@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
8084 client = self .get_client (f"session-{ status_code } " )
8185 mock_responses = [{"status" : status_code }]
8286
83- with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
87+ with patch (
88+ PATCH_TARGET , return_value = create_mock_conn (mock_responses )
89+ ) as mock_get_conn :
8490 client .export_failure_log ("TestError" , "Test message" )
8591 TelemetryClientFactory .close (client ._session_id_hex )
8692
@@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self):
9298 Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
9399 """
94100 num_retries = 3
95- expected_total_calls = num_retries + 1
101+ expected_total_calls = num_retries + 1
96102 retry_after = 1
97103 client = self .get_client ("session-exceed-limit" , num_retries = num_retries )
98- mock_responses = [{"status" : 503 , "headers" : {"Retry-After" : str (retry_after )}}, {"status" : 429 }, {"status" : 502 }, {"status" : 503 }]
99-
100- with patch (PATCH_TARGET , return_value = create_mock_conn (mock_responses )) as mock_get_conn :
104+ mock_responses = [
105+ {"status" : 503 , "headers" : {"Retry-After" : str (retry_after )}},
106+ {"status" : 429 },
107+ {"status" : 502 },
108+ {"status" : 503 },
109+ ]
110+
111+ with patch (
112+ PATCH_TARGET , return_value = create_mock_conn (mock_responses )
113+ ) as mock_get_conn :
101114 start_time = time .time ()
102115 client .export_failure_log ("TestError" , "Test message" )
103116 TelemetryClientFactory .close (client ._session_id_hex )
104117 end_time = time .time ()
105-
106- assert mock_get_conn .return_value .getresponse .call_count == expected_total_calls
107- assert end_time - start_time > retry_after
118+
119+ assert (
120+ mock_get_conn .return_value .getresponse .call_count
121+ == expected_total_calls
122+ )
123+ assert end_time - start_time > retry_after
0 commit comments