Skip to content

Commit aab8ce5

Browse files
Merge branch 'main' into less-defensive-download
2 parents cd8389f + 0a7a6ab commit aab8ce5

31 files changed

+553
-472
lines changed

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def test_sea_sync_query_with_cloud_fetch():
7272
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
7373
)
7474

75+
# Verify total row count
76+
if actual_row_count != requested_row_count:
77+
logger.error(
78+
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
79+
)
80+
return False
81+
7582
# Close resources
7683
cursor.close()
7784
connection.close()
@@ -132,15 +139,27 @@ def test_sea_sync_query_without_cloud_fetch():
132139
# For non-cloud fetch, use a smaller row count to avoid exceeding inline limits
133140
requested_row_count = 100
134141
cursor = connection.cursor()
135-
logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows")
142+
logger.info(
143+
f"Executing synchronous query without cloud fetch: SELECT {requested_row_count} rows"
144+
)
136145
cursor.execute(
137146
"SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)"
138147
)
139148

140149
results = [cursor.fetchone()]
141150
results.extend(cursor.fetchmany(10))
142151
results.extend(cursor.fetchall())
143-
logger.info(f"{len(results)} rows retrieved against 100 requested")
152+
actual_row_count = len(results)
153+
logger.info(
154+
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
155+
)
156+
157+
# Verify total row count
158+
if actual_row_count != requested_row_count:
159+
logger.error(
160+
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
161+
)
162+
return False
144163

145164
# Close resources
146165
cursor.close()

src/databricks/sql/backend/databricks_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def execute_command(
9696
max_rows: Maximum number of rows to fetch in a single fetch batch
9797
max_bytes: Maximum number of bytes to fetch in a single fetch batch
9898
lz4_compression: Whether to use LZ4 compression for result data
99-
cursor: The cursor object that will handle the results
99+
cursor: The cursor object that will handle the results. The command id is set in this cursor.
100100
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
101101
parameters: List of parameters to bind to the query
102102
async_op: Whether to execute the command asynchronously
@@ -282,7 +282,9 @@ def get_tables(
282282
max_bytes: Maximum number of bytes to fetch in a single batch
283283
cursor: The cursor object that will handle the results
284284
catalog_name: Optional catalog name pattern to filter by
285+
if catalog_name is None, we fetch across all catalogs
285286
schema_name: Optional schema name pattern to filter by
287+
if schema_name is None, we fetch across all schemas
286288
table_name: Optional table name pattern to filter by
287289
table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW'])
288290
@@ -321,6 +323,7 @@ def get_columns(
321323
catalog_name: Optional catalog name pattern to filter by
322324
schema_name: Optional schema name pattern to filter by
323325
table_name: Optional table name pattern to filter by
326+
if table_name is None, we fetch across all tables
324327
column_name: Optional column name pattern to filter by
325328
326329
Returns:

src/databricks/sql/backend/sea/backend.py

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import re
66
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
77

8-
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
8+
from databricks.sql.backend.sea.models.base import (
9+
ExternalLink,
10+
ResultManifest,
11+
StatementStatus,
12+
)
13+
from databricks.sql.backend.sea.models.responses import GetChunksResponse
914
from databricks.sql.backend.sea.utils.constants import (
1015
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
1116
ResultFormat,
@@ -45,14 +50,34 @@
4550
GetStatementResponse,
4651
CreateSessionResponse,
4752
)
48-
from databricks.sql.backend.sea.models.responses import GetChunksResponse
4953

5054
logger = logging.getLogger(__name__)
5155

5256

5357
def _filter_session_configuration(
5458
session_configuration: Optional[Dict[str, Any]],
5559
) -> Dict[str, str]:
60+
"""
61+
Filter and normalise the provided session configuration parameters.
62+
63+
The Statement Execution API supports only a subset of SQL session
64+
configuration options. This helper validates the supplied
65+
``session_configuration`` dictionary against the allow-list defined in
66+
``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
67+
dictionary that contains **only** the supported parameters.
68+
69+
Args:
70+
session_configuration: Optional mapping of session configuration
71+
names to their desired values. Key comparison is
72+
case-insensitive.
73+
74+
Returns:
75+
Dict[str, str]: A dictionary containing only the supported
76+
configuration parameters with lower-case keys and string values. If
77+
*session_configuration* is ``None`` or empty, an empty dictionary is
78+
returned.
79+
"""
80+
5681
if not session_configuration:
5782
return {}
5883

@@ -143,7 +168,7 @@ def __init__(
143168
http_path=http_path,
144169
http_headers=http_headers,
145170
auth_provider=auth_provider,
146-
ssl_options=self._ssl_options,
171+
ssl_options=ssl_options,
147172
**kwargs,
148173
)
149174

@@ -275,29 +300,6 @@ def close_session(self, session_id: SessionId) -> None:
275300
data=request_data.to_dict(),
276301
)
277302

278-
@staticmethod
279-
def get_default_session_configuration_value(name: str) -> Optional[str]:
280-
"""
281-
Get the default value for a session configuration parameter.
282-
283-
Args:
284-
name: The name of the session configuration parameter
285-
286-
Returns:
287-
The default value if the parameter is supported, None otherwise
288-
"""
289-
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
290-
291-
@staticmethod
292-
def get_allowed_session_configurations() -> List[str]:
293-
"""
294-
Get the list of allowed session configuration parameters.
295-
296-
Returns:
297-
List of allowed session configuration parameter names
298-
"""
299-
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
300-
301303
def _extract_description_from_manifest(
302304
self, manifest: ResultManifest
303305
) -> List[Tuple]:
@@ -309,7 +311,7 @@ def _extract_description_from_manifest(
309311
manifest: The ResultManifest object containing schema information
310312
311313
Returns:
312-
List[Tuple]: A list of column tuples
314+
Optional[List]: A list of column tuples or None if no columns are found
313315
"""
314316

315317
schema_data = manifest.schema
@@ -318,15 +320,23 @@ def _extract_description_from_manifest(
318320
columns = []
319321
for col_data in columns_data:
320322
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
323+
name = col_data.get("name", "")
324+
type_name = col_data.get("type_name", "")
325+
type_name = (
326+
type_name[:-5] if type_name.endswith("_TYPE") else type_name
327+
).lower()
328+
precision = col_data.get("type_precision")
329+
scale = col_data.get("type_scale")
330+
321331
columns.append(
322332
(
323-
col_data.get("name", ""), # name
324-
col_data.get("type_name", ""), # type_code
333+
name, # name
334+
type_name, # type_code
325335
None, # display_size (not provided by SEA)
326336
None, # internal_size (not provided by SEA)
327-
col_data.get("precision"), # precision
328-
col_data.get("scale"), # scale
329-
col_data.get("nullable", True), # null_ok
337+
precision, # precision
338+
scale, # scale
339+
None, # null_ok
330340
)
331341
)
332342

@@ -389,8 +399,9 @@ def _response_to_result_set(
389399
)
390400

391401
def _check_command_not_in_failed_or_closed_state(
392-
self, state: CommandState, command_id: CommandId
402+
self, status: StatementStatus, command_id: CommandId
393403
) -> None:
404+
state = status.state
394405
if state == CommandState.CLOSED:
395406
raise DatabaseError(
396407
"Command {} unexpectedly closed server side".format(command_id),
@@ -399,8 +410,11 @@ def _check_command_not_in_failed_or_closed_state(
399410
},
400411
)
401412
if state == CommandState.FAILED:
413+
error = status.error
414+
error_code = error.error_code if error else "UNKNOWN_ERROR_CODE"
415+
error_message = error.message if error else "UNKNOWN_ERROR_MESSAGE"
402416
raise ServerOperationError(
403-
"Command {} failed".format(command_id),
417+
"Command failed: {} - {}".format(error_code, error_message),
404418
{
405419
"operation-id": command_id,
406420
},
@@ -414,16 +428,18 @@ def _wait_until_command_done(
414428
"""
415429

416430
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
417-
418-
state = final_response.status.state
419431
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
420432

421-
while state in [CommandState.PENDING, CommandState.RUNNING]:
433+
while final_response.status.state in [
434+
CommandState.PENDING,
435+
CommandState.RUNNING,
436+
]:
422437
time.sleep(self.POLL_INTERVAL_SECONDS)
423438
final_response = self._poll_query(command_id)
424-
state = final_response.status.state
425439

426-
self._check_command_not_in_failed_or_closed_state(state, command_id)
440+
self._check_command_not_in_failed_or_closed_state(
441+
final_response.status, command_id
442+
)
427443

428444
return final_response
429445

@@ -457,7 +473,7 @@ def execute_command(
457473
enforce_embedded_schema_correctness: Whether to enforce schema correctness
458474
459475
Returns:
460-
SeaResultSet: A SeaResultSet instance for the executed command
476+
ResultSet: A SeaResultSet instance for the executed command
461477
"""
462478

463479
if session_id.backend_type != BackendType.SEA:
@@ -513,14 +529,6 @@ def execute_command(
513529
)
514530
response = ExecuteStatementResponse.from_dict(response_data)
515531
statement_id = response.statement_id
516-
if not statement_id:
517-
raise ServerOperationError(
518-
"Failed to execute command: No statement ID returned",
519-
{
520-
"operation-id": None,
521-
"diagnostic-info": None,
522-
},
523-
)
524532

525533
command_id = CommandId.from_sea_statement_id(statement_id)
526534

@@ -552,8 +560,6 @@ def cancel_command(self, command_id: CommandId) -> None:
552560
raise ValueError("Not a valid SEA command ID")
553561

554562
sea_statement_id = command_id.to_sea_statement_id()
555-
if sea_statement_id is None:
556-
raise ValueError("Not a valid SEA command ID")
557563

558564
request = CancelStatementRequest(statement_id=sea_statement_id)
559565
self._http_client._make_request(
@@ -577,8 +583,6 @@ def close_command(self, command_id: CommandId) -> None:
577583
raise ValueError("Not a valid SEA command ID")
578584

579585
sea_statement_id = command_id.to_sea_statement_id()
580-
if sea_statement_id is None:
581-
raise ValueError("Not a valid SEA command ID")
582586

583587
request = CloseStatementRequest(statement_id=sea_statement_id)
584588
self._http_client._make_request(
@@ -596,8 +600,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
596600
raise ValueError("Not a valid SEA command ID")
597601

598602
sea_statement_id = command_id.to_sea_statement_id()
599-
if sea_statement_id is None:
600-
raise ValueError("Not a valid SEA command ID")
601603

602604
request = GetStatementRequest(statement_id=sea_statement_id)
603605
response_data = self._http_client._make_request(
@@ -620,7 +622,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
620622
CommandState: The current state of the command
621623
622624
Raises:
623-
ProgrammingError: If the command ID is invalid
625+
ValueError: If the command ID is invalid
624626
"""
625627

626628
response = self._poll_query(command_id)

src/databricks/sql/backend/sea/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
StatementStatus,
1010
ExternalLink,
1111
ResultData,
12-
ColumnInfo,
1312
ResultManifest,
1413
)
1514

@@ -36,7 +35,6 @@
3635
"StatementStatus",
3736
"ExternalLink",
3837
"ResultData",
39-
"ColumnInfo",
4038
"ResultManifest",
4139
# Request models
4240
"StatementParameter",

src/databricks/sql/backend/sea/models/base.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,12 @@ class ResultData:
6767
attachment: Optional[bytes] = None
6868

6969

70-
@dataclass
71-
class ColumnInfo:
72-
"""Information about a column in the result set."""
73-
74-
name: str
75-
type_name: str
76-
type_text: str
77-
nullable: bool = True
78-
precision: Optional[int] = None
79-
scale: Optional[int] = None
80-
ordinal_position: Optional[int] = None
81-
82-
8370
@dataclass
8471
class ResultManifest:
8572
"""Manifest information for a result set."""
8673

8774
format: str
88-
schema: Dict[str, Any] # Will contain column information
75+
schema: Dict[str, Any]
8976
total_row_count: int
9077
total_byte_count: int
9178
total_chunk_count: int

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def to_dict(self) -> Dict[str, Any]:
5454
result["parameters"] = [
5555
{
5656
"name": param.name,
57-
**({"value": param.value} if param.value is not None else {}),
58-
**({"type": param.type} if param.type is not None else {}),
57+
"value": param.value,
58+
"type": param.type,
5959
}
6060
for param in self.parameters
6161
]

src/databricks/sql/backend/sea/queue.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def build_queue(
7272
return JsonQueue(result_data.data)
7373
elif manifest.format == ResultFormat.ARROW_STREAM.value:
7474
if result_data.attachment is not None:
75+
# direct results from Hybrid disposition
7576
arrow_file = (
7677
ResultSetDownloadHandler._decompress_data(result_data.attachment)
7778
if lz4_compressed
@@ -363,14 +364,14 @@ def __init__(
363364
# Initialize table and position
364365
self.table = self._create_next_table()
365366

366-
def _create_next_table(self) -> "pyarrow.Table":
367+
def _create_next_table(self) -> Union["pyarrow.Table", None]:
367368
"""Create next table by retrieving the logical next downloaded file."""
368369
if self.link_fetcher is None:
369-
return self._create_empty_table()
370+
return None
370371

371372
chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index)
372373
if chunk_link is None:
373-
return self._create_empty_table()
374+
return None
374375

375376
row_offset = chunk_link.row_offset
376377
# NOTE: link has already been submitted to download manager at this point

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _filter_sea_result_set(
5353
# Reuse the command_id from the original result set
5454
command_id = result_set.command_id
5555

56-
# Create an ExecuteResponse with the filtered data
56+
# Create an ExecuteResponse for the filtered data
5757
execute_response = ExecuteResponse(
5858
command_id=command_id,
5959
status=result_set.status,

0 commit comments

Comments
 (0)