Skip to content

Commit 120dfc0

Browse files
remove service specific state from ExecuteResponse
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 88371ea commit 120dfc0

File tree

6 files changed

+103
-56
lines changed

6 files changed

+103
-56
lines changed

src/databricks/sql/backend/sea/client.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,8 @@ def _results_message_to_execute_response(
349349
command_id=CommandId.from_sea_statement_id(response.statement_id),
350350
status=response.status.state,
351351
description=description,
352-
has_been_closed_server_side=False,
353352
lz4_compressed=lz4_compressed,
354353
is_staging_operation=response.manifest.is_volume_operation,
355-
arrow_schema_bytes=None,
356354
result_format=response.manifest.format,
357355
)
358356

src/databricks/sql/backend/sea/result_set.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
if TYPE_CHECKING:
1717
from databricks.sql.client import Connection
18-
from databricks.sql.exc import ProgrammingError
18+
from databricks.sql.exc import CursorAlreadyClosedError, ProgrammingError, RequestError
1919
from databricks.sql.types import Row
2020
from databricks.sql.backend.sea.queue import JsonQueue, SeaResultSetQueueFactory
21-
from databricks.sql.backend.types import ExecuteResponse
21+
from databricks.sql.backend.types import CommandState, ExecuteResponse
2222
from databricks.sql.result_set import ResultSet
2323

2424
logger = logging.getLogger(__name__)
@@ -61,11 +61,9 @@ def __init__(
6161
buffer_size_bytes=buffer_size_bytes,
6262
command_id=execute_response.command_id,
6363
status=execute_response.status,
64-
has_been_closed_server_side=execute_response.has_been_closed_server_side,
6564
description=execute_response.description,
6665
is_staging_operation=execute_response.is_staging_operation,
6766
lz4_compressed=execute_response.lz4_compressed,
68-
arrow_schema_bytes=execute_response.arrow_schema_bytes,
6967
)
7068

7169
# Assert that the backend is of the correct type
@@ -274,3 +272,21 @@ def fetchall(self) -> List[Row]:
274272
return self._create_json_table(self.fetchall_json())
275273
else:
276274
raise NotImplementedError("fetchall only supported for JSON data")
275+
276+
def close(self) -> None:
277+
"""
278+
Close the result set.
279+
280+
If the connection has not been closed, and the result set has not already
281+
been closed on the server for some other reason, issue a request to the server to close it.
282+
"""
283+
try:
284+
if self.results is not None:
285+
self.results.close()
286+
if self.status != CommandState.CLOSED and self.connection.open:
287+
self.backend.close_command(self.command_id)
288+
except RequestError as e:
289+
if isinstance(e.args[1], CursorAlreadyClosedError):
290+
logger.info("Operation was canceled by a prior request")
291+
finally:
292+
self.status = CommandState.CLOSED

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ def _filter_sea_result_set(
6060
command_id=command_id,
6161
status=result_set.status,
6262
description=result_set.description,
63-
has_been_closed_server_side=result_set.has_been_closed_server_side,
6463
lz4_compressed=result_set.lz4_compressed,
65-
arrow_schema_bytes=result_set._arrow_schema_bytes,
6664
is_staging_operation=False,
6765
)
6866

src/databricks/sql/backend/thrift_backend.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -821,14 +821,17 @@ def _results_message_to_execute_response(self, resp, operation_state):
821821
command_id=command_id,
822822
status=status,
823823
description=description,
824-
has_been_closed_server_side=has_been_closed_server_side,
825824
lz4_compressed=lz4_compressed,
826825
is_staging_operation=t_result_set_metadata_resp.isStagingOperation,
827-
arrow_schema_bytes=schema_bytes,
828826
result_format=t_result_set_metadata_resp.resultFormat,
829827
)
830828

831-
return execute_response, is_direct_results
829+
return (
830+
execute_response,
831+
is_direct_results,
832+
has_been_closed_server_side,
833+
schema_bytes,
834+
)
832835

833836
def get_execution_result(
834837
self, command_id: CommandId, cursor: "Cursor"
@@ -881,10 +884,8 @@ def get_execution_result(
881884
command_id=command_id,
882885
status=status,
883886
description=description,
884-
has_been_closed_server_side=False,
885887
lz4_compressed=lz4_compressed,
886888
is_staging_operation=is_staging_operation,
887-
arrow_schema_bytes=schema_bytes,
888889
result_format=t_result_set_metadata_resp.resultFormat,
889890
)
890891

@@ -898,6 +899,8 @@ def get_execution_result(
898899
max_download_threads=self.max_download_threads,
899900
ssl_options=self._ssl_options,
900901
is_direct_results=is_direct_results,
902+
arrow_schema_bytes=schema_bytes,
903+
has_been_closed_server_side=False,
901904
)
902905

903906
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -1016,9 +1019,12 @@ def execute_command(
10161019
self._handle_execute_response_async(resp, cursor)
10171020
return None
10181021
else:
1019-
execute_response, is_direct_results = self._handle_execute_response(
1020-
resp, cursor
1021-
)
1022+
(
1023+
execute_response,
1024+
is_direct_results,
1025+
has_been_closed_server_side,
1026+
schema_bytes,
1027+
) = self._handle_execute_response(resp, cursor)
10221028

10231029
t_row_set = None
10241030
if resp.directResults and resp.directResults.resultSet:
@@ -1034,6 +1040,8 @@ def execute_command(
10341040
max_download_threads=self.max_download_threads,
10351041
ssl_options=self._ssl_options,
10361042
is_direct_results=is_direct_results,
1043+
has_been_closed_server_side=has_been_closed_server_side,
1044+
arrow_schema_bytes=schema_bytes,
10371045
)
10381046

10391047
def get_catalogs(
@@ -1055,9 +1063,12 @@ def get_catalogs(
10551063
)
10561064
resp = self.make_request(self._client.GetCatalogs, req)
10571065

1058-
execute_response, is_direct_results = self._handle_execute_response(
1059-
resp, cursor
1060-
)
1066+
(
1067+
execute_response,
1068+
is_direct_results,
1069+
has_been_closed_server_side,
1070+
schema_bytes,
1071+
) = self._handle_execute_response(resp, cursor)
10611072

10621073
t_row_set = None
10631074
if resp.directResults and resp.directResults.resultSet:
@@ -1073,6 +1084,8 @@ def get_catalogs(
10731084
max_download_threads=self.max_download_threads,
10741085
ssl_options=self._ssl_options,
10751086
is_direct_results=is_direct_results,
1087+
has_been_closed_server_side=has_been_closed_server_side,
1088+
arrow_schema_bytes=schema_bytes,
10761089
)
10771090

10781091
def get_schemas(
@@ -1100,9 +1113,12 @@ def get_schemas(
11001113
)
11011114
resp = self.make_request(self._client.GetSchemas, req)
11021115

1103-
execute_response, is_direct_results = self._handle_execute_response(
1104-
resp, cursor
1105-
)
1116+
(
1117+
execute_response,
1118+
is_direct_results,
1119+
has_been_closed_server_side,
1120+
schema_bytes,
1121+
) = self._handle_execute_response(resp, cursor)
11061122

11071123
t_row_set = None
11081124
if resp.directResults and resp.directResults.resultSet:
@@ -1118,6 +1134,8 @@ def get_schemas(
11181134
max_download_threads=self.max_download_threads,
11191135
ssl_options=self._ssl_options,
11201136
is_direct_results=is_direct_results,
1137+
has_been_closed_server_side=has_been_closed_server_side,
1138+
arrow_schema_bytes=schema_bytes,
11211139
)
11221140

11231141
def get_tables(
@@ -1149,9 +1167,12 @@ def get_tables(
11491167
)
11501168
resp = self.make_request(self._client.GetTables, req)
11511169

1152-
execute_response, is_direct_results = self._handle_execute_response(
1153-
resp, cursor
1154-
)
1170+
(
1171+
execute_response,
1172+
is_direct_results,
1173+
has_been_closed_server_side,
1174+
schema_bytes,
1175+
) = self._handle_execute_response(resp, cursor)
11551176

11561177
t_row_set = None
11571178
if resp.directResults and resp.directResults.resultSet:
@@ -1167,6 +1188,8 @@ def get_tables(
11671188
max_download_threads=self.max_download_threads,
11681189
ssl_options=self._ssl_options,
11691190
is_direct_results=is_direct_results,
1191+
has_been_closed_server_side=has_been_closed_server_side,
1192+
arrow_schema_bytes=schema_bytes,
11701193
)
11711194

11721195
def get_columns(
@@ -1198,9 +1221,12 @@ def get_columns(
11981221
)
11991222
resp = self.make_request(self._client.GetColumns, req)
12001223

1201-
execute_response, is_direct_results = self._handle_execute_response(
1202-
resp, cursor
1203-
)
1224+
(
1225+
execute_response,
1226+
is_direct_results,
1227+
has_been_closed_server_side,
1228+
schema_bytes,
1229+
) = self._handle_execute_response(resp, cursor)
12041230

12051231
t_row_set = None
12061232
if resp.directResults and resp.directResults.resultSet:
@@ -1216,6 +1242,8 @@ def get_columns(
12161242
max_download_threads=self.max_download_threads,
12171243
ssl_options=self._ssl_options,
12181244
is_direct_results=is_direct_results,
1245+
has_been_closed_server_side=has_been_closed_server_side,
1246+
arrow_schema_bytes=schema_bytes,
12191247
)
12201248

12211249
def _handle_execute_response(self, resp, cursor):

src/databricks/sql/backend/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,6 @@ class ExecuteResponse:
419419
command_id: CommandId
420420
status: CommandState
421421
description: List[Tuple]
422-
has_been_closed_server_side: bool = False
423422
lz4_compressed: bool = True
424423
is_staging_operation: bool = False
425-
arrow_schema_bytes: Optional[bytes] = None
426424
result_format: Optional[Any] = None

src/databricks/sql/result_set.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,10 @@ def __init__(
4141
buffer_size_bytes: int,
4242
command_id: CommandId,
4343
status: CommandState,
44-
has_been_closed_server_side: bool = False,
4544
is_direct_results: bool = False,
4645
description: List[Tuple] = [],
4746
is_staging_operation: bool = False,
4847
lz4_compressed: bool = False,
49-
arrow_schema_bytes: Optional[bytes] = None,
5048
):
5149
"""
5250
A ResultSet manages the results of a single command.
@@ -57,7 +55,6 @@ def __init__(
5755
:param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
5856
:param command_id: The command ID
5957
:param status: The command status
60-
:param has_been_closed_server_side: Whether the command has been closed on the server
6158
:param is_direct_results: Whether the command has more rows
6259
:param description: column description of the results
6360
:param is_staging_operation: Whether the command is a staging operation
@@ -71,12 +68,10 @@ def __init__(
7168
self.description: List[Tuple] = description
7269
self.command_id: CommandId = command_id
7370
self.status: CommandState = status
74-
self.has_been_closed_server_side: bool = has_been_closed_server_side
7571
self.is_direct_results: bool = is_direct_results
7672
self.results: Optional[ResultSetQueue] = None
7773
self._is_staging_operation: bool = is_staging_operation
7874
self.lz4_compressed: bool = lz4_compressed
79-
self._arrow_schema_bytes: Optional[bytes] = arrow_schema_bytes
8075

8176
def __iter__(self):
8277
while True:
@@ -158,28 +153,12 @@ def fetchall_arrow(self) -> "pyarrow.Table":
158153
"""Fetch all remaining rows as an Arrow table."""
159154
pass
160155

156+
@abstractmethod
161157
def close(self) -> None:
162158
"""
163159
Close the result set.
164-
165-
If the connection has not been closed, and the result set has not already
166-
been closed on the server for some other reason, issue a request to the server to close it.
167160
"""
168-
try:
169-
if self.results:
170-
self.results.close()
171-
if (
172-
self.status != CommandState.CLOSED
173-
and not self.has_been_closed_server_side
174-
and self.connection.open
175-
):
176-
self.backend.close_command(self.command_id)
177-
except RequestError as e:
178-
if isinstance(e.args[1], CursorAlreadyClosedError):
179-
logger.info("Operation was canceled by a prior request")
180-
finally:
181-
self.has_been_closed_server_side = True
182-
self.status = CommandState.CLOSED
161+
pass
183162

184163

185164
class ThriftResultSet(ResultSet):
@@ -196,6 +175,8 @@ def __init__(
196175
max_download_threads: int = 10,
197176
ssl_options=None,
198177
is_direct_results: bool = True,
178+
has_been_closed_server_side: bool = False,
179+
arrow_schema_bytes: Optional[bytes] = None,
199180
):
200181
"""
201182
Initialize a ThriftResultSet with direct access to the ThriftDatabricksClient.
@@ -210,11 +191,15 @@ def __init__(
210191
:param max_download_threads: Maximum number of download threads for cloud fetch
211192
:param ssl_options: SSL options for cloud fetch
212193
:param is_direct_results: Whether there are more rows to fetch
194+
:param has_been_closed_server_side: Whether the command has been closed on the server
195+
:param arrow_schema_bytes: The schema of the result set
213196
"""
214197

215198
# Initialize ThriftResultSet-specific attributes
216199
self._use_cloud_fetch = use_cloud_fetch
217200
self.is_direct_results = is_direct_results
201+
self.has_been_closed_server_side = has_been_closed_server_side
202+
self._arrow_schema_bytes = arrow_schema_bytes
218203

219204
# Build the results queue if t_row_set is provided
220205
results_queue = None
@@ -225,7 +210,7 @@ def __init__(
225210
results_queue = ThriftResultSetQueueFactory.build_queue(
226211
row_set_type=execute_response.result_format,
227212
t_row_set=t_row_set,
228-
arrow_schema_bytes=execute_response.arrow_schema_bytes or b"",
213+
arrow_schema_bytes=self._arrow_schema_bytes or b"",
229214
max_download_threads=max_download_threads,
230215
lz4_compressed=execute_response.lz4_compressed,
231216
description=execute_response.description,
@@ -239,12 +224,10 @@ def __init__(
239224
buffer_size_bytes=buffer_size_bytes,
240225
command_id=execute_response.command_id,
241226
status=execute_response.status,
242-
has_been_closed_server_side=execute_response.has_been_closed_server_side,
243227
is_direct_results=is_direct_results,
244228
description=execute_response.description,
245229
is_staging_operation=execute_response.is_staging_operation,
246230
lz4_compressed=execute_response.lz4_compressed,
247-
arrow_schema_bytes=execute_response.arrow_schema_bytes,
248231
)
249232

250233
# Assert that the backend is of the correct type
@@ -460,3 +443,29 @@ def map_col_type(type_):
460443
(column.name, map_col_type(column.datatype), None, None, None, None, None)
461444
for column in table_schema_message.columns
462445
]
446+
447+
def close(self) -> None:
448+
"""
449+
Close the result set.
450+
451+
If the connection has not been closed, and the result set has not already
452+
been closed on the server for some other reason, issue a request to the server to close it.
453+
"""
454+
try:
455+
if self.results:
456+
self.results.close()
457+
print(f"has_been_closed_server_side: {self.has_been_closed_server_side}")
458+
print(f"status: {self.status}")
459+
print(f"connection.open: {self.connection.open}")
460+
if (
461+
self.status != CommandState.CLOSED
462+
and not self.has_been_closed_server_side
463+
and self.connection.open
464+
):
465+
self.backend.close_command(self.command_id)
466+
except RequestError as e:
467+
if isinstance(e.args[1], CursorAlreadyClosedError):
468+
logger.info("Operation was canceled by a prior request")
469+
finally:
470+
self.has_been_closed_server_side = True
471+
self.status = CommandState.CLOSED

0 commit comments

Comments
 (0)