Skip to content

Commit a35d3e9

Browse files
committed
set statement type to query for chunk download
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 2564d41 commit a35d3e9

File tree

12 files changed

+21
-97
lines changed

12 files changed

+21
-97
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,12 @@ def __init__(
135135
super().__init__(
136136
max_download_threads=max_download_threads,
137137
ssl_options=ssl_options,
138-
# TODO: fix these arguments when telemetry is implemented in SEA
139-
session_id_hex=None,
140138
statement_id=statement_id,
141-
statement_type=StatementType.NONE,
142139
chunk_id=0,
143140
schema_bytes=None,
144141
lz4_compressed=lz4_compressed,
145142
description=description,
143+
session_id_hex=None, # TODO: fix this argument when telemetry is implemented in SEA
146144
)
147145

148146
self._sea_client = sea_client

src/databricks/sql/backend/thrift_backend.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,6 @@ def get_execution_result(
889889
arrow_schema_bytes=schema_bytes,
890890
result_format=t_result_set_metadata_resp.resultFormat,
891891
)
892-
execute_response.command_id.set_statement_type(StatementType.QUERY)
893892

894893
return ThriftResultSet(
895894
connection=cursor.connection,
@@ -1029,8 +1028,6 @@ def execute_command(
10291028
if resp.directResults and resp.directResults.resultSet:
10301029
t_row_set = resp.directResults.resultSet.results
10311030

1032-
execute_response.command_id.set_statement_type(StatementType.QUERY)
1033-
10341031
return ThriftResultSet(
10351032
connection=cursor.connection,
10361033
execute_response=execute_response,
@@ -1072,8 +1069,6 @@ def get_catalogs(
10721069
if resp.directResults and resp.directResults.resultSet:
10731070
t_row_set = resp.directResults.resultSet.results
10741071

1075-
execute_response.command_id.set_statement_type(StatementType.METADATA)
1076-
10771072
return ThriftResultSet(
10781073
connection=cursor.connection,
10791074
execute_response=execute_response,
@@ -1121,8 +1116,6 @@ def get_schemas(
11211116
if resp.directResults and resp.directResults.resultSet:
11221117
t_row_set = resp.directResults.resultSet.results
11231118

1124-
execute_response.command_id.set_statement_type(StatementType.METADATA)
1125-
11261119
return ThriftResultSet(
11271120
connection=cursor.connection,
11281121
execute_response=execute_response,
@@ -1174,8 +1167,6 @@ def get_tables(
11741167
if resp.directResults and resp.directResults.resultSet:
11751168
t_row_set = resp.directResults.resultSet.results
11761169

1177-
execute_response.command_id.set_statement_type(StatementType.METADATA)
1178-
11791170
return ThriftResultSet(
11801171
connection=cursor.connection,
11811172
execute_response=execute_response,
@@ -1227,8 +1218,6 @@ def get_columns(
12271218
if resp.directResults and resp.directResults.resultSet:
12281219
t_row_set = resp.directResults.resultSet.results
12291220

1230-
execute_response.command_id.set_statement_type(StatementType.METADATA)
1231-
12321221
return ThriftResultSet(
12331222
connection=cursor.connection,
12341223
execute_response=execute_response,
@@ -1315,7 +1304,6 @@ def fetch_results(
13151304
ssl_options=self._ssl_options,
13161305
session_id_hex=self._session_id_hex,
13171306
statement_id=command_id.to_hex_guid(),
1318-
statement_type=command_id.statement_type,
13191307
chunk_id=chunk_id,
13201308
)
13211309

src/databricks/sql/backend/types.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,6 @@ def __init__(
301301
self.operation_type = operation_type
302302
self.has_result_set = has_result_set
303303
self.modified_row_count = modified_row_count
304-
self._statement_type = StatementType.NONE
305304

306305
def __str__(self) -> str:
307306
"""
@@ -413,19 +412,6 @@ def to_hex_guid(self) -> str:
413412
else:
414413
return str(self.guid)
415414

416-
def set_statement_type(self, statement_type: StatementType):
417-
"""
418-
Set the statement type for this command.
419-
"""
420-
self._statement_type = statement_type
421-
422-
@property
423-
def statement_type(self) -> StatementType:
424-
"""
425-
Get the statement type for this command.
426-
"""
427-
return self._statement_type
428-
429415

430416
@dataclass
431417
class ExecuteResponse:

src/databricks/sql/client.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def _handle_staging_operation(
708708
session_id_hex=self.connection.get_session_id_hex(),
709709
)
710710

711-
@log_latency()
711+
@log_latency(StatementType.SQL)
712712
def _handle_staging_put(
713713
self, presigned_url: str, local_file: str, headers: Optional[dict] = None
714714
):
@@ -717,7 +717,6 @@ def _handle_staging_put(
717717
Raise an exception if request fails. Returns no data.
718718
"""
719719

720-
self.statement_type = StatementType.SQL
721720
if local_file is None:
722721
raise ProgrammingError(
723722
"Cannot perform PUT without specifying a local_file",
@@ -749,7 +748,7 @@ def _handle_staging_put(
749748
+ "but not yet applied on the server. It's possible this command may fail later."
750749
)
751750

752-
@log_latency()
751+
@log_latency(StatementType.SQL)
753752
def _handle_staging_get(
754753
self, local_file: str, presigned_url: str, headers: Optional[dict] = None
755754
):
@@ -758,7 +757,6 @@ def _handle_staging_get(
758757
Raise an exception if request fails. Returns no data.
759758
"""
760759

761-
self.statement_type = StatementType.SQL
762760
if local_file is None:
763761
raise ProgrammingError(
764762
"Cannot perform GET without specifying a local_file",
@@ -778,13 +776,12 @@ def _handle_staging_get(
778776
with open(local_file, "wb") as fp:
779777
fp.write(r.content)
780778

781-
@log_latency()
779+
@log_latency(StatementType.SQL)
782780
def _handle_staging_remove(
783781
self, presigned_url: str, headers: Optional[dict] = None
784782
):
785783
"""Make an HTTP DELETE request to the presigned_url"""
786784

787-
self.statement_type = StatementType.SQL
788785
r = requests.delete(url=presigned_url, headers=headers)
789786

790787
if not r.ok:
@@ -793,7 +790,7 @@ def _handle_staging_remove(
793790
session_id_hex=self.connection.get_session_id_hex(),
794791
)
795792

796-
@log_latency()
793+
@log_latency(StatementType.QUERY)
797794
def execute(
798795
self,
799796
operation: str,
@@ -832,7 +829,6 @@ def execute(
832829
:returns self
833830
"""
834831

835-
self.statement_type = StatementType.QUERY
836832
logger.debug(
837833
"Cursor.execute(operation=%s, parameters=%s)", operation, parameters
838834
)
@@ -879,7 +875,7 @@ def execute(
879875

880876
return self
881877

882-
@log_latency()
878+
@log_latency(StatementType.QUERY)
883879
def execute_async(
884880
self,
885881
operation: str,
@@ -895,7 +891,6 @@ def execute_async(
895891
:return:
896892
"""
897893

898-
self.statement_type = StatementType.QUERY
899894
param_approach = self._determine_parameter_approach(parameters)
900895
if param_approach == ParameterApproach.NONE:
901896
prepared_params = NO_NATIVE_PARAMS
@@ -999,14 +994,13 @@ def executemany(self, operation, seq_of_parameters):
999994
self.execute(operation, parameters)
1000995
return self
1001996

1002-
@log_latency()
997+
@log_latency(StatementType.METADATA)
1003998
def catalogs(self) -> "Cursor":
1004999
"""
10051000
Get all available catalogs.
10061001
10071002
:returns self
10081003
"""
1009-
self.statement_type = StatementType.METADATA
10101004
self._check_not_closed()
10111005
self._close_and_clear_active_result_set()
10121006
self.active_result_set = self.backend.get_catalogs(
@@ -1017,7 +1011,7 @@ def catalogs(self) -> "Cursor":
10171011
)
10181012
return self
10191013

1020-
@log_latency()
1014+
@log_latency(StatementType.METADATA)
10211015
def schemas(
10221016
self, catalog_name: Optional[str] = None, schema_name: Optional[str] = None
10231017
) -> "Cursor":
@@ -1027,7 +1021,6 @@ def schemas(
10271021
Names can contain % wildcards.
10281022
:returns self
10291023
"""
1030-
self.statement_type = StatementType.METADATA
10311024
self._check_not_closed()
10321025
self._close_and_clear_active_result_set()
10331026
self.active_result_set = self.backend.get_schemas(
@@ -1040,7 +1033,7 @@ def schemas(
10401033
)
10411034
return self
10421035

1043-
@log_latency()
1036+
@log_latency(StatementType.METADATA)
10441037
def tables(
10451038
self,
10461039
catalog_name: Optional[str] = None,
@@ -1054,7 +1047,6 @@ def tables(
10541047
Names can contain % wildcards.
10551048
:returns self
10561049
"""
1057-
self.statement_type = StatementType.METADATA
10581050
self._check_not_closed()
10591051
self._close_and_clear_active_result_set()
10601052

@@ -1070,7 +1062,7 @@ def tables(
10701062
)
10711063
return self
10721064

1073-
@log_latency()
1065+
@log_latency(StatementType.METADATA)
10741066
def columns(
10751067
self,
10761068
catalog_name: Optional[str] = None,
@@ -1084,7 +1076,6 @@ def columns(
10841076
Names can contain % wildcards.
10851077
:returns self
10861078
"""
1087-
self.statement_type = StatementType.METADATA
10881079
self._check_not_closed()
10891080
self._close_and_clear_active_result_set()
10901081

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(
2424
ssl_options: SSLOptions,
2525
session_id_hex: Optional[str],
2626
statement_id: str,
27-
statement_type: StatementType,
2827
chunk_id: int,
2928
):
3029
self._pending_links: List[Tuple[int, TSparkArrowResultLink]] = []
@@ -48,7 +47,6 @@ def __init__(
4847
self._ssl_options = ssl_options
4948
self.session_id_hex = session_id_hex
5049
self.statement_id = statement_id
51-
self.statement_type = statement_type
5250

5351
def get_next_downloaded_file(
5452
self, next_row_offset: int
@@ -111,7 +109,6 @@ def _schedule_downloads(self):
111109
chunk_id=chunk_id,
112110
session_id_hex=self.session_id_hex,
113111
statement_id=self.statement_id,
114-
statement_type=self.statement_type,
115112
)
116113
task = self._thread_pool.submit(handler.run)
117114
self._download_tasks.append(task)

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,15 @@ def __init__(
7272
chunk_id: int,
7373
session_id_hex: Optional[str],
7474
statement_id: str,
75-
statement_type: StatementType,
7675
):
7776
self.settings = settings
7877
self.link = link
7978
self._ssl_options = ssl_options
8079
self.chunk_id = chunk_id
8180
self.session_id_hex = session_id_hex
8281
self.statement_id = statement_id
83-
self.statement_type = statement_type
8482

85-
@log_latency()
83+
@log_latency(StatementType.QUERY)
8684
def run(self) -> DownloadedFile:
8785
"""
8886
Download the file described in the cloud fetch link.

src/databricks/sql/result_set.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def __init__(
217217
:param ssl_options: SSL options for cloud fetch
218218
:param is_direct_results: Whether there are more rows to fetch
219219
"""
220-
self.statement_type = execute_response.command_id.statement_type
221220
self.num_downloaded_chunks = 0
222221

223222
# Initialize ThriftResultSet-specific attributes
@@ -240,7 +239,6 @@ def __init__(
240239
ssl_options=ssl_options,
241240
session_id_hex=session_id_hex,
242241
statement_id=execute_response.command_id.to_hex_guid(),
243-
statement_type=self.statement_type,
244242
chunk_id=self.num_downloaded_chunks,
245243
)
246244
if t_row_set and t_row_set.resultLinks:

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ def get_retry_count(self):
4343
def get_chunk_id(self):
4444
pass
4545

46-
def get_statement_type(self):
47-
pass
48-
4946

5047
class CursorExtractor(TelemetryExtractor):
5148
"""
@@ -86,9 +83,6 @@ def get_retry_count(self) -> int:
8683
def get_chunk_id(self):
8784
return None
8885

89-
def get_statement_type(self):
90-
return self.statement_type
91-
9286

9387
class ResultSetDownloadHandlerExtractor(TelemetryExtractor):
9488
"""
@@ -114,9 +108,6 @@ def get_retry_count(self) -> Optional[int]:
114108
def get_chunk_id(self) -> Optional[int]:
115109
return self._obj.chunk_id
116110

117-
def get_statement_type(self):
118-
return self.statement_type
119-
120111

121112
def get_extractor(obj):
122113
"""
@@ -144,7 +135,7 @@ def get_extractor(obj):
144135
return None
145136

146137

147-
def log_latency():
138+
def log_latency(statement_type: StatementType = StatementType.NONE):
148139
"""
149140
Decorator for logging execution latency and telemetry information.
150141
@@ -159,7 +150,7 @@ def log_latency():
159150
- Sends the telemetry data asynchronously via TelemetryClient
160151
161152
Usage:
162-
@log_latency()
153+
@log_latency(StatementType.QUERY)
163154
def execute(self, query):
164155
# Method implementation
165156
pass
@@ -199,7 +190,7 @@ def _safe_call(func_to_call):
199190
statement_id = _safe_call(extractor.get_statement_id)
200191

201192
sql_exec_event = SqlExecutionEvent(
202-
statement_type=_safe_call(extractor.get_statement_type),
193+
statement_type=statement_type,
203194
is_compressed=_safe_call(extractor.get_is_compressed),
204195
execution_result=_safe_call(
205196
extractor.get_execution_result_format

0 commit comments

Comments
 (0)