From af1851b02373359e3a04b00f23bccf1cebdd17ae Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 23 Jul 2025 10:33:36 +0530 Subject: [PATCH 1/3] telemetry retry Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/common/http.py | 71 +++++++++++- src/databricks/sql/exc.py | 4 +- .../sql/telemetry/telemetry_client.py | 5 +- tests/e2e/test_telemetry_retry.py | 107 ++++++++++++++++++ tests/unit/test_telemetry.py | 2 +- 5 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 tests/e2e/test_telemetry_retry.py diff --git a/src/databricks/sql/common/http.py b/src/databricks/sql/common/http.py index ec4e3341a..0cd2919c0 100644 --- a/src/databricks/sql/common/http.py +++ b/src/databricks/sql/common/http.py @@ -5,8 +5,10 @@ import threading from dataclasses import dataclass from contextlib import contextmanager -from typing import Generator +from typing import Generator, Optional import logging +from requests.adapters import HTTPAdapter +from databricks.sql.auth.retry import DatabricksRetryPolicy, CommandType logger = logging.getLogger(__name__) @@ -81,3 +83,70 @@ def execute( def close(self): self.session.close() + + +class TelemetryHTTPAdapter(HTTPAdapter): + """ + Custom HTTP adapter to prepare our DatabricksRetryPolicy before each request. + This ensures the retry timer is started and the command type is set correctly, + allowing the policy to manage its state for the duration of the request retries. + """ + + def send(self, request, **kwargs): + self.max_retries.command_type = CommandType.OTHER + self.max_retries.start_retry_timer() + return super().send(request, **kwargs) + + +class TelemetryHttpClient: # TODO: Unify all the http clients in the PySQL Connector + """Singleton HTTP client for sending telemetry data.""" + + _instance: Optional["TelemetryHttpClient"] = None + _lock = threading.Lock() + + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT = 3 + TELEMETRY_RETRY_DELAY_MIN = 1.0 + TELEMETRY_RETRY_DELAY_MAX = 10.0 + TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION = 30.0 + + def __init__(self): + """Initializes the session and mounts the custom retry adapter.""" + retry_policy = DatabricksRetryPolicy( + delay_min=self.TELEMETRY_RETRY_DELAY_MIN, + delay_max=self.TELEMETRY_RETRY_DELAY_MAX, + stop_after_attempts_count=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_COUNT, + stop_after_attempts_duration=self.TELEMETRY_RETRY_STOP_AFTER_ATTEMPTS_DURATION, + delay_default=1.0, + force_dangerous_codes=[], + ) + adapter = TelemetryHTTPAdapter(max_retries=retry_policy) + self.session = requests.Session() + self.session.mount("https://", adapter) + self.session.mount("http://", adapter) + + @classmethod + def get_instance(cls) -> "TelemetryHttpClient": + """Get the singleton instance of the TelemetryHttpClient.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + logger.debug("Initializing singleton TelemetryHttpClient") + cls._instance = TelemetryHttpClient() + return cls._instance + + def post(self, url: str, **kwargs) -> requests.Response: + """ + Executes a POST request using the configured session. + + This is a blocking call intended to be run in a background thread. + """ + logger.debug("Executing telemetry POST request to: %s", url) + return self.session.post(url, **kwargs) + + def close(self): + """Closes the underlying requests.Session.""" + logger.debug("Closing TelemetryHttpClient session.") + self.session.close() + # Clear the instance to allow for re-initialization if needed + with TelemetryHttpClient._lock: + TelemetryHttpClient._instance = None diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 65235f630..4a772c49b 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -2,8 +2,6 @@ import logging logger = logging.getLogger(__name__) -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory - ### PEP-249 Mandated ### # https://peps.python.org/pep-0249/#exceptions @@ -22,6 +20,8 @@ def __init__( error_name = self.__class__.__name__ if session_id_hex: + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory + telemetry_client = TelemetryClientFactory.get_telemetry_client( session_id_hex ) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 2c389513a..1690d0c3f 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -4,6 +4,7 @@ import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional +from databricks.sql.common.http import TelemetryHttpClient from databricks.sql.telemetry.models.event import ( TelemetryEvent, DriverSystemConfiguration, @@ -159,6 +160,7 @@ def __init__( self._driver_connection_params = None self._host_url = host_url self._executor = executor + self._http_client = TelemetryHttpClient.get_instance() def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" @@ -207,7 +209,7 @@ def _send_telemetry(self, events): try: logger.debug("Submitting telemetry request to thread pool") future = self._executor.submit( - requests.post, + self._http_client.post, url, data=request.to_json(), headers=headers, @@ -433,6 +435,7 @@ def close(session_id_hex): ) try: TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryHttpClient.close() except Exception as e: logger.debug("Failed to shutdown thread pool executor: %s", e) TelemetryClientFactory._executor = None diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py new file mode 100644 index 000000000..11055b558 --- /dev/null +++ b/tests/e2e/test_telemetry_retry.py @@ -0,0 +1,107 @@ +import pytest +from unittest.mock import patch, MagicMock +import io +import time + +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.auth.retry import DatabricksRetryPolicy + +PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' + +def create_mock_conn(responses): + """Creates a mock connection object whose getresponse() method yields a series of responses.""" + mock_conn = MagicMock() + mock_http_responses = [] + for resp in responses: + mock_http_response = MagicMock() + mock_http_response.status = resp.get("status") + mock_http_response.headers = resp.get("headers", {}) + body = resp.get("body", b'{}') + mock_http_response.fp = io.BytesIO(body) + def release(): + mock_http_response.fp.close() + mock_http_response.release_conn = release + mock_http_responses.append(mock_http_response) + mock_conn.getresponse.side_effect = mock_http_responses + return mock_conn + +class TestTelemetryClientRetries: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + yield + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + + def get_client(self, session_id, num_retries=3): + """ + Configures a client with a specific number of retries. + """ + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test.databricks.com", + ) + client = TelemetryClientFactory.get_telemetry_client(session_id) + + retry_policy = DatabricksRetryPolicy( + delay_min=0.01, + delay_max=0.02, + stop_after_attempts_duration=2.0, + stop_after_attempts_count=num_retries, + delay_default=0.1, + force_dangerous_codes=[], + urllib3_kwargs={'total': num_retries} + ) + adapter = client._http_client.session.adapters.get("https://") + adapter.max_retries = retry_policy + return client + + @pytest.mark.parametrize( + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], + ) + def test_non_retryable_status_codes_are_not_retried(self, status_code, description): + """ + Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. + """ + # Use the status code in the session ID for easier debugging if it fails + client = self.get_client(f"session-{status_code}") + mock_responses = [{"status": status_code}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + + mock_get_conn.return_value.getresponse.assert_called_once() + + def test_exceeds_retry_count_limit(self): + """ + Verifies that the client retries up to the specified number of times before giving up. + Verifies that the client respects the Retry-After header and retries on 429, 502, 503. + """ + num_retries = 3 + expected_total_calls = num_retries + 1 + retry_after = 1 + client = self.get_client("session-exceed-limit", num_retries=num_retries) + mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + start_time = time.time() + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + end_time = time.time() + + assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls + assert end_time - start_time > retry_after \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 4e6e928ab..33db0a245 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -90,7 +90,7 @@ def test_network_request_flow(self, mock_post, mock_telemetry_client): args, kwargs = client._executor.submit.call_args # Verify correct function and URL - assert args[0] == requests.post + assert args[0] == client._http_client.post assert args[1] == "https://test-host.com/telemetry-ext" assert kwargs["headers"]["Authorization"] == "Bearer test-token" From a13155ba4cf06f7a0a7194be21e05de8651f1e6c Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 23 Jul 2025 10:41:43 +0530 Subject: [PATCH 2/3] shifted tests to unit test, removed unused imports Signed-off-by: Sai Shree Pradhan --- .../sql/telemetry/telemetry_client.py | 1 - tests/e2e/test_telemetry_retry.py | 107 ------------------ tests/unit/test_telemetry.py | 1 - 3 files changed, 109 deletions(-) delete mode 100644 tests/e2e/test_telemetry_retry.py diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 1690d0c3f..8462e7ffe 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -1,6 +1,5 @@ import threading import time -import requests import logging from concurrent.futures import ThreadPoolExecutor from typing import Dict, Optional diff --git a/tests/e2e/test_telemetry_retry.py b/tests/e2e/test_telemetry_retry.py deleted file mode 100644 index 11055b558..000000000 --- a/tests/e2e/test_telemetry_retry.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import io -import time - -from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory -from databricks.sql.auth.retry import DatabricksRetryPolicy - -PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' - -def create_mock_conn(responses): - """Creates a mock connection object whose getresponse() method yields a series of responses.""" - mock_conn = MagicMock() - mock_http_responses = [] - for resp in responses: - mock_http_response = MagicMock() - mock_http_response.status = resp.get("status") - mock_http_response.headers = resp.get("headers", {}) - body = resp.get("body", b'{}') - mock_http_response.fp = io.BytesIO(body) - def release(): - mock_http_response.fp.close() - mock_http_response.release_conn = release - mock_http_responses.append(mock_http_response) - mock_conn.getresponse.side_effect = mock_http_responses - return mock_conn - -class TestTelemetryClientRetries: - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - yield - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) - TelemetryClientFactory._initialized = False - TelemetryClientFactory._clients = {} - TelemetryClientFactory._executor = None - - def get_client(self, session_id, num_retries=3): - """ - Configures a client with a specific number of retries. - """ - TelemetryClientFactory.initialize_telemetry_client( - telemetry_enabled=True, - session_id_hex=session_id, - auth_provider=None, - host_url="test.databricks.com", - ) - client = TelemetryClientFactory.get_telemetry_client(session_id) - - retry_policy = DatabricksRetryPolicy( - delay_min=0.01, - delay_max=0.02, - stop_after_attempts_duration=2.0, - stop_after_attempts_count=num_retries, - delay_default=0.1, - force_dangerous_codes=[], - urllib3_kwargs={'total': num_retries} - ) - adapter = client._http_client.session.adapters.get("https://") - adapter.max_retries = retry_policy - return client - - @pytest.mark.parametrize( - "status_code, description", - [ - (401, "Unauthorized"), - (403, "Forbidden"), - (501, "Not Implemented"), - (200, "Success"), - ], - ) - def test_non_retryable_status_codes_are_not_retried(self, status_code, description): - """ - Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. - """ - # Use the status code in the session ID for easier debugging if it fails - client = self.get_client(f"session-{status_code}") - mock_responses = [{"status": status_code}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: - client.export_failure_log("TestError", "Test message") - TelemetryClientFactory.close(client._session_id_hex) - - mock_get_conn.return_value.getresponse.assert_called_once() - - def test_exceeds_retry_count_limit(self): - """ - Verifies that the client retries up to the specified number of times before giving up. - Verifies that the client respects the Retry-After header and retries on 429, 502, 503. - """ - num_retries = 3 - expected_total_calls = num_retries + 1 - retry_after = 1 - client = self.get_client("session-exceed-limit", num_retries=num_retries) - mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] - - with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: - start_time = time.time() - client.export_failure_log("TestError", "Test message") - TelemetryClientFactory.close(client._session_id_hex) - end_time = time.time() - - assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls - assert end_time - start_time > retry_after \ No newline at end of file diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 33db0a245..6c4c2edfe 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -1,6 +1,5 @@ import uuid import pytest -import requests from unittest.mock import patch, MagicMock from databricks.sql.telemetry.telemetry_client import ( From f97136075d1cc0464f70e26bd148127bf1fc9549 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Wed, 23 Jul 2025 10:50:11 +0530 Subject: [PATCH 3/3] tests Signed-off-by: Sai Shree Pradhan --- tests/unit/test_telemetry_retry.py | 107 +++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 tests/unit/test_telemetry_retry.py diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py new file mode 100644 index 000000000..11055b558 --- /dev/null +++ b/tests/unit/test_telemetry_retry.py @@ -0,0 +1,107 @@ +import pytest +from unittest.mock import patch, MagicMock +import io +import time + +from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory +from databricks.sql.auth.retry import DatabricksRetryPolicy + +PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn' + +def create_mock_conn(responses): + """Creates a mock connection object whose getresponse() method yields a series of responses.""" + mock_conn = MagicMock() + mock_http_responses = [] + for resp in responses: + mock_http_response = MagicMock() + mock_http_response.status = resp.get("status") + mock_http_response.headers = resp.get("headers", {}) + body = resp.get("body", b'{}') + mock_http_response.fp = io.BytesIO(body) + def release(): + mock_http_response.fp.close() + mock_http_response.release_conn = release + mock_http_responses.append(mock_http_response) + mock_conn.getresponse.side_effect = mock_http_responses + return mock_conn + +class TestTelemetryClientRetries: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + yield + if TelemetryClientFactory._executor: + TelemetryClientFactory._executor.shutdown(wait=True) + TelemetryClientFactory._initialized = False + TelemetryClientFactory._clients = {} + TelemetryClientFactory._executor = None + + def get_client(self, session_id, num_retries=3): + """ + Configures a client with a specific number of retries. + """ + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex=session_id, + auth_provider=None, + host_url="test.databricks.com", + ) + client = TelemetryClientFactory.get_telemetry_client(session_id) + + retry_policy = DatabricksRetryPolicy( + delay_min=0.01, + delay_max=0.02, + stop_after_attempts_duration=2.0, + stop_after_attempts_count=num_retries, + delay_default=0.1, + force_dangerous_codes=[], + urllib3_kwargs={'total': num_retries} + ) + adapter = client._http_client.session.adapters.get("https://") + adapter.max_retries = retry_policy + return client + + @pytest.mark.parametrize( + "status_code, description", + [ + (401, "Unauthorized"), + (403, "Forbidden"), + (501, "Not Implemented"), + (200, "Success"), + ], + ) + def test_non_retryable_status_codes_are_not_retried(self, status_code, description): + """ + Verifies that terminal error codes (401, 403, 501) and success codes (200) are not retried. + """ + # Use the status code in the session ID for easier debugging if it fails + client = self.get_client(f"session-{status_code}") + mock_responses = [{"status": status_code}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + + mock_get_conn.return_value.getresponse.assert_called_once() + + def test_exceeds_retry_count_limit(self): + """ + Verifies that the client retries up to the specified number of times before giving up. + Verifies that the client respects the Retry-After header and retries on 429, 502, 503. + """ + num_retries = 3 + expected_total_calls = num_retries + 1 + retry_after = 1 + client = self.get_client("session-exceed-limit", num_retries=num_retries) + mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}] + + with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn: + start_time = time.time() + client.export_failure_log("TestError", "Test message") + TelemetryClientFactory.close(client._session_id_hex) + end_time = time.time() + + assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls + assert end_time - start_time > retry_after \ No newline at end of file