Skip to content

Commit e116d9b

Browse files
fix merge artifacts
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent dd7c410 commit e116d9b

File tree

7 files changed

+319
-72
lines changed

7 files changed

+319
-72
lines changed

src/databricks/sql/backend/filters.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from databricks.sql.utils import JsonQueue, SeaResultSetQueueFactory
2121
from databricks.sql.backend.types import ExecuteResponse, CommandId
2222
from databricks.sql.backend.sea.models.base import ResultData
23+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
2324

2425
if TYPE_CHECKING:
2526
from databricks.sql.result_set import ResultSet, SeaResultSet
@@ -69,19 +70,21 @@ def _filter_sea_result_set(
6970
status=result_set.status,
7071
description=result_set.description,
7172
has_more_rows=result_set._has_more_rows,
72-
results_queue=JsonQueue(filtered_rows),
7373
has_been_closed_server_side=result_set.has_been_closed_server_side,
7474
lz4_compressed=False,
7575
is_staging_operation=False,
7676
)
7777

78-
return SeaResultSet(
78+
# Create a new SeaResultSet with the filtered data
79+
filtered_result_set = SeaResultSet(
7980
connection=result_set.connection,
8081
execute_response=execute_response,
81-
sea_client=result_set.backend,
82+
sea_client=cast(SeaDatabricksClient, result_set.backend),
8283
buffer_size_bytes=result_set.buffer_size_bytes,
8384
arraysize=result_set.arraysize,
8485
)
86+
87+
return filtered_result_set
8588

8689
@staticmethod
8790
def filter_by_column_values(

src/databricks/sql/backend/sea/backend.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import logging
22
import uuid
33
import time
4-
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING
4+
import re
5+
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING, Set
6+
7+
from databricks.sql.backend.sea.utils.constants import ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP
58

69
if TYPE_CHECKING:
710
from databricks.sql.client import Cursor
@@ -43,6 +46,32 @@
4346
logger = logging.getLogger(__name__)
4447

4548

49+
def _filter_session_configuration(
50+
session_configuration: Optional[Dict[str, str]]
51+
) -> Optional[Dict[str, str]]:
52+
if not session_configuration:
53+
return None
54+
55+
filtered_session_configuration = {}
56+
ignored_configs: Set[str] = set()
57+
58+
for key, value in session_configuration.items():
59+
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
60+
filtered_session_configuration[key.lower()] = value
61+
else:
62+
ignored_configs.add(key)
63+
64+
if ignored_configs:
65+
logger.warning(
66+
"Some session configurations were ignored because they are not supported: %s",
67+
ignored_configs,
68+
)
69+
logger.warning(
70+
"Supported session configurations are: %s",
71+
list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()),
72+
)
73+
74+
return filtered_session_configuration
4675
class SeaDatabricksClient(DatabricksClient):
4776
"""
4877
Statement Execution API (SEA) implementation of the DatabricksClient interface.
@@ -60,7 +89,7 @@ class SeaDatabricksClient(DatabricksClient):
6089
CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel"
6190
CHUNKS_PATH_WITH_ID = STATEMENT_PATH + "/{}/result/chunks"
6291
CHUNK_PATH_WITH_ID_AND_INDEX = STATEMENT_PATH + "/{}/result/chunks/{}"
63-
92+
6493
def __init__(
6594
self,
6695
server_hostname: str,
@@ -92,6 +121,7 @@ def __init__(
92121
)
93122

94123
self._max_download_threads = kwargs.get("max_download_threads", 10)
124+
self.ssl_options = ssl_options
95125

96126
# Extract warehouse ID from http_path
97127
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -232,9 +262,45 @@ def close_session(self, session_id: SessionId) -> None:
232262
self.http_client._make_request(
233263
method="DELETE",
234264
path=self.SESSION_PATH_WITH_ID.format(sea_session_id),
235-
params=request.to_dict(),
265+
data=request_data.to_dict(),
236266
)
237267

268+
@staticmethod
269+
def get_default_session_configuration_value(name: str) -> Optional[str]:
270+
"""
271+
Get the default value for a session configuration parameter.
272+
273+
Args:
274+
name: The name of the session configuration parameter
275+
276+
Returns:
277+
The default value if the parameter is supported, None otherwise
278+
"""
279+
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
280+
281+
@staticmethod
282+
def is_session_configuration_parameter_supported(name: str) -> bool:
283+
"""
284+
Check if a session configuration parameter is supported.
285+
286+
Args:
287+
name: The name of the session configuration parameter
288+
289+
Returns:
290+
True if the parameter is supported, False otherwise
291+
"""
292+
return name.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP
293+
294+
@staticmethod
295+
def get_allowed_session_configurations() -> List[str]:
296+
"""
297+
Get the list of allowed session configuration parameters.
298+
299+
Returns:
300+
List of allowed session configuration parameter names
301+
"""
302+
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
303+
238304
def fetch_chunk_links(
239305
self, statement_id: str, chunk_index: int
240306
) -> List["ExternalLink"]:
@@ -491,10 +557,11 @@ def _results_message_to_execute_response(self, sea_response, command_id):
491557
status=state,
492558
description=description,
493559
has_more_rows=False,
494-
results_queue=results_queue,
495560
has_been_closed_server_side=False,
496561
lz4_compressed=lz4_compressed,
497562
is_staging_operation=False,
563+
arrow_schema_bytes=schema_bytes,
564+
result_format=manifest_data.get("format"),
498565
)
499566

500567
def execute_command(

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
3939
error_code=error_data.get("error_code"),
4040
)
4141

42+
state = CommandState.from_sea_state(status_data.get("state", ""))
43+
if state is None:
44+
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
45+
4246
status = StatementStatus(
43-
state=CommandState.from_sea_state(status_data.get("state", "")),
47+
state=state,
4448
error=error,
4549
sql_state=status_data.get("sql_state"),
4650
)
@@ -119,8 +123,12 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
119123
error_code=error_data.get("error_code"),
120124
)
121125

126+
state = CommandState.from_sea_state(status_data.get("state", ""))
127+
if state is None:
128+
raise ValueError(f"Invalid state: {status_data.get('state', '')}")
129+
122130
status = StatementStatus(
123-
state=CommandState.from_sea_state(status_data.get("state", "")),
131+
state=state,
124132
error=error,
125133
sql_state=status_data.get("sql_state"),
126134
)

src/databricks/sql/cloud_fetch_queue.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from abc import ABC
9-
from typing import Any, List, Optional, Union, TYPE_CHECKING
9+
from typing import Any, List, Optional, Tuple, Union, TYPE_CHECKING
1010

1111
if TYPE_CHECKING:
1212
from databricks.sql.backend.sea.backend import SeaDatabricksClient
@@ -137,7 +137,7 @@ def __init__(
137137
max_download_threads: int,
138138
ssl_options: SSLOptions,
139139
lz4_compressed: bool = True,
140-
description: Optional[List[List[Any]]] = None,
140+
description: Optional[List[Tuple[Any, ...]]] = None,
141141
):
142142
"""
143143
Initialize the base CloudFetchQueue.
@@ -227,7 +227,7 @@ def __init__(
227227
statement_id: str,
228228
total_chunk_count: int,
229229
lz4_compressed: bool = False,
230-
description: Optional[List[List[Any]]] = None,
230+
description: Optional[List[Tuple[Any, ...]]] = None,
231231
):
232232
"""
233233
Initialize the SEA CloudFetchQueue.
@@ -496,7 +496,7 @@ def __init__(
496496
start_row_offset: int = 0,
497497
result_links: Optional[List[TSparkArrowResultLink]] = None,
498498
lz4_compressed: bool = True,
499-
description: Optional[List[List[Any]]] = None,
499+
description: Optional[List[Tuple[Any, ...]]] = None,
500500
):
501501
"""
502502
Initialize the Thrift CloudFetchQueue.

src/databricks/sql/result_set.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import pandas
77

88
from databricks.sql.backend.sea.backend import SeaDatabricksClient
9+
from databricks.sql.backend.sea.models.base import ExternalLink, ResultData, ResultManifest
910
from databricks.sql.cloud_fetch_queue import SeaCloudFetchQueue
11+
from databricks.sql.utils import SeaResultSetQueueFactory
1012

1113
try:
1214
import pyarrow
@@ -465,6 +467,77 @@ def __init__(
465467
if execute_response.command_id
466468
else None
467469
)
470+
471+
# Get the response data from the SEA backend
472+
response_data = sea_client.http_client._make_request(
473+
method="GET",
474+
path=sea_client.STATEMENT_PATH_WITH_ID.format(self.statement_id),
475+
data={"statement_id": self.statement_id},
476+
)
477+
478+
# Build the results queue
479+
results_queue = None
480+
481+
# Extract data from the response
482+
result_data = response_data.get("result", {})
483+
manifest_data = response_data.get("manifest", {})
484+
485+
if result_data:
486+
# Convert external links
487+
external_links = None
488+
if "external_links" in result_data:
489+
external_links = []
490+
for link_data in result_data["external_links"]:
491+
external_links.append(
492+
ExternalLink(
493+
external_link=link_data.get("external_link", ""),
494+
expiration=link_data.get("expiration", ""),
495+
chunk_index=link_data.get("chunk_index", 0),
496+
byte_count=link_data.get("byte_count", 0),
497+
row_count=link_data.get("row_count", 0),
498+
row_offset=link_data.get("row_offset", 0),
499+
next_chunk_index=link_data.get("next_chunk_index"),
500+
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
501+
http_headers=link_data.get("http_headers", {}),
502+
)
503+
)
504+
505+
# Create the result data object
506+
result_data_obj = ResultData(
507+
data=result_data.get("data_array"), external_links=external_links
508+
)
509+
510+
# Create the manifest object
511+
manifest_obj = ResultManifest(
512+
format=manifest_data.get("format", ""),
513+
schema=manifest_data.get("schema", {}),
514+
total_row_count=manifest_data.get("total_row_count", 0),
515+
total_byte_count=manifest_data.get("total_byte_count", 0),
516+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
517+
truncated=manifest_data.get("truncated", False),
518+
chunks=manifest_data.get("chunks"),
519+
result_compression=manifest_data.get("result_compression"),
520+
)
521+
522+
# Build the queue based on the response data
523+
from typing import cast, List
524+
525+
# Convert description to the expected format
526+
desc = None
527+
if execute_response.description:
528+
desc = cast(List[Tuple[Any, ...]], execute_response.description)
529+
530+
results_queue = SeaResultSetQueueFactory.build_queue(
531+
result_data_obj,
532+
manifest_obj,
533+
str(self.statement_id),
534+
description=desc,
535+
schema_bytes=execute_response.arrow_schema_bytes if execute_response.arrow_schema_bytes else None,
536+
max_download_threads=sea_client.max_download_threads,
537+
ssl_options=sea_client.ssl_options,
538+
sea_client=sea_client,
539+
lz4_compressed=execute_response.lz4_compressed,
540+
)
468541

469542
# Call parent constructor with common attributes
470543
super().__init__(
@@ -476,14 +549,12 @@ def __init__(
476549
status=execute_response.status,
477550
has_been_closed_server_side=execute_response.has_been_closed_server_side,
478551
has_more_rows=execute_response.has_more_rows,
479-
results_queue=execute_response.results_queue,
480552
description=execute_response.description,
481553
is_staging_operation=execute_response.is_staging_operation,
482554
)
483555

484556
# Initialize queue for result data if not provided
485-
if not self.results:
486-
self.results = JsonQueue([])
557+
self.results = results_queue or JsonQueue([])
487558

488559
def _convert_to_row_objects(self, rows):
489560
"""

src/databricks/sql/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def build_queue(
6363
max_download_threads: Optional[int] = None,
6464
ssl_options: Optional[SSLOptions] = None,
6565
lz4_compressed: bool = True,
66-
description: Optional[List[Tuple]] = None,
66+
description: Optional[List[Tuple[Any, ...]]] = None,
6767
) -> ResultSetQueue:
6868
"""
6969
Factory method to build a result set queue for Thrift backend.
@@ -131,7 +131,7 @@ def build_queue(
131131
sea_result_data: ResultData,
132132
manifest: ResultManifest,
133133
statement_id: str,
134-
description: Optional[List[List[Any]]] = None,
134+
description: Optional[List[Tuple[Any, ...]]] = None,
135135
schema_bytes: Optional[bytes] = None,
136136
max_download_threads: Optional[int] = None,
137137
ssl_options: Optional[SSLOptions] = None,

0 commit comments

Comments
 (0)