1+ # tests/e2e/test_telemetry_retry.py
2+
3+ import pytest
4+ import logging
5+ from unittest .mock import patch , MagicMock
6+ from functools import wraps
7+ import time
8+ from concurrent .futures import Future
9+
10+ # Imports for the code being tested
11+ from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory
12+ from databricks .sql .telemetry .models .event import DriverConnectionParameters , HostDetails , DatabricksClientType
13+ from databricks .sql .telemetry .models .enums import AuthMech
14+ from databricks .sql .auth .retry import DatabricksRetryPolicy , CommandType
15+
16+ # Imports for mocking the network layer correctly
17+ from urllib3 .connectionpool import HTTPSConnectionPool
18+ from urllib3 .exceptions import MaxRetryError
19+ from requests .exceptions import ConnectionError as RequestsConnectionError
20+
21+ PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
22+
23+ # Helper to create a mock that looks and acts like a urllib3.response.HTTPResponse.
24+ def create_urllib3_response (status , headers = None , body = b'{}' ):
25+ """Create a proper mock response that simulates urllib3's HTTPResponse"""
26+ mock_response = MagicMock ()
27+ mock_response .status = status
28+ mock_response .headers = headers or {}
29+ mock_response .msg = headers or {} # For urllib3~=1.0 compatibility
30+ mock_response .data = body
31+ mock_response .read .return_value = body
32+ mock_response .get_redirect_location .return_value = False
33+ mock_response .closed = False
34+ mock_response .isclosed .return_value = False
35+ return mock_response
36+
37+ @pytest .mark .usefixtures ("caplog" )
38+ class TestTelemetryClientRetries :
39+ """
40+ Test suite for verifying the retry mechanism of the TelemetryClient.
41+ This suite patches the low-level urllib3 connection to correctly
42+ trigger and test the retry logic configured in the requests adapter.
43+ """
44+
45+ @pytest .fixture (autouse = True )
46+ def setup_and_teardown (self , caplog ):
47+ caplog .set_level (logging .DEBUG )
48+ TelemetryClientFactory ._initialized = False
49+ TelemetryClientFactory ._clients = {}
50+ TelemetryClientFactory ._executor = None
51+ yield
52+ if TelemetryClientFactory ._executor :
53+ TelemetryClientFactory ._executor .shutdown (wait = True )
54+ TelemetryClientFactory ._initialized = False
55+ TelemetryClientFactory ._clients = {}
56+ TelemetryClientFactory ._executor = None
57+
58+ def get_client (self , session_id , total_retries = 3 ):
59+ TelemetryClientFactory .initialize_telemetry_client (
60+ telemetry_enabled = True ,
61+ session_id_hex = session_id ,
62+ auth_provider = None ,
63+ host_url = "test.databricks.com" ,
64+ )
65+ client = TelemetryClientFactory .get_telemetry_client (session_id )
66+
67+ retry_policy = DatabricksRetryPolicy (
68+ delay_min = 0.01 ,
69+ delay_max = 0.02 ,
70+ stop_after_attempts_duration = 2.0 ,
71+ stop_after_attempts_count = total_retries ,
72+ delay_default = 0.1 ,
73+ force_dangerous_codes = [],
74+ urllib3_kwargs = {'total' : total_retries }
75+ )
76+ adapter = client ._session .adapters .get ("https://" )
77+ adapter .max_retries = retry_policy
78+ return client , adapter
79+
80+ def wait_for_async_request (self , timeout = 2.0 ):
81+ """Wait for async telemetry request to complete"""
82+ start_time = time .time ()
83+ while time .time () - start_time < timeout :
84+ if TelemetryClientFactory ._executor and TelemetryClientFactory ._executor ._threads :
85+ # Wait a bit more for threads to complete
86+ time .sleep (0.1 )
87+ else :
88+ break
89+ time .sleep (0.1 ) # Extra buffer for completion
90+
91+ def test_success_no_retry (self ):
92+ client , _ = self .get_client ("session-success" )
93+ params = DriverConnectionParameters (
94+ http_path = "test-path" ,
95+ mode = DatabricksClientType .THRIFT ,
96+ host_info = HostDetails (host_url = "test.databricks.com" , port = 443 ),
97+ auth_mech = AuthMech .PAT
98+ )
99+ with patch (PATCH_TARGET ) as mock_get_conn :
100+ mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (200 )
101+
102+ client .export_initial_telemetry_log (params , "test-agent" )
103+ self .wait_for_async_request ()
104+ TelemetryClientFactory .close (client ._session_id_hex )
105+
106+ mock_get_conn .return_value .getresponse .assert_called_once ()
107+
108+ def test_retry_on_503_then_succeeds (self ):
109+ client , _ = self .get_client ("session-retry-once" )
110+ with patch (PATCH_TARGET ) as mock_get_conn :
111+ mock_get_conn .return_value .getresponse .side_effect = [
112+ create_urllib3_response (503 ),
113+ create_urllib3_response (200 ),
114+ ]
115+
116+ client .export_failure_log ("TestError" , "Test message" )
117+ self .wait_for_async_request ()
118+ TelemetryClientFactory .close (client ._session_id_hex )
119+
120+ assert mock_get_conn .return_value .getresponse .call_count == 2
121+
122+ def test_respects_retry_after_header (self , caplog ):
123+ client , _ = self .get_client ("session-retry-after" )
124+ with patch (PATCH_TARGET ) as mock_get_conn :
125+ mock_get_conn .return_value .getresponse .side_effect = [
126+ create_urllib3_response (429 , headers = {'Retry-After' : '1' }), # Use integer seconds to avoid parsing issues
127+ create_urllib3_response (200 )
128+ ]
129+
130+ client .export_failure_log ("TestError" , "Test message" )
131+ self .wait_for_async_request ()
132+ TelemetryClientFactory .close (client ._session_id_hex )
133+
134+ # Check that the request was retried (should be 2 calls: initial + 1 retry)
135+ assert mock_get_conn .return_value .getresponse .call_count == 2
136+ assert "Retrying after" in caplog .text
137+
138+ def test_exceeds_retry_count_limit (self , caplog ):
139+ client , _ = self .get_client ("session-exceed-limit" , total_retries = 3 )
140+ expected_call_count = 4
141+ with patch (PATCH_TARGET ) as mock_get_conn :
142+ mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (503 )
143+
144+ client .export_failure_log ("TestError" , "Test message" )
145+ self .wait_for_async_request ()
146+ TelemetryClientFactory .close (client ._session_id_hex )
147+
148+ assert mock_get_conn .return_value .getresponse .call_count == expected_call_count
149+ assert "Telemetry request failed with exception" in caplog .text
150+ assert "Max retries exceeded" in caplog .text
151+
152+ def test_no_retry_on_401_unauthorized (self , caplog ):
153+ """Test that 401 responses are not retried (per retry policy)"""
154+ client , _ = self .get_client ("session-401" )
155+ with patch (PATCH_TARGET ) as mock_get_conn :
156+ mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (401 )
157+
158+ client .export_failure_log ("TestError" , "Test message" )
159+ self .wait_for_async_request ()
160+ TelemetryClientFactory .close (client ._session_id_hex )
161+
162+ # 401 should not be retried based on the retry policy
163+ mock_get_conn .return_value .getresponse .assert_called_once ()
164+ assert "Telemetry request failed with status code: 401" in caplog .text
165+
166+ def test_retries_on_400_bad_request (self , caplog ):
167+ """Test that 400 responses are retried (this is the current behavior for telemetry)"""
168+ client , _ = self .get_client ("session-400" )
169+ with patch (PATCH_TARGET ) as mock_get_conn :
170+ mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (400 )
171+
172+ client .export_failure_log ("TestError" , "Test message" )
173+ self .wait_for_async_request ()
174+ TelemetryClientFactory .close (client ._session_id_hex )
175+
176+ # Based on the logs, 400 IS being retried (this is the actual behavior for CommandType.OTHER)
177+ expected_call_count = 4 # total + 1 (initial + 3 retries)
178+ assert mock_get_conn .return_value .getresponse .call_count == expected_call_count
179+ assert "Telemetry request failed with exception" in caplog .text
180+ assert "Max retries exceeded" in caplog .text
181+
182+ def test_no_retry_on_403_forbidden (self , caplog ):
183+ """Test that 403 responses are not retried (per retry policy)"""
184+ client , _ = self .get_client ("session-403" )
185+ with patch (PATCH_TARGET ) as mock_get_conn :
186+ mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (403 )
187+
188+ client .export_failure_log ("TestError" , "Test message" )
189+ self .wait_for_async_request ()
190+ TelemetryClientFactory .close (client ._session_id_hex )
191+
192+ # 403 should not be retried based on the retry policy
193+ mock_get_conn .return_value .getresponse .assert_called_once ()
194+ assert "Telemetry request failed with status code: 403" in caplog .text
195+
196+ def test_retry_policy_command_type_is_set_to_other (self ):
197+ client , adapter = self .get_client ("session-command-type" )
198+
199+ original_send = adapter .send
200+ @wraps (original_send )
201+ def wrapper (request , ** kwargs ):
202+ assert adapter .max_retries .command_type == CommandType .OTHER
203+ return original_send (request , ** kwargs )
204+
205+ with patch .object (adapter , 'send' , side_effect = wrapper , autospec = True ), \
206+ patch (PATCH_TARGET ) as mock_get_conn :
207+ mock_get_conn .return_value .getresponse .return_value = create_urllib3_response (200 )
208+
209+ client .export_failure_log ("TestError" , "Test message" )
210+ self .wait_for_async_request ()
211+ TelemetryClientFactory .close (client ._session_id_hex )
212+
213+ assert adapter .send .call_count == 1
0 commit comments