Skip to content

Commit a094659

Browse files
committed
latency logs funcitionality
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent f101b19 commit a094659

File tree

4 files changed

+189
-1
lines changed

4 files changed

+189
-1
lines changed

src/databricks/sql/client.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
DriverConnectionParameters,
6262
HostDetails,
6363
)
64-
64+
from databricks.sql.telemetry.latency_logger import log_latency
6565

6666
logger = logging.getLogger(__name__)
6767

@@ -758,6 +758,7 @@ def _handle_staging_operation(
758758
session_id_hex=self.connection.get_session_id_hex(),
759759
)
760760

761+
@log_latency()
761762
def _handle_staging_put(
762763
self, presigned_url: str, local_file: str, headers: Optional[dict] = None
763764
):
@@ -797,6 +798,7 @@ def _handle_staging_put(
797798
+ "but not yet applied on the server. It's possible this command may fail later."
798799
)
799800

801+
@log_latency()
800802
def _handle_staging_get(
801803
self, local_file: str, presigned_url: str, headers: Optional[dict] = None
802804
):
@@ -824,6 +826,7 @@ def _handle_staging_get(
824826
with open(local_file, "wb") as fp:
825827
fp.write(r.content)
826828

829+
@log_latency()
827830
def _handle_staging_remove(
828831
self, presigned_url: str, headers: Optional[dict] = None
829832
):
@@ -837,6 +840,7 @@ def _handle_staging_remove(
837840
session_id_hex=self.connection.get_session_id_hex(),
838841
)
839842

843+
@log_latency()
840844
def execute(
841845
self,
842846
operation: str,
@@ -927,6 +931,7 @@ def execute(
927931

928932
return self
929933

934+
@log_latency()
930935
def execute_async(
931936
self,
932937
operation: str,
@@ -1052,6 +1057,7 @@ def executemany(self, operation, seq_of_parameters):
10521057
self.execute(operation, parameters)
10531058
return self
10541059

1060+
@log_latency()
10551061
def catalogs(self) -> "Cursor":
10561062
"""
10571063
Get all available catalogs.
@@ -1075,6 +1081,7 @@ def catalogs(self) -> "Cursor":
10751081
)
10761082
return self
10771083

1084+
@log_latency()
10781085
def schemas(
10791086
self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None
10801087
) -> "Cursor":
@@ -1103,6 +1110,7 @@ def schemas(
11031110
)
11041111
return self
11051112

1113+
@log_latency()
11061114
def tables(
11071115
self,
11081116
catalog_name: Optional[str] = None,
@@ -1138,6 +1146,7 @@ def tables(
11381146
)
11391147
return self
11401148

1149+
@log_latency()
11411150
def columns(
11421151
self,
11431152
catalog_name: Optional[str] = None,
@@ -1173,6 +1182,7 @@ def columns(
11731182
)
11741183
return self
11751184

1185+
@log_latency()
11761186
def fetchall(self) -> List[Row]:
11771187
"""
11781188
Fetch all (remaining) rows of a query result, returning them as a sequence of sequences.
@@ -1206,6 +1216,7 @@ def fetchone(self) -> Optional[Row]:
12061216
session_id_hex=self.connection.get_session_id_hex(),
12071217
)
12081218

1219+
@log_latency()
12091220
def fetchmany(self, size: int) -> List[Row]:
12101221
"""
12111222
Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
@@ -1231,6 +1242,7 @@ def fetchmany(self, size: int) -> List[Row]:
12311242
session_id_hex=self.connection.get_session_id_hex(),
12321243
)
12331244

1245+
@log_latency()
12341246
def fetchall_arrow(self) -> "pyarrow.Table":
12351247
self._check_not_closed()
12361248
if self.active_result_set:
@@ -1241,6 +1253,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
12411253
session_id_hex=self.connection.get_session_id_hex(),
12421254
)
12431255

1256+
@log_latency()
12441257
def fetchmany_arrow(self, size) -> "pyarrow.Table":
12451258
self._check_not_closed()
12461259
if self.active_result_set:
@@ -1406,6 +1419,7 @@ def _fill_results_buffer(self):
14061419
self.results = results
14071420
self.has_more_rows = has_more_rows
14081421

1422+
@log_latency()
14091423
def _convert_columnar_table(self, table):
14101424
column_names = [c[0] for c in self.description]
14111425
ResultRow = Row(*column_names)
@@ -1418,6 +1432,7 @@ def _convert_columnar_table(self, table):
14181432

14191433
return result
14201434

1435+
@log_latency()
14211436
def _convert_arrow_table(self, table):
14221437
column_names = [c[0] for c in self.description]
14231438
ResultRow = Row(*column_names)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import time
2+
import functools
3+
from typing import Optional
4+
from uuid import UUID
5+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
6+
from databricks.sql.telemetry.models.event import (
7+
SqlExecutionEvent,
8+
)
9+
from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType
10+
from databricks.sql.utils import ColumnQueue
11+
12+
# Helper to get statement_id/query_id from instance if available
13+
def _get_statement_id(instance) -> Optional[str]:
14+
"""
15+
Get statement ID from an instance using various methods:
16+
1. For Cursor: Use query_id property which returns UUID from active_op_handle
17+
2. For ResultSet: Use command_id which contains operationId
18+
19+
Note: ThriftBackend itself doesn't have a statement ID since one backend
20+
can handle multiple concurrent operations/cursors.
21+
"""
22+
if hasattr(instance, "query_id"):
23+
return instance.query_id
24+
25+
if hasattr(instance, "command_id") and instance.command_id:
26+
return str(UUID(bytes=instance.command_id.operationId.guid))
27+
28+
return None
29+
30+
31+
def _get_session_id_hex(instance) -> Optional[str]:
32+
if hasattr(instance, "connection") and instance.connection:
33+
return instance.connection.get_session_id_hex()
34+
if hasattr(instance, "get_session_id_hex"):
35+
return instance.get_session_id_hex()
36+
return None
37+
38+
39+
def _get_statement_type(func_name: str) -> StatementType: # TODO: implement this
40+
return StatementType.SQL
41+
42+
43+
def _get_is_compressed(instance) -> bool:
44+
"""
45+
Get compression status from instance:
46+
1. Direct lz4_compression attribute (Connection)
47+
2. Through connection attribute (Cursor/ResultSet)
48+
3. Through thrift_backend attribute (Cursor)
49+
"""
50+
if hasattr(instance, "lz4_compression"):
51+
return instance.lz4_compression
52+
if hasattr(instance, "connection") and instance.connection:
53+
return instance.connection.lz4_compression
54+
if hasattr(instance, "thrift_backend") and instance.thrift_backend:
55+
return instance.thrift_backend.lz4_compressed
56+
return False
57+
58+
59+
def _get_execution_result(instance) -> ExecutionResultFormat:
60+
"""
61+
Get execution result format from instance:
62+
1. For ResultSet: Check if using cloud fetch (external_links) or arrow/columnar format
63+
2. For Cursor: Check through active_result_set
64+
3. For ThriftBackend: Check result format from server
65+
"""
66+
if hasattr(instance, "_use_cloud_fetch") and instance._use_cloud_fetch:
67+
return ExecutionResultFormat.EXTERNAL_LINKS
68+
69+
if hasattr(instance, "active_result_set") and instance.active_result_set:
70+
if isinstance(instance.active_result_set.results, ColumnQueue):
71+
return ExecutionResultFormat.COLUMNAR_INLINE
72+
return ExecutionResultFormat.INLINE_ARROW
73+
74+
if hasattr(instance, "thrift_backend") and instance.thrift_backend:
75+
if hasattr(instance.thrift_backend, "_use_arrow_native_complex_types"):
76+
return ExecutionResultFormat.INLINE_ARROW
77+
78+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
79+
80+
81+
def _get_retry_count(instance) -> int:
82+
"""
83+
Get retry count from instance by checking retry_policy.history length.
84+
The retry_policy is only accessible through thrift_backend.
85+
"""
86+
# TODO: implement this
87+
88+
return 0
89+
90+
91+
def log_latency():
92+
def decorator(func):
93+
@functools.wraps(func)
94+
def wrapper(self, *args, **kwargs):
95+
start_time = time.perf_counter()
96+
result = None
97+
try:
98+
result = func(self, *args, **kwargs)
99+
return result
100+
finally:
101+
end_time = time.perf_counter()
102+
duration_ms = int((end_time - start_time) * 1000)
103+
104+
session_id_hex = _get_session_id_hex(self)
105+
106+
if session_id_hex:
107+
statement_id = _get_statement_id(self)
108+
statement_type = _get_statement_type(func.__name__)
109+
is_compressed = _get_is_compressed(self)
110+
execution_result = _get_execution_result(self)
111+
retry_count = _get_retry_count(self)
112+
113+
sql_exec_event = SqlExecutionEvent(
114+
statement_type=statement_type,
115+
is_compressed=is_compressed,
116+
execution_result=execution_result,
117+
retry_count=retry_count,
118+
)
119+
120+
telemetry_client = TelemetryClientFactory.get_telemetry_client(
121+
session_id_hex
122+
)
123+
telemetry_client.export_latency_log(
124+
latency_ms=duration_ms,
125+
sql_execution_event=sql_exec_event,
126+
sql_statement_id=statement_id,
127+
)
128+
129+
return wrapper
130+
131+
return decorator

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent):
112112
def export_failure_log(self, error_name, error_message):
113113
raise NotImplementedError("Subclasses must implement export_failure_log")
114114

115+
@abstractmethod
116+
def export_latency_log(
117+
self, latency_ms, sql_execution_event, sql_statement_id=None
118+
):
119+
raise NotImplementedError("Subclasses must implement export_latency_log")
120+
115121
@abstractmethod
116122
def close(self):
117123
raise NotImplementedError("Subclasses must implement close")
@@ -136,6 +142,11 @@ def export_initial_telemetry_log(self, driver_connection_params, user_agent):
136142
def export_failure_log(self, error_name, error_message):
137143
pass
138144

145+
def export_latency_log(
146+
self, latency_ms, sql_execution_event, sql_statement_id=None
147+
):
148+
pass
149+
139150
def close(self):
140151
pass
141152

@@ -299,6 +310,34 @@ def export_failure_log(self, error_name, error_message):
299310
except Exception as e:
300311
logger.debug("Failed to export failure log: %s", e)
301312

313+
def export_latency_log(
314+
self, latency_ms, sql_execution_event, sql_statement_id=None
315+
):
316+
logger.debug("Exporting latency log for connection %s", self._session_id_hex)
317+
try:
318+
telemetry_frontend_log = TelemetryFrontendLog(
319+
frontend_log_event_id=str(uuid.uuid4()),
320+
context=FrontendLogContext(
321+
client_context=TelemetryClientContext(
322+
timestamp_millis=int(time.time() * 1000),
323+
user_agent=self._user_agent,
324+
)
325+
),
326+
entry=FrontendLogEntry(
327+
sql_driver_log=TelemetryEvent(
328+
session_id=self._session_id_hex,
329+
system_configuration=TelemetryHelper.get_driver_system_configuration(),
330+
driver_connection_params=self._driver_connection_params,
331+
sql_statement_id=sql_statement_id,
332+
sql_operation=sql_execution_event,
333+
operation_latency_ms=latency_ms,
334+
)
335+
),
336+
)
337+
self._export_event(telemetry_frontend_log)
338+
except Exception as e:
339+
logger.debug("Failed to export latency log: %s", e)
340+
302341
def close(self):
303342
"""Flush remaining events before closing"""
304343
logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex)

src/databricks/sql/thrift_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,9 @@ def open_session(self, session_configuration, catalog, schema):
583583
self._transport.close()
584584
raise
585585

586+
def get_session_id_hex(self) -> str:
587+
return self._session_id_hex
588+
586589
def close_session(self, session_handle) -> None:
587590
req = ttypes.TCloseSessionReq(sessionHandle=session_handle)
588591
try:

0 commit comments

Comments
 (0)