From 7619a69703d9be7d1d41956ad04e2f0be9636d62 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 1 Aug 2025 01:32:32 +0530 Subject: [PATCH 1/2] flush fix, sync fix in e2e test Signed-off-by: Sai Shree Pradhan --- src/databricks/sql/telemetry/telemetry_client.py | 3 +++ tests/e2e/test_concurrent_telemetry.py | 12 ++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9960490c5..75c29b19c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -127,6 +127,9 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): def close(self): pass + def _flush(self): + pass + class TelemetryClient(BaseTelemetryClient): """ diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 656bcd21f..2edce7740 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -35,6 +35,7 @@ def telemetry_setup_teardown(self): if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._initialized = False def test_concurrent_queries_sends_telemetry(self): @@ -101,8 +102,15 @@ def execute_query_worker(thread_id): # Run the workers concurrently run_in_threads(execute_query_worker, num_threads, pass_index=True) - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) + timeout_seconds = 60 # Max time to wait for telemetry to arrive + start_time = time.time() + expected_event_count = num_threads * 2 # initial_log + latency_log per thread + + # Poll until the expected number of events are captured or we time out + while len(captured_telemetry) < expected_event_count: + if time.time() - start_time > timeout_seconds: + break # Exit loop if timeout is reached + time.sleep(0.1) # --- VERIFICATION --- assert not captured_exceptions From acf02bde9c6cb7812b952ada8af152a7f4aeb4b2 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 1 Aug 2025 10:19:19 +0530 Subject: [PATCH 2/2] sync fix Signed-off-by: Sai Shree Pradhan --- tests/e2e/test_concurrent_telemetry.py | 46 +++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 2edce7740..cb3aee21f 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -1,3 +1,4 @@ +from concurrent.futures import wait import random import threading import time @@ -48,8 +49,7 @@ def test_concurrent_queries_sends_telemetry(self): captured_telemetry = [] captured_session_ids = [] captured_statement_ids = [] - captured_responses = [] - captured_exceptions = [] + captured_futures = [] original_send_telemetry = TelemetryClient._send_telemetry original_callback = TelemetryClient._telemetry_request_callback @@ -64,18 +64,9 @@ def callback_wrapper(self_client, future, sent_count): Wraps the original callback to capture the server's response or any exceptions from the async network call. """ - try: - original_callback(self_client, future, sent_count) - - # Now, capture the result for our assertions - response = future.result() - response.raise_for_status() # Raise an exception for 4xx/5xx errors - telemetry_response = response.json() - with capture_lock: - captured_responses.append(telemetry_response) - except Exception as e: - with capture_lock: - captured_exceptions.append(e) + with capture_lock: + captured_futures.append(future) + original_callback(self_client, future, sent_count) with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \ patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): @@ -102,17 +93,26 @@ def execute_query_worker(thread_id): # Run the workers concurrently run_in_threads(execute_query_worker, num_threads, pass_index=True) - timeout_seconds = 60 # Max time to wait for telemetry to arrive + timeout_seconds = 60 start_time = time.time() - expected_event_count = num_threads * 2 # initial_log + latency_log per thread - - # Poll until the expected number of events are captured or we time out - while len(captured_telemetry) < expected_event_count: - if time.time() - start_time > timeout_seconds: - break # Exit loop if timeout is reached - time.sleep(0.1) + expected_event_count = num_threads + + while len(captured_futures) < expected_event_count and time.time() - start_time < timeout_seconds: + time.sleep(0.1) + + done, not_done = wait(captured_futures, timeout=timeout_seconds) + assert not not_done + + captured_exceptions = [] + captured_responses = [] + for future in done: + try: + response = future.result() + response.raise_for_status() + captured_responses.append(response.json()) + except Exception as e: + captured_exceptions.append(e) - # --- VERIFICATION --- assert not captured_exceptions assert len(captured_responses) > 0