Skip to content

Commit 1ea13c2

Browse files
Merge branch 'sea-migration' into comparator
2 parents f29198d + 80692e3 commit 1ea13c2

26 files changed

+358
-411
lines changed

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def test_sea_sync_query_with_cloud_fetch():
7272
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
7373
)
7474

75+
# Verify total row count
76+
if actual_row_count != requested_row_count:
77+
logger.error(
78+
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
79+
)
80+
return False
81+
7582
# Close resources
7683
cursor.close()
7784
connection.close()
@@ -132,15 +139,27 @@ def test_sea_sync_query_without_cloud_fetch():
132139
# For non-cloud fetch, use a smaller row count to avoid exceeding inline limits
133140
requested_row_count = 100
134141
cursor = connection.cursor()
135-
logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows")
142+
logger.info(
143+
f"Executing synchronous query without cloud fetch: SELECT {requested_row_count} rows"
144+
)
136145
cursor.execute(
137146
"SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)"
138147
)
139148

140149
results = [cursor.fetchone()]
141150
results.extend(cursor.fetchmany(10))
142151
results.extend(cursor.fetchall())
143-
logger.info(f"{len(results)} rows retrieved against 100 requested")
152+
actual_row_count = len(results)
153+
logger.info(
154+
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
155+
)
156+
157+
# Verify total row count
158+
if actual_row_count != requested_row_count:
159+
logger.error(
160+
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
161+
)
162+
return False
144163

145164
# Close resources
146165
cursor.close()

src/databricks/sql/backend/databricks_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def execute_command(
9696
max_rows: Maximum number of rows to fetch in a single fetch batch
9797
max_bytes: Maximum number of bytes to fetch in a single fetch batch
9898
lz4_compression: Whether to use LZ4 compression for result data
99-
cursor: The cursor object that will handle the results
99+
cursor: The cursor object that will handle the results. The command id is set in this cursor.
100100
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
101101
parameters: List of parameters to bind to the query
102102
async_op: Whether to execute the command asynchronously
@@ -282,7 +282,9 @@ def get_tables(
282282
max_bytes: Maximum number of bytes to fetch in a single batch
283283
cursor: The cursor object that will handle the results
284284
catalog_name: Optional catalog name pattern to filter by
285+
if catalog_name is None, we fetch across all catalogs
285286
schema_name: Optional schema name pattern to filter by
287+
if schema_name is None, we fetch across all schemas
286288
table_name: Optional table name pattern to filter by
287289
table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW'])
288290
@@ -321,6 +323,7 @@ def get_columns(
321323
catalog_name: Optional catalog name pattern to filter by
322324
schema_name: Optional schema name pattern to filter by
323325
table_name: Optional table name pattern to filter by
326+
if table_name is None, we fetch across all tables
324327
column_name: Optional column name pattern to filter by
325328
326329
Returns:

src/databricks/sql/backend/sea/backend.py

Lines changed: 44 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import re
66
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
77

8-
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
8+
from databricks.sql.backend.sea.models.base import (
9+
ExternalLink,
10+
ResultManifest,
11+
StatementStatus,
12+
)
13+
from databricks.sql.backend.sea.models.responses import GetChunksResponse
914
from databricks.sql.backend.sea.utils.constants import (
1015
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
1116
ResultFormat,
@@ -45,14 +50,34 @@
4550
GetStatementResponse,
4651
CreateSessionResponse,
4752
)
48-
from databricks.sql.backend.sea.models.responses import GetChunksResponse
4953

5054
logger = logging.getLogger(__name__)
5155

5256

5357
def _filter_session_configuration(
5458
session_configuration: Optional[Dict[str, Any]],
5559
) -> Dict[str, str]:
60+
"""
61+
Filter and normalise the provided session configuration parameters.
62+
63+
The Statement Execution API supports only a subset of SQL session
64+
configuration options. This helper validates the supplied
65+
``session_configuration`` dictionary against the allow-list defined in
66+
``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
67+
dictionary that contains **only** the supported parameters.
68+
69+
Args:
70+
session_configuration: Optional mapping of session configuration
71+
names to their desired values. Key comparison is
72+
case-insensitive.
73+
74+
Returns:
75+
Dict[str, str]: A dictionary containing only the supported
76+
configuration parameters with lower-case keys and string values. If
77+
*session_configuration* is ``None`` or empty, an empty dictionary is
78+
returned.
79+
"""
80+
5681
if not session_configuration:
5782
return {}
5883

@@ -143,7 +168,7 @@ def __init__(
143168
http_path=http_path,
144169
http_headers=http_headers,
145170
auth_provider=auth_provider,
146-
ssl_options=self._ssl_options,
171+
ssl_options=ssl_options,
147172
**kwargs,
148173
)
149174

@@ -275,29 +300,6 @@ def close_session(self, session_id: SessionId) -> None:
275300
data=request_data.to_dict(),
276301
)
277302

278-
@staticmethod
279-
def get_default_session_configuration_value(name: str) -> Optional[str]:
280-
"""
281-
Get the default value for a session configuration parameter.
282-
283-
Args:
284-
name: The name of the session configuration parameter
285-
286-
Returns:
287-
The default value if the parameter is supported, None otherwise
288-
"""
289-
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
290-
291-
@staticmethod
292-
def get_allowed_session_configurations() -> List[str]:
293-
"""
294-
Get the list of allowed session configuration parameters.
295-
296-
Returns:
297-
List of allowed session configuration parameter names
298-
"""
299-
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
300-
301303
def _extract_description_from_manifest(
302304
self, manifest: ResultManifest
303305
) -> List[Tuple]:
@@ -309,7 +311,7 @@ def _extract_description_from_manifest(
309311
manifest: The ResultManifest object containing schema information
310312
311313
Returns:
312-
List[Tuple]: A list of column tuples
314+
Optional[List]: A list of column tuples or None if no columns are found
313315
"""
314316

315317
schema_data = manifest.schema
@@ -397,8 +399,9 @@ def _response_to_result_set(
397399
)
398400

399401
def _check_command_not_in_failed_or_closed_state(
400-
self, state: CommandState, command_id: CommandId
402+
self, status: StatementStatus, command_id: CommandId
401403
) -> None:
404+
state = status.state
402405
if state == CommandState.CLOSED:
403406
raise DatabaseError(
404407
"Command {} unexpectedly closed server side".format(command_id),
@@ -407,8 +410,11 @@ def _check_command_not_in_failed_or_closed_state(
407410
},
408411
)
409412
if state == CommandState.FAILED:
413+
error = status.error
414+
error_code = error.error_code if error else "UNKNOWN_ERROR_CODE"
415+
error_message = error.message if error else "UNKNOWN_ERROR_MESSAGE"
410416
raise ServerOperationError(
411-
"Command {} failed".format(command_id),
417+
"Command failed: {} - {}".format(error_code, error_message),
412418
{
413419
"operation-id": command_id,
414420
},
@@ -422,16 +428,18 @@ def _wait_until_command_done(
422428
"""
423429

424430
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
425-
426-
state = final_response.status.state
427431
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
428432

429-
while state in [CommandState.PENDING, CommandState.RUNNING]:
433+
while final_response.status.state in [
434+
CommandState.PENDING,
435+
CommandState.RUNNING,
436+
]:
430437
time.sleep(self.POLL_INTERVAL_SECONDS)
431438
final_response = self._poll_query(command_id)
432-
state = final_response.status.state
433439

434-
self._check_command_not_in_failed_or_closed_state(state, command_id)
440+
self._check_command_not_in_failed_or_closed_state(
441+
final_response.status, command_id
442+
)
435443

436444
return final_response
437445

@@ -465,7 +473,7 @@ def execute_command(
465473
enforce_embedded_schema_correctness: Whether to enforce schema correctness
466474
467475
Returns:
468-
SeaResultSet: A SeaResultSet instance for the executed command
476+
ResultSet: A SeaResultSet instance for the executed command
469477
"""
470478

471479
if session_id.backend_type != BackendType.SEA:
@@ -521,14 +529,6 @@ def execute_command(
521529
)
522530
response = ExecuteStatementResponse.from_dict(response_data)
523531
statement_id = response.statement_id
524-
if not statement_id:
525-
raise ServerOperationError(
526-
"Failed to execute command: No statement ID returned",
527-
{
528-
"operation-id": None,
529-
"diagnostic-info": None,
530-
},
531-
)
532532

533533
command_id = CommandId.from_sea_statement_id(statement_id)
534534

@@ -560,8 +560,6 @@ def cancel_command(self, command_id: CommandId) -> None:
560560
raise ValueError("Not a valid SEA command ID")
561561

562562
sea_statement_id = command_id.to_sea_statement_id()
563-
if sea_statement_id is None:
564-
raise ValueError("Not a valid SEA command ID")
565563

566564
request = CancelStatementRequest(statement_id=sea_statement_id)
567565
self._http_client._make_request(
@@ -585,8 +583,6 @@ def close_command(self, command_id: CommandId) -> None:
585583
raise ValueError("Not a valid SEA command ID")
586584

587585
sea_statement_id = command_id.to_sea_statement_id()
588-
if sea_statement_id is None:
589-
raise ValueError("Not a valid SEA command ID")
590586

591587
request = CloseStatementRequest(statement_id=sea_statement_id)
592588
self._http_client._make_request(
@@ -604,8 +600,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
604600
raise ValueError("Not a valid SEA command ID")
605601

606602
sea_statement_id = command_id.to_sea_statement_id()
607-
if sea_statement_id is None:
608-
raise ValueError("Not a valid SEA command ID")
609603

610604
request = GetStatementRequest(statement_id=sea_statement_id)
611605
response_data = self._http_client._make_request(
@@ -628,7 +622,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
628622
CommandState: The current state of the command
629623
630624
Raises:
631-
ProgrammingError: If the command ID is invalid
625+
ValueError: If the command ID is invalid
632626
"""
633627

634628
response = self._poll_query(command_id)

src/databricks/sql/backend/sea/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
StatementStatus,
1010
ExternalLink,
1111
ResultData,
12-
ColumnInfo,
1312
ResultManifest,
1413
)
1514

@@ -36,7 +35,6 @@
3635
"StatementStatus",
3736
"ExternalLink",
3837
"ResultData",
39-
"ColumnInfo",
4038
"ResultManifest",
4139
# Request models
4240
"StatementParameter",

src/databricks/sql/backend/sea/models/base.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,12 @@ class ResultData:
6767
attachment: Optional[bytes] = None
6868

6969

70-
@dataclass
71-
class ColumnInfo:
72-
"""Information about a column in the result set."""
73-
74-
name: str
75-
type_name: str
76-
type_text: str
77-
nullable: bool = True
78-
precision: Optional[int] = None
79-
scale: Optional[int] = None
80-
ordinal_position: Optional[int] = None
81-
82-
8370
@dataclass
8471
class ResultManifest:
8572
"""Manifest information for a result set."""
8673

8774
format: str
88-
schema: Dict[str, Any] # Will contain column information
75+
schema: Dict[str, Any]
8976
total_row_count: int
9077
total_byte_count: int
9178
total_chunk_count: int

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def to_dict(self) -> Dict[str, Any]:
5454
result["parameters"] = [
5555
{
5656
"name": param.name,
57-
**({"value": param.value} if param.value is not None else {}),
58-
**({"type": param.type} if param.type is not None else {}),
57+
"value": param.value,
58+
"type": param.type,
5959
}
6060
for param in self.parameters
6161
]

src/databricks/sql/backend/sea/queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def build_queue(
7272
return JsonQueue(result_data.data)
7373
elif manifest.format == ResultFormat.ARROW_STREAM.value:
7474
if result_data.attachment is not None:
75+
# direct results from Hybrid disposition
7576
arrow_file = (
7677
ResultSetDownloadHandler._decompress_data(result_data.attachment)
7778
if lz4_compressed

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _filter_sea_result_set(
5353
# Reuse the command_id from the original result set
5454
command_id = result_set.command_id
5555

56-
# Create an ExecuteResponse with the filtered data
56+
# Create an ExecuteResponse for the filtered data
5757
execute_response = ExecuteResponse(
5858
command_id=command_id,
5959
status=result_set.status,

0 commit comments

Comments
 (0)