Skip to content

Commit b404af7

Browse files
refactors
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent bddef1f commit b404af7

File tree

5 files changed

+188
-95
lines changed

5 files changed

+188
-95
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,19 @@ def _filter_sea_result_set(
7575
is_staging_operation=False,
7676
)
7777

78+
# Create a new ResultData object with filtered data
79+
from databricks.sql.backend.sea.models.base import ResultData
80+
81+
result_data = ResultData(data=filtered_rows, external_links=None)
82+
7883
# Create a new SeaResultSet with the filtered data
7984
filtered_result_set = SeaResultSet(
8085
connection=result_set.connection,
8186
execute_response=execute_response,
8287
sea_client=cast(SeaDatabricksClient, result_set.backend),
8388
buffer_size_bytes=result_set.buffer_size_bytes,
8489
arraysize=result_set.arraysize,
90+
result_data=result_data,
8591
)
8692

8793
return filtered_result_set

src/databricks/sql/backend/sea/backend.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -449,14 +449,15 @@ def _get_schema_bytes(self, sea_response) -> Optional[bytes]:
449449

450450
def _results_message_to_execute_response(self, sea_response, command_id):
451451
"""
452-
Convert a SEA response to an ExecuteResponse.
452+
Convert a SEA response to an ExecuteResponse and extract result data.
453453
454454
Args:
455455
sea_response: The response from the SEA API
456456
command_id: The command ID
457457
458458
Returns:
459-
ExecuteResponse: The normalized execute response
459+
tuple: (ExecuteResponse, ResultData, ResultManifest) - The normalized execute response,
460+
result data object, and manifest object
460461
"""
461462
# Extract status
462463
status_data = sea_response.get("status", {})
@@ -498,8 +499,10 @@ def _results_message_to_execute_response(self, sea_response, command_id):
498499
# Check for compression
499500
lz4_compressed = manifest_data.get("result_compression") == "LZ4_FRAME"
500501

501-
# Create results queue
502-
results_queue = None
502+
# Initialize result_data_obj and manifest_obj
503+
result_data_obj = None
504+
manifest_obj = None
505+
503506
result_data = sea_response.get("result", {})
504507
if result_data:
505508
# Convert external links
@@ -528,31 +531,19 @@ def _results_message_to_execute_response(self, sea_response, command_id):
528531
data=result_data.get("data_array"), external_links=external_links
529532
)
530533

531-
# Create the manifest object
532-
manifest_obj = ResultManifest(
533-
format=manifest_data.get("format", ""),
534-
schema=manifest_data.get("schema", {}),
535-
total_row_count=manifest_data.get("total_row_count", 0),
536-
total_byte_count=manifest_data.get("total_byte_count", 0),
537-
total_chunk_count=manifest_data.get("total_chunk_count", 0),
538-
truncated=manifest_data.get("truncated", False),
539-
chunks=manifest_data.get("chunks"),
540-
result_compression=manifest_data.get("result_compression"),
541-
)
542-
543-
results_queue = SeaResultSetQueueFactory.build_queue(
544-
result_data_obj,
545-
manifest_obj,
546-
command_id.to_sea_statement_id(),
547-
description=description,
548-
schema_bytes=schema_bytes,
549-
max_download_threads=self.max_download_threads,
550-
ssl_options=self.ssl_options,
551-
sea_client=self,
552-
lz4_compressed=lz4_compressed,
553-
)
534+
# Create the manifest object
535+
manifest_obj = ResultManifest(
536+
format=manifest_data.get("format", ""),
537+
schema=manifest_data.get("schema", {}),
538+
total_row_count=manifest_data.get("total_row_count", 0),
539+
total_byte_count=manifest_data.get("total_byte_count", 0),
540+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
541+
truncated=manifest_data.get("truncated", False),
542+
chunks=manifest_data.get("chunks"),
543+
result_compression=manifest_data.get("result_compression"),
544+
)
554545

555-
return ExecuteResponse(
546+
execute_response = ExecuteResponse(
556547
command_id=command_id,
557548
status=state,
558549
description=description,
@@ -563,6 +554,8 @@ def _results_message_to_execute_response(self, sea_response, command_id):
563554
arrow_schema_bytes=schema_bytes,
564555
result_format=manifest_data.get("format"),
565556
)
557+
558+
return execute_response, result_data_obj, manifest_obj
566559

567560
def execute_command(
568561
self,
@@ -782,8 +775,8 @@ def get_execution_result(
782775
# Create and return a SeaResultSet
783776
from databricks.sql.result_set import SeaResultSet
784777

785-
# Convert the response to an ExecuteResponse
786-
execute_response = self._results_message_to_execute_response(
778+
# Convert the response to an ExecuteResponse and extract result data
779+
execute_response, result_data, manifest = self._results_message_to_execute_response(
787780
response_data, command_id
788781
)
789782

@@ -793,6 +786,8 @@ def get_execution_result(
793786
sea_client=self,
794787
buffer_size_bytes=cursor.buffer_size_bytes,
795788
arraysize=cursor.arraysize,
789+
result_data=result_data,
790+
manifest=manifest,
796791
)
797792

798793
# == Metadata Operations ==

src/databricks/sql/result_set.py

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,8 @@ def __init__(
450450
sea_client: "SeaDatabricksClient",
451451
buffer_size_bytes: int = 104857600,
452452
arraysize: int = 10000,
453+
result_data: Optional["ResultData"] = None,
454+
manifest: Optional["ResultManifest"] = None,
453455
):
454456
"""
455457
Initialize a SeaResultSet with the response from a SEA query execution.
@@ -460,6 +462,8 @@ def __init__(
460462
sea_client: The SeaDatabricksClient instance for direct access
461463
buffer_size_bytes: Buffer size for fetching results
462464
arraysize: Default number of rows to fetch
465+
result_data: Result data from SEA response (optional)
466+
manifest: Manifest from SEA response (optional)
463467
"""
464468
# Extract and store SEA-specific properties
465469
self.statement_id = (
@@ -468,58 +472,10 @@ def __init__(
468472
else None
469473
)
470474

471-
# Get the response data from the SEA backend
472-
response_data = sea_client.http_client._make_request(
473-
method="GET",
474-
path=sea_client.STATEMENT_PATH_WITH_ID.format(self.statement_id),
475-
data={"statement_id": self.statement_id},
476-
)
477-
478475
# Build the results queue
479476
results_queue = None
480477

481-
# Extract data from the response
482-
result_data = response_data.get("result", {})
483-
manifest_data = response_data.get("manifest", {})
484-
485478
if result_data:
486-
# Convert external links
487-
external_links = None
488-
if "external_links" in result_data:
489-
external_links = []
490-
for link_data in result_data["external_links"]:
491-
external_links.append(
492-
ExternalLink(
493-
external_link=link_data.get("external_link", ""),
494-
expiration=link_data.get("expiration", ""),
495-
chunk_index=link_data.get("chunk_index", 0),
496-
byte_count=link_data.get("byte_count", 0),
497-
row_count=link_data.get("row_count", 0),
498-
row_offset=link_data.get("row_offset", 0),
499-
next_chunk_index=link_data.get("next_chunk_index"),
500-
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
501-
http_headers=link_data.get("http_headers", {}),
502-
)
503-
)
504-
505-
# Create the result data object
506-
result_data_obj = ResultData(
507-
data=result_data.get("data_array"), external_links=external_links
508-
)
509-
510-
# Create the manifest object
511-
manifest_obj = ResultManifest(
512-
format=manifest_data.get("format", ""),
513-
schema=manifest_data.get("schema", {}),
514-
total_row_count=manifest_data.get("total_row_count", 0),
515-
total_byte_count=manifest_data.get("total_byte_count", 0),
516-
total_chunk_count=manifest_data.get("total_chunk_count", 0),
517-
truncated=manifest_data.get("truncated", False),
518-
chunks=manifest_data.get("chunks"),
519-
result_compression=manifest_data.get("result_compression"),
520-
)
521-
522-
# Build the queue based on the response data
523479
from typing import cast, List
524480

525481
# Convert description to the expected format
@@ -528,8 +484,8 @@ def __init__(
528484
desc = cast(List[Tuple[Any, ...]], execute_response.description)
529485

530486
results_queue = SeaResultSetQueueFactory.build_queue(
531-
result_data_obj,
532-
manifest_obj,
487+
result_data,
488+
manifest,
533489
str(self.statement_id),
534490
description=desc,
535491
schema_bytes=execute_response.arrow_schema_bytes if execute_response.arrow_schema_bytes else None,

src/databricks/sql/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ class SeaResultSetQueueFactory(ABC):
129129
@staticmethod
130130
def build_queue(
131131
sea_result_data: ResultData,
132-
manifest: ResultManifest,
132+
manifest: Optional[ResultManifest],
133133
statement_id: str,
134134
description: Optional[List[Tuple[Any, ...]]] = None,
135135
schema_bytes: Optional[bytes] = None,
@@ -176,6 +176,10 @@ def build_queue(
176176
raise ValueError(
177177
"SEA client is required for EXTERNAL_LINKS disposition"
178178
)
179+
if not manifest:
180+
raise ValueError(
181+
"Manifest is required for EXTERNAL_LINKS disposition"
182+
)
179183

180184
return SeaCloudFetchQueue(
181185
initial_links=sea_result_data.external_links,

0 commit comments

Comments
 (0)