Skip to content

Commit fb35f69

Browse files
refactor fetch interface
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 448476b commit fb35f69

17 files changed

+556
-687
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from databricks.sql.client import Cursor
1818

1919
from databricks.sql.thrift_api.TCLIService import ttypes
20-
from databricks.sql.backend.types import SessionId, CommandId, CommandState
21-
from databricks.sql.utils import ExecuteResponse
20+
from databricks.sql.backend.types import SessionId, CommandId, CommandState, ExecuteResponse
2221
from databricks.sql.types import SSLOptions
2322

2423
# Forward reference for type hints

src/databricks/sql/backend/filters.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
TYPE_CHECKING,
1818
)
1919

20-
# Import SeaResultSet for type checking
21-
from databricks.sql.backend.sea_result_set import SeaResultSet
20+
from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory
21+
from databricks.sql.backend.types import ExecuteResponse, CommandId
22+
from databricks.sql.backend.models.base import ResultData
2223

2324
if TYPE_CHECKING:
24-
from databricks.sql.result_set import ResultSet
25+
from databricks.sql.result_set import ResultSet, SeaResultSet
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -36,38 +37,47 @@ class ResultSetFilter:
3637

3738
@staticmethod
3839
def _filter_sea_result_set(
39-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
40-
) -> SeaResultSet:
40+
result_set: "SeaResultSet", filter_func: Callable[[List[Any]], bool]
41+
) -> "SeaResultSet":
4142
"""
4243
Filter a SEA result set using the provided filter function.
43-
44+
4445
Args:
4546
result_set: The SEA result set to filter
4647
filter_func: Function that takes a row and returns True if the row should be included
47-
48+
4849
Returns:
4950
A filtered SEA result set
5051
"""
51-
# Create a filtered version of the result set
52-
filtered_response = result_set._response.copy()
53-
54-
# If there's a result with rows, filter them
55-
if (
56-
"result" in filtered_response
57-
and "data_array" in filtered_response["result"]
58-
):
59-
rows = filtered_response["result"]["data_array"]
60-
filtered_rows = [row for row in rows if filter_func(row)]
61-
filtered_response["result"]["data_array"] = filtered_rows
62-
63-
# Update row count if present
64-
if "row_count" in filtered_response["result"]:
65-
filtered_response["result"]["row_count"] = len(filtered_rows)
66-
67-
# Create a new result set with the filtered data
52+
# Get all remaining rows
53+
original_index = result_set.results.cur_row_index
54+
result_set.results.cur_row_index = 0 # Reset to beginning
55+
all_rows = result_set.results.remaining_rows()
56+
57+
# Filter rows
58+
filtered_rows = [row for row in all_rows if filter_func(row)]
59+
60+
# Import SeaResultSet here to avoid circular imports
61+
from databricks.sql.result_set import SeaResultSet
62+
63+
# Reuse the command_id from the original result set
64+
command_id = result_set.command_id
65+
66+
# Create an ExecuteResponse with the filtered data
67+
execute_response = ExecuteResponse(
68+
command_id=command_id,
69+
status=result_set.status,
70+
description=result_set.description,
71+
has_more_rows=result_set._has_more_rows,
72+
results_queue=JsonQueue(filtered_rows),
73+
has_been_closed_server_side=result_set.has_been_closed_server_side,
74+
lz4_compressed=False,
75+
is_staging_operation=False,
76+
)
77+
6878
return SeaResultSet(
6979
connection=result_set.connection,
70-
sea_response=filtered_response,
80+
execute_response=execute_response,
7181
sea_client=result_set.backend,
7282
buffer_size_bytes=result_set.buffer_size_bytes,
7383
arraysize=result_set.arraysize,
@@ -97,6 +107,7 @@ def filter_by_column_values(
97107
allowed_values = [v.upper() for v in allowed_values]
98108

99109
# Determine the type of result set and apply appropriate filtering
110+
from databricks.sql.result_set import SeaResultSet
100111
if isinstance(result_set, SeaResultSet):
101112
return ResultSetFilter._filter_sea_result_set(
102113
result_set,
@@ -145,4 +156,4 @@ def filter_tables_by_type(
145156
# Table type is typically in the 4th column (index 3)
146157
return ResultSetFilter.filter_by_column_values(
147158
result_set, 3, valid_types, case_sensitive=False
148-
)
159+
)

src/databricks/sql/backend/sea_backend.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from databricks.sql.result_set import ResultSet
99

1010
from databricks.sql.backend.databricks_client import DatabricksClient
11-
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
11+
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType, ExecuteResponse
1212
from databricks.sql.exc import Error, NotSupportedError, ServerOperationError
1313
from databricks.sql.backend.utils.http_client import CustomHttpClient
1414
from databricks.sql.thrift_api.TCLIService import ttypes
1515
from databricks.sql.types import SSLOptions
16+
from databricks.sql.utils import SeaResultSetQueueFactory
17+
from databricks.sql.backend.models.base import ResultData
1618

1719
from databricks.sql.backend.models import (
1820
ExecuteStatementRequest,
@@ -227,6 +229,70 @@ def close_session(self, session_id: SessionId) -> None:
227229
params=request.to_dict(),
228230
)
229231

232+
def _results_message_to_execute_response(self, sea_response, command_id):
233+
"""
234+
Convert a SEA response to an ExecuteResponse.
235+
236+
Args:
237+
sea_response: The response from the SEA API
238+
command_id: The command ID
239+
240+
Returns:
241+
ExecuteResponse: The normalized execute response
242+
"""
243+
# Extract status
244+
status_data = sea_response.get("status", {})
245+
state = CommandState.from_sea_state(status_data.get("state", ""))
246+
247+
# Extract description from manifest
248+
description = None
249+
manifest_data = sea_response.get("manifest", {})
250+
schema_data = manifest_data.get("schema", {})
251+
columns_data = schema_data.get("columns", [])
252+
253+
if columns_data:
254+
columns = []
255+
for col_data in columns_data:
256+
if not isinstance(col_data, dict):
257+
continue
258+
259+
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
260+
columns.append(
261+
(
262+
col_data.get("name", ""), # name
263+
col_data.get("type_name", ""), # type_code
264+
None, # display_size (not provided by SEA)
265+
None, # internal_size (not provided by SEA)
266+
col_data.get("precision"), # precision
267+
col_data.get("scale"), # scale
268+
col_data.get("nullable", True), # null_ok
269+
)
270+
)
271+
description = columns if columns else None
272+
273+
# Create results queue
274+
results_queue = None
275+
result_data = sea_response.get("result", {})
276+
if result_data:
277+
results_queue = SeaResultSetQueueFactory.build_queue(
278+
ResultData(
279+
data=result_data.get("data_array", None),
280+
external_links=result_data.get("external_links", None)
281+
),
282+
description=description
283+
)
284+
285+
return ExecuteResponse(
286+
command_id=command_id,
287+
status=state,
288+
description=description,
289+
has_more_rows=False,
290+
results_queue=results_queue,
291+
has_been_closed_server_side=False,
292+
lz4_compressed=False, # TODO: extract from response
293+
is_staging_operation=False,
294+
)
295+
230296
def execute_command(
231297
self,
232298
operation: str,
@@ -444,11 +510,14 @@ def get_execution_result(
444510
)
445511

446512
# Create and return a SeaResultSet
447-
from databricks.sql.backend.sea_result_set import SeaResultSet
448-
513+
from databricks.sql.result_set import SeaResultSet
514+
515+
# Convert the response to an ExecuteResponse
516+
execute_response = self._results_message_to_execute_response(response_data, command_id)
517+
449518
return SeaResultSet(
450519
connection=cursor.connection,
451-
sea_response=response_data,
520+
execute_response=execute_response,
452521
sea_client=self,
453522
buffer_size_bytes=cursor.buffer_size_bytes,
454523
arraysize=cursor.arraysize,
@@ -599,4 +668,4 @@ def get_columns(
599668
enforce_embedded_schema_correctness=False,
600669
)
601670
assert result is not None, "execute_command returned None in synchronous mode"
602-
return result
671+
return result

0 commit comments

Comments
 (0)