1+ import threading
2+ from unittest .mock import patch
3+ import pytest
4+
5+ from databricks .sql .telemetry .telemetry_client import TelemetryClient , TelemetryClientFactory
6+ from tests .e2e .test_driver import PySQLPytestTestCase
7+
8+ def run_in_threads (target , num_threads , pass_index = False ):
9+ """Helper to run target function in multiple threads."""
10+ threads = [
11+ threading .Thread (target = target , args = (i ,) if pass_index else ())
12+ for i in range (num_threads )
13+ ]
14+ for t in threads :
15+ t .start ()
16+ for t in threads :
17+ t .join ()
18+
19+
20+ class TestE2ETelemetry (PySQLPytestTestCase ):
21+
22+ @pytest .fixture (autouse = True )
23+ def telemetry_setup_teardown (self ):
24+ """
25+ This fixture ensures the TelemetryClientFactory is in a clean state
26+ before each test and shuts it down afterward. Using a fixture makes
27+ this robust and automatic.
28+ """
29+ # --- SETUP ---
30+ if TelemetryClientFactory ._executor :
31+ TelemetryClientFactory ._executor .shutdown (wait = True )
32+ TelemetryClientFactory ._clients .clear ()
33+ TelemetryClientFactory ._executor = None
34+ TelemetryClientFactory ._initialized = False
35+
36+ yield # This is where the test runs
37+
38+ # --- TEARDOWN ---
39+ if TelemetryClientFactory ._executor :
40+ TelemetryClientFactory ._executor .shutdown (wait = True )
41+ TelemetryClientFactory ._executor = None
42+ TelemetryClientFactory ._initialized = False
43+
44+ def test_concurrent_queries_sends_telemetry (self ):
45+ """
46+ An E2E test where concurrent threads execute real queries against
47+ the staging endpoint, while we capture and verify the generated telemetry.
48+ """
49+ num_threads = 5
50+ captured_telemetry = []
51+ captured_telemetry_lock = threading .Lock ()
52+
53+ original_send_telemetry = TelemetryClient ._send_telemetry
54+
55+ def send_telemetry_wrapper (self_client , events ):
56+ with captured_telemetry_lock :
57+ captured_telemetry .extend (events )
58+ original_send_telemetry (self_client , events )
59+
60+ with patch .object (TelemetryClient , "_send_telemetry" , send_telemetry_wrapper ):
61+
62+ def execute_query_worker (thread_id ):
63+ """Each thread creates a connection and executes a query."""
64+ with self .connection (extra_params = {"enable_telemetry" : True }) as conn :
65+ with conn .cursor () as cursor :
66+ cursor .execute (f"SELECT { thread_id } " )
67+ cursor .fetchall ()
68+
69+ # Run the workers concurrently
70+ run_in_threads (execute_query_worker , num_threads , pass_index = True )
71+
72+ if TelemetryClientFactory ._executor :
73+ TelemetryClientFactory ._executor .shutdown (wait = True )
74+
75+ # --- VERIFICATION ---
76+ assert len (captured_telemetry ) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
77+
78+ events_with_latency = [
79+ e for e in captured_telemetry
80+ if e .entry .sql_driver_log .operation_latency_ms is not None and e .entry .sql_driver_log .sql_statement_id is not None
81+ ]
82+ assert len (events_with_latency ) == num_threads # 1 event per thread (execute)
0 commit comments