Skip to content

Commit acf02bd

Browse files
committed
sync fix
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 7619a69 commit acf02bd

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

tests/e2e/test_concurrent_telemetry.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from concurrent.futures import wait
12
import random
23
import threading
34
import time
@@ -48,8 +49,7 @@ def test_concurrent_queries_sends_telemetry(self):
4849
captured_telemetry = []
4950
captured_session_ids = []
5051
captured_statement_ids = []
51-
captured_responses = []
52-
captured_exceptions = []
52+
captured_futures = []
5353

5454
original_send_telemetry = TelemetryClient._send_telemetry
5555
original_callback = TelemetryClient._telemetry_request_callback
@@ -64,18 +64,9 @@ def callback_wrapper(self_client, future, sent_count):
6464
Wraps the original callback to capture the server's response
6565
or any exceptions from the async network call.
6666
"""
67-
try:
68-
original_callback(self_client, future, sent_count)
69-
70-
# Now, capture the result for our assertions
71-
response = future.result()
72-
response.raise_for_status() # Raise an exception for 4xx/5xx errors
73-
telemetry_response = response.json()
74-
with capture_lock:
75-
captured_responses.append(telemetry_response)
76-
except Exception as e:
77-
with capture_lock:
78-
captured_exceptions.append(e)
67+
with capture_lock:
68+
captured_futures.append(future)
69+
original_callback(self_client, future, sent_count)
7970

8071
with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \
8172
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
@@ -102,17 +93,26 @@ def execute_query_worker(thread_id):
10293
# Run the workers concurrently
10394
run_in_threads(execute_query_worker, num_threads, pass_index=True)
10495

105-
timeout_seconds = 60 # Max time to wait for telemetry to arrive
96+
timeout_seconds = 60
10697
start_time = time.time()
107-
expected_event_count = num_threads * 2 # initial_log + latency_log per thread
108-
109-
# Poll until the expected number of events are captured or we time out
110-
while len(captured_telemetry) < expected_event_count:
111-
if time.time() - start_time > timeout_seconds:
112-
break # Exit loop if timeout is reached
113-
time.sleep(0.1)
98+
expected_event_count = num_threads
99+
100+
while len(captured_futures) < expected_event_count and time.time() - start_time < timeout_seconds:
101+
time.sleep(0.1)
102+
103+
done, not_done = wait(captured_futures, timeout=timeout_seconds)
104+
assert not not_done
105+
106+
captured_exceptions = []
107+
captured_responses = []
108+
for future in done:
109+
try:
110+
response = future.result()
111+
response.raise_for_status()
112+
captured_responses.append(response.json())
113+
except Exception as e:
114+
captured_exceptions.append(e)
114115

115-
# --- VERIFICATION ---
116116
assert not captured_exceptions
117117
assert len(captured_responses) > 0
118118

0 commit comments

Comments
 (0)