diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 9960490c5..75c29b19c 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -127,6 +127,9 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): def close(self): pass + def _flush(self): + pass + class TelemetryClient(BaseTelemetryClient): """ diff --git a/tests/e2e/test_concurrent_telemetry.py b/tests/e2e/test_concurrent_telemetry.py index 656bcd21f..cb3aee21f 100644 --- a/tests/e2e/test_concurrent_telemetry.py +++ b/tests/e2e/test_concurrent_telemetry.py @@ -1,3 +1,4 @@ +from concurrent.futures import wait import random import threading import time @@ -35,6 +36,7 @@ def telemetry_setup_teardown(self): if TelemetryClientFactory._executor: TelemetryClientFactory._executor.shutdown(wait=True) TelemetryClientFactory._executor = None + TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._initialized = False def test_concurrent_queries_sends_telemetry(self): @@ -47,8 +49,7 @@ def test_concurrent_queries_sends_telemetry(self): captured_telemetry = [] captured_session_ids = [] captured_statement_ids = [] - captured_responses = [] - captured_exceptions = [] + captured_futures = [] original_send_telemetry = TelemetryClient._send_telemetry original_callback = TelemetryClient._telemetry_request_callback @@ -63,18 +64,9 @@ def callback_wrapper(self_client, future, sent_count): Wraps the original callback to capture the server's response or any exceptions from the async network call. """ - try: - original_callback(self_client, future, sent_count) - - # Now, capture the result for our assertions - response = future.result() - response.raise_for_status() # Raise an exception for 4xx/5xx errors - telemetry_response = response.json() - with capture_lock: - captured_responses.append(telemetry_response) - except Exception as e: - with capture_lock: - captured_exceptions.append(e) + with capture_lock: + captured_futures.append(future) + original_callback(self_client, future, sent_count) with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \ patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper): @@ -101,10 +93,26 @@ def execute_query_worker(thread_id): # Run the workers concurrently run_in_threads(execute_query_worker, num_threads, pass_index=True) - if TelemetryClientFactory._executor: - TelemetryClientFactory._executor.shutdown(wait=True) + timeout_seconds = 60 + start_time = time.time() + expected_event_count = num_threads + + while len(captured_futures) < expected_event_count and time.time() - start_time < timeout_seconds: + time.sleep(0.1) + + done, not_done = wait(captured_futures, timeout=timeout_seconds) + assert not not_done + + captured_exceptions = [] + captured_responses = [] + for future in done: + try: + response = future.result() + response.raise_for_status() + captured_responses.append(response.json()) + except Exception as e: + captured_exceptions.append(e) - # --- VERIFICATION --- assert not captured_exceptions assert len(captured_responses) > 0