Skip to content

Commit 49082fb

Browse files
committed
added the get_attribute functions to the classes
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 34b63e4 commit 49082fb

File tree

4 files changed

+80
-111
lines changed

4 files changed

+80
-111
lines changed

.github/workflows/code-quality-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ jobs:
112112
# run test suite
113113
#----------------------------------------------
114114
- name: Run tests
115-
run: poetry run python -m pytest tests/unit -v -s
115+
run: poetry run python -m pytest tests/unit -v -s
116116
check-linting:
117117
runs-on: ubuntu-latest
118118
strategy:

src/databricks/sql/client.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
transform_paramstyle,
3232
ColumnTable,
3333
ColumnQueue,
34+
ArrowQueue,
35+
CloudFetchQueue,
3436
)
3537
from databricks.sql.parameters.native import (
3638
DbsqlParameterBase,
@@ -62,6 +64,7 @@
6264
HostDetails,
6365
)
6466
from databricks.sql.telemetry.latency_logger import log_latency
67+
from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType
6568

6669
logger = logging.getLogger(__name__)
6770

@@ -1355,6 +1358,35 @@ def setoutputsize(self, size, column=None):
13551358
"""Does nothing by default"""
13561359
pass
13571360

1361+
def get_statement_id(self) -> Optional[str]:
1362+
return self.query_id
1363+
1364+
def get_session_id_hex(self) -> Optional[str]:
1365+
return self.connection.get_session_id_hex()
1366+
1367+
def get_is_compressed(self) -> bool:
1368+
return self.connection.lz4_compression
1369+
1370+
def get_execution_result(self) -> ExecutionResultFormat:
1371+
if self.active_result_set is None:
1372+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
1373+
1374+
if isinstance(self.active_result_set.results, ColumnQueue):
1375+
return ExecutionResultFormat.COLUMNAR_INLINE
1376+
elif isinstance(self.active_result_set.results, CloudFetchQueue):
1377+
return ExecutionResultFormat.EXTERNAL_LINKS
1378+
elif isinstance(self.active_result_set.results, ArrowQueue):
1379+
return ExecutionResultFormat.INLINE_ARROW
1380+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
1381+
1382+
def get_retry_count(self) -> int:
1383+
# return len(self.thrift_backend.retry_policy.history)
1384+
return 0
1385+
1386+
def get_statement_type(self, func_name: str) -> StatementType:
1387+
# TODO: Implement this
1388+
return StatementType.SQL
1389+
13581390

13591391
class ResultSet:
13601392
def __init__(
@@ -1654,3 +1686,31 @@ def map_col_type(type_):
16541686
(column.name, map_col_type(column.datatype), None, None, None, None, None)
16551687
for column in table_schema_message.columns
16561688
]
1689+
1690+
def get_statement_id(self) -> Optional[str]:
1691+
if self.command_id:
1692+
return str(UUID(bytes=self.command_id.operationId.guid))
1693+
return None
1694+
1695+
def get_session_id_hex(self) -> Optional[str]:
1696+
return self.connection.get_session_id_hex()
1697+
1698+
def get_is_compressed(self) -> bool:
1699+
return self.lz4_compressed
1700+
1701+
def get_execution_result(self) -> ExecutionResultFormat:
1702+
if isinstance(self.results, ColumnQueue):
1703+
return ExecutionResultFormat.COLUMNAR_INLINE
1704+
elif isinstance(self.results, CloudFetchQueue):
1705+
return ExecutionResultFormat.EXTERNAL_LINKS
1706+
elif isinstance(self.results, ArrowQueue):
1707+
return ExecutionResultFormat.INLINE_ARROW
1708+
return ExecutionResultFormat.FORMAT_UNSPECIFIED
1709+
1710+
def get_statement_type(self, func_name: str) -> StatementType:
1711+
# TODO: Implement this
1712+
return StatementType.SQL
1713+
1714+
def get_retry_count(self) -> int:
1715+
# return len(self.thrift_backend.retry_policy.history)
1716+
return 0

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 18 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,9 @@
11
import time
22
import functools
3-
from typing import Optional
4-
from uuid import UUID
53
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
64
from databricks.sql.telemetry.models.event import (
75
SqlExecutionEvent,
86
)
9-
from databricks.sql.telemetry.models.enums import ExecutionResultFormat, StatementType
10-
from databricks.sql.utils import ColumnQueue
11-
12-
# Helper to get statement_id/query_id from instance if available
13-
def _get_statement_id(instance) -> Optional[str]:
14-
"""
15-
Get statement ID from an instance using various methods:
16-
1. For Cursor: Use query_id property which returns UUID from active_op_handle
17-
2. For ResultSet: Use command_id which contains operationId
18-
19-
Note: ThriftBackend itself doesn't have a statement ID since one backend
20-
can handle multiple concurrent operations/cursors.
21-
"""
22-
if hasattr(instance, "query_id"):
23-
return instance.query_id
24-
25-
if hasattr(instance, "command_id") and instance.command_id:
26-
return str(UUID(bytes=instance.command_id.operationId.guid))
27-
28-
return None
29-
30-
31-
def _get_session_id_hex(instance) -> Optional[str]:
32-
if hasattr(instance, "connection") and instance.connection:
33-
return instance.connection.get_session_id_hex()
34-
if hasattr(instance, "get_session_id_hex"):
35-
return instance.get_session_id_hex()
36-
return None
37-
38-
39-
def _get_statement_type(func_name: str) -> StatementType: # TODO: implement this
40-
return StatementType.SQL
41-
42-
43-
def _get_is_compressed(instance) -> bool:
44-
"""
45-
Get compression status from instance:
46-
1. Direct lz4_compression attribute (Connection)
47-
2. Through connection attribute (Cursor/ResultSet)
48-
3. Through thrift_backend attribute (Cursor)
49-
"""
50-
if hasattr(instance, "lz4_compression"):
51-
return instance.lz4_compression
52-
if hasattr(instance, "connection") and instance.connection:
53-
return instance.connection.lz4_compression
54-
if hasattr(instance, "thrift_backend") and instance.thrift_backend:
55-
return instance.thrift_backend.lz4_compressed
56-
return False
57-
58-
59-
def _get_execution_result(instance) -> ExecutionResultFormat:
60-
"""
61-
Get execution result format from instance:
62-
1. For ResultSet: Check if using cloud fetch (external_links) or arrow/columnar format
63-
2. For Cursor: Check through active_result_set
64-
3. For ThriftBackend: Check result format from server
65-
"""
66-
if hasattr(instance, "_use_cloud_fetch") and instance._use_cloud_fetch:
67-
return ExecutionResultFormat.EXTERNAL_LINKS
68-
69-
if hasattr(instance, "active_result_set") and instance.active_result_set:
70-
if isinstance(instance.active_result_set.results, ColumnQueue):
71-
return ExecutionResultFormat.COLUMNAR_INLINE
72-
return ExecutionResultFormat.INLINE_ARROW
73-
74-
if hasattr(instance, "thrift_backend") and instance.thrift_backend:
75-
if hasattr(instance.thrift_backend, "_use_arrow_native_complex_types"):
76-
return ExecutionResultFormat.INLINE_ARROW
77-
78-
return ExecutionResultFormat.FORMAT_UNSPECIFIED
79-
80-
81-
def _get_retry_count(instance) -> int:
82-
"""
83-
Get retry count from instance by checking retry_policy.history length.
84-
The retry_policy is only accessible through thrift_backend.
85-
"""
86-
# TODO: implement this
87-
88-
return 0
897

908

919
def log_latency():
@@ -101,30 +19,24 @@ def wrapper(self, *args, **kwargs):
10119
end_time = time.perf_counter()
10220
duration_ms = int((end_time - start_time) * 1000)
10321

104-
session_id_hex = _get_session_id_hex(self)
105-
106-
if session_id_hex:
107-
statement_id = _get_statement_id(self)
108-
statement_type = _get_statement_type(func.__name__)
109-
is_compressed = _get_is_compressed(self)
110-
execution_result = _get_execution_result(self)
111-
retry_count = _get_retry_count(self)
112-
113-
sql_exec_event = SqlExecutionEvent(
114-
statement_type=statement_type,
115-
is_compressed=is_compressed,
116-
execution_result=execution_result,
117-
retry_count=retry_count,
118-
)
119-
120-
telemetry_client = TelemetryClientFactory.get_telemetry_client(
121-
session_id_hex
122-
)
123-
telemetry_client.export_latency_log(
124-
latency_ms=duration_ms,
125-
sql_execution_event=sql_exec_event,
126-
sql_statement_id=statement_id,
127-
)
22+
session_id_hex = self.get_session_id_hex()
23+
statement_id = self.get_statement_id()
24+
25+
sql_exec_event = SqlExecutionEvent(
26+
statement_type=self.get_statement_type(func.__name__),
27+
is_compressed=self.get_is_compressed(),
28+
execution_result=self.get_execution_result(),
29+
retry_count=self.get_retry_count(),
30+
)
31+
32+
telemetry_client = TelemetryClientFactory.get_telemetry_client(
33+
session_id_hex
34+
)
35+
telemetry_client.export_latency_log(
36+
latency_ms=duration_ms,
37+
sql_execution_event=sql_exec_event,
38+
sql_statement_id=statement_id,
39+
)
12840

12941
return wrapper
13042

src/databricks/sql/thrift_backend.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
convert_column_based_set_to_arrow_table,
4242
)
4343
from databricks.sql.types import SSLOptions
44-
from typing import Optional
44+
4545

4646
logger = logging.getLogger(__name__)
4747

@@ -584,9 +584,6 @@ def open_session(self, session_configuration, catalog, schema):
584584
self._transport.close()
585585
raise
586586

587-
def get_session_id_hex(self) -> Optional[str]:
588-
return self._session_id_hex
589-
590587
def close_session(self, session_handle) -> None:
591588
req = ttypes.TCloseSessionReq(sessionHandle=session_handle)
592589
try:

0 commit comments

Comments
 (0)