Skip to content

Commit 97707f5

Browse files
Merge branch 'sea-migration' into sea-http-client
2 parents 41f1130 + c07beb1 commit 97707f5

File tree

4 files changed

+103
-183
lines changed

4 files changed

+103
-183
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,17 @@
5050

5151

5252
def _filter_session_configuration(
53-
session_configuration: Optional[Dict[str, str]]
54-
) -> Optional[Dict[str, str]]:
53+
session_configuration: Optional[Dict[str, Any]],
54+
) -> Dict[str, str]:
5555
if not session_configuration:
56-
return None
56+
return {}
5757

5858
filtered_session_configuration = {}
5959
ignored_configs: Set[str] = set()
6060

6161
for key, value in session_configuration.items():
6262
if key.upper() in ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP:
63-
filtered_session_configuration[key.lower()] = value
63+
filtered_session_configuration[key.lower()] = str(value)
6464
else:
6565
ignored_configs.add(key)
6666

@@ -188,7 +188,7 @@ def max_download_threads(self) -> int:
188188

189189
def open_session(
190190
self,
191-
session_configuration: Optional[Dict[str, str]],
191+
session_configuration: Optional[Dict[str, Any]],
192192
catalog: Optional[str],
193193
schema: Optional[str],
194194
) -> SessionId:

src/databricks/sql/backend/sea/queue.py

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC
4-
from typing import List, Optional, Tuple, Union
4+
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
55

66
from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager
77

@@ -12,12 +12,13 @@
1212

1313
import dateutil
1414

15-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
16-
from databricks.sql.backend.sea.models.base import (
17-
ExternalLink,
18-
ResultData,
19-
ResultManifest,
20-
)
15+
if TYPE_CHECKING:
16+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
17+
from databricks.sql.backend.sea.models.base import (
18+
ExternalLink,
19+
ResultData,
20+
ResultManifest,
21+
)
2122
from databricks.sql.backend.sea.utils.constants import ResultFormat
2223
from databricks.sql.exc import ProgrammingError, ServerOperationError
2324
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
@@ -109,7 +110,7 @@ def __init__(
109110
result_data: ResultData,
110111
max_download_threads: int,
111112
ssl_options: SSLOptions,
112-
sea_client: "SeaDatabricksClient",
113+
sea_client: SeaDatabricksClient,
113114
statement_id: str,
114115
total_chunk_count: int,
115116
lz4_compressed: bool = False,
@@ -140,6 +141,7 @@ def __init__(
140141

141142
self._sea_client = sea_client
142143
self._statement_id = statement_id
144+
self._total_chunk_count = total_chunk_count
143145

144146
logger.debug(
145147
"SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}".format(
@@ -154,12 +156,11 @@ def __init__(
154156
return None
155157

156158
# Track the current chunk we're processing
157-
self._current_chunk_link = first_link
158-
159+
self._current_chunk_index = 0
159160
# Initialize table and position
160-
self.table = self._create_table_from_link(self._current_chunk_link)
161+
self.table = self._create_table_from_link(first_link)
161162

162-
def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink:
163+
def _convert_to_thrift_link(self, link: ExternalLink) -> TSparkArrowResultLink:
163164
"""Convert SEA external links to Thrift format for compatibility with existing download manager."""
164165
# Parse the ISO format expiration time
165166
expiry_time = int(dateutil.parser.parse(link.expiration).timestamp())
@@ -172,36 +173,24 @@ def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink
172173
httpHeaders=link.http_headers or {},
173174
)
174175

175-
def _progress_chunk_link(self):
176+
def _get_chunk_link(self, chunk_index: int) -> Optional[ExternalLink]:
176177
"""Progress to the next chunk link."""
177-
if not self._current_chunk_link:
178-
return None
179-
180-
next_chunk_index = self._current_chunk_link.next_chunk_index
181-
182-
if next_chunk_index is None:
183-
self._current_chunk_link = None
178+
if chunk_index >= self._total_chunk_count:
184179
return None
185180

186181
try:
187-
self._current_chunk_link = self._sea_client.get_chunk_link(
188-
self._statement_id, next_chunk_index
189-
)
182+
return self._sea_client.get_chunk_link(self._statement_id, chunk_index)
190183
except Exception as e:
191184
raise ServerOperationError(
192-
f"Error fetching link for chunk {next_chunk_index}: {e}",
185+
f"Error fetching link for chunk {chunk_index}: {e}",
193186
{
194187
"operation-id": self._statement_id,
195188
"diagnostic-info": None,
196189
},
197190
)
198191

199-
logger.debug(
200-
f"SeaCloudFetchQueue: Progressed to link for chunk {next_chunk_index}: {self._current_chunk_link}"
201-
)
202-
203192
def _create_table_from_link(
204-
self, link: "ExternalLink"
193+
self, link: ExternalLink
205194
) -> Union["pyarrow.Table", None]:
206195
"""Create a table from a link."""
207196

@@ -215,11 +204,8 @@ def _create_table_from_link(
215204

216205
def _create_next_table(self) -> Union["pyarrow.Table", None]:
217206
"""Create next table by retrieving the logical next downloaded file."""
218-
219-
self._progress_chunk_link()
220-
221-
if not self._current_chunk_link:
222-
logger.debug("SeaCloudFetchQueue: No current chunk link, returning")
207+
self._current_chunk_index += 1
208+
next_chunk_link = self._get_chunk_link(self._current_chunk_index)
209+
if not next_chunk_link:
223210
return None
224-
225-
return self._create_table_from_link(self._current_chunk_link)
211+
return self._create_table_from_link(next_chunk_link)

tests/unit/test_sea_backend.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,71 @@ def test_utility_methods(self, sea_client):
624624
assert description[1][1] == "INT" # type_code
625625
assert description[1][6] is False # null_ok
626626

627+
def test_filter_session_configuration(self):
628+
"""Test that _filter_session_configuration converts all values to strings."""
629+
session_config = {
630+
"ANSI_MODE": True,
631+
"statement_timeout": 3600,
632+
"TIMEZONE": "UTC",
633+
"enable_photon": False,
634+
"MAX_FILE_PARTITION_BYTES": 128.5,
635+
"unsupported_param": "value",
636+
"ANOTHER_UNSUPPORTED": 42,
637+
}
638+
639+
result = _filter_session_configuration(session_config)
640+
641+
# Verify result is not None
642+
assert result is not None
643+
644+
# Verify all returned values are strings
645+
for key, value in result.items():
646+
assert isinstance(
647+
value, str
648+
), f"Value for key '{key}' is not a string: {type(value)}"
649+
650+
# Verify specific conversions
651+
expected_result = {
652+
"ansi_mode": "True", # boolean True -> "True", key lowercased
653+
"statement_timeout": "3600", # int -> "3600", key lowercased
654+
"timezone": "UTC", # string -> "UTC", key lowercased
655+
"enable_photon": "False", # boolean False -> "False", key lowercased
656+
"max_file_partition_bytes": "128.5", # float -> "128.5", key lowercased
657+
}
658+
659+
assert result == expected_result
660+
661+
# Test with None input
662+
assert _filter_session_configuration(None) == {}
663+
664+
# Test with only unsupported parameters
665+
unsupported_config = {
666+
"unsupported_param1": "value1",
667+
"unsupported_param2": 123,
668+
}
669+
result = _filter_session_configuration(unsupported_config)
670+
assert result == {}
671+
672+
# Test case insensitivity for keys
673+
case_insensitive_config = {
674+
"ansi_mode": "false", # lowercase key
675+
"STATEMENT_TIMEOUT": 7200, # uppercase key
676+
"TiMeZoNe": "America/New_York", # mixed case key
677+
}
678+
result = _filter_session_configuration(case_insensitive_config)
679+
expected_case_result = {
680+
"ansi_mode": "false",
681+
"statement_timeout": "7200",
682+
"timezone": "America/New_York",
683+
}
684+
assert result == expected_case_result
685+
686+
# Verify all values are strings in case insensitive test
687+
for key, value in result.items():
688+
assert isinstance(
689+
value, str
690+
), f"Value for key '{key}' is not a string: {type(value)}"
691+
627692
def test_results_message_to_execute_response_is_staging_operation(self, sea_client):
628693
"""Test that is_staging_operation is correctly set from manifest.is_volume_operation."""
629694
# Test when is_volume_operation is True

0 commit comments

Comments
 (0)