Skip to content

Commit f6281fe

Browse files
committed
assert session id, statement id
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 9f6e2e1 commit f6281fe

File tree

1 file changed

+59
-4
lines changed

1 file changed

+59
-4
lines changed

tests/e2e/test_concurrent_telemetry.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import patch
33
import pytest
44

5+
from databricks.sql.telemetry.models.enums import StatementType
56
from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory
67
from tests.e2e.test_driver import PySQLPytestTestCase
78

@@ -49,6 +50,9 @@ def test_concurrent_queries_sends_telemetry(self):
4950
num_threads = 5
5051
captured_telemetry = []
5152
captured_telemetry_lock = threading.Lock()
53+
captured_session_ids = []
54+
captured_statement_ids = []
55+
capture_info_lock = threading.Lock()
5256

5357
original_send_telemetry = TelemetryClient._send_telemetry
5458

@@ -62,8 +66,17 @@ def send_telemetry_wrapper(self_client, events):
6266
def execute_query_worker(thread_id):
6367
"""Each thread creates a connection and executes a query."""
6468
with self.connection(extra_params={"enable_telemetry": True}) as conn:
69+
# Capture the session ID from the connection before executing the query
70+
session_id_hex = conn.get_session_id_hex()
71+
with capture_info_lock:
72+
captured_session_ids.append(session_id_hex)
73+
6574
with conn.cursor() as cursor:
6675
cursor.execute(f"SELECT {thread_id}")
76+
# Capture the statement ID after executing the query
77+
statement_id = cursor.query_id
78+
with capture_info_lock:
79+
captured_statement_ids.append(statement_id)
6780
cursor.fetchall()
6881

6982
# Run the workers concurrently
@@ -73,10 +86,52 @@ def execute_query_worker(thread_id):
7386
TelemetryClientFactory._executor.shutdown(wait=True)
7487

7588
# --- VERIFICATION ---
76-
assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
89+
assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
90+
assert len(captured_session_ids) == num_threads # One session ID per thread
91+
assert len(captured_statement_ids) == num_threads # One statement ID per thread (per query)
7792

78-
events_with_latency = [
93+
# Separate initial logs from latency logs
94+
initial_logs = [
7995
e for e in captured_telemetry
80-
if e.entry.sql_driver_log.operation_latency_ms is not None and e.entry.sql_driver_log.sql_statement_id is not None
96+
if e.entry.sql_driver_log.operation_latency_ms is None
97+
and e.entry.sql_driver_log.driver_connection_params is not None
98+
and e.entry.sql_driver_log.system_configuration is not None
8199
]
82-
assert len(events_with_latency) == num_threads # 1 event per thread (execute)
100+
latency_logs = [
101+
e for e in captured_telemetry
102+
if e.entry.sql_driver_log.operation_latency_ms is not None
103+
and e.entry.sql_driver_log.sql_statement_id is not None
104+
and e.entry.sql_driver_log.sql_operation.statement_type == StatementType.QUERY
105+
]
106+
107+
# Verify counts
108+
assert len(initial_logs) == num_threads
109+
assert len(latency_logs) == num_threads
110+
111+
# Verify that telemetry events contain the exact session IDs we captured from connections
112+
telemetry_session_ids = set()
113+
for event in captured_telemetry:
114+
session_id = event.entry.sql_driver_log.session_id
115+
assert session_id is not None
116+
telemetry_session_ids.add(session_id)
117+
118+
captured_session_ids_set = set(captured_session_ids)
119+
assert telemetry_session_ids == captured_session_ids_set
120+
assert len(captured_session_ids_set) == num_threads
121+
122+
# Verify that telemetry latency logs contain the exact statement IDs we captured from cursors
123+
telemetry_statement_ids = set()
124+
for event in latency_logs:
125+
statement_id = event.entry.sql_driver_log.sql_statement_id
126+
assert statement_id is not None
127+
telemetry_statement_ids.add(statement_id)
128+
129+
captured_statement_ids_set = set(captured_statement_ids)
130+
assert telemetry_statement_ids == captured_statement_ids_set
131+
assert len(captured_statement_ids_set) == num_threads
132+
133+
# Verify that each latency log has a statement ID from our captured set
134+
for event in latency_logs:
135+
log = event.entry.sql_driver_log
136+
assert log.sql_statement_id in captured_statement_ids
137+
assert log.session_id in captured_session_ids

0 commit comments

Comments
 (0)