Skip to content

Commit 3ae7d04

Browse files
Merge branch 'main' into sea-e2e-tests
2 parents d9f59a8 + e732e96 commit 3ae7d04

38 files changed

+1529
-727
lines changed

examples/experimental/tests/test_sea_sync_query.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def test_sea_sync_query_with_cloud_fetch():
7272
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
7373
)
7474

75+
# Verify total row count
76+
if actual_row_count != requested_row_count:
77+
logger.error(
78+
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
79+
)
80+
return False
81+
7582
# Close resources
7683
cursor.close()
7784
connection.close()
@@ -132,15 +139,27 @@ def test_sea_sync_query_without_cloud_fetch():
132139
# For non-cloud fetch, use a smaller row count to avoid exceeding inline limits
133140
requested_row_count = 100
134141
cursor = connection.cursor()
135-
logger.info("Executing synchronous query without cloud fetch: SELECT 100 rows")
142+
logger.info(
143+
f"Executing synchronous query without cloud fetch: SELECT {requested_row_count} rows"
144+
)
136145
cursor.execute(
137146
"SELECT id, 'test_value_' || CAST(id as STRING) as test_value FROM range(1, 101)"
138147
)
139148

140149
results = [cursor.fetchone()]
141150
results.extend(cursor.fetchmany(10))
142151
results.extend(cursor.fetchall())
143-
logger.info(f"{len(results)} rows retrieved against 100 requested")
152+
actual_row_count = len(results)
153+
logger.info(
154+
f"{actual_row_count} rows retrieved against {requested_row_count} requested"
155+
)
156+
157+
# Verify total row count
158+
if actual_row_count != requested_row_count:
159+
logger.error(
160+
f"FAIL: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
161+
)
162+
return False
144163

145164
# Close resources
146165
cursor.close()

src/databricks/sql/backend/databricks_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def execute_command(
9696
max_rows: Maximum number of rows to fetch in a single fetch batch
9797
max_bytes: Maximum number of bytes to fetch in a single fetch batch
9898
lz4_compression: Whether to use LZ4 compression for result data
99-
cursor: The cursor object that will handle the results
99+
cursor: The cursor object that will handle the results. The command id is set in this cursor.
100100
use_cloud_fetch: Whether to use cloud fetch for retrieving large result sets
101101
parameters: List of parameters to bind to the query
102102
async_op: Whether to execute the command asynchronously
@@ -282,7 +282,9 @@ def get_tables(
282282
max_bytes: Maximum number of bytes to fetch in a single batch
283283
cursor: The cursor object that will handle the results
284284
catalog_name: Optional catalog name pattern to filter by
285+
if catalog_name is None, we fetch across all catalogs
285286
schema_name: Optional schema name pattern to filter by
287+
if schema_name is None, we fetch across all schemas
286288
table_name: Optional table name pattern to filter by
287289
table_types: Optional list of table types to filter by (e.g., ['TABLE', 'VIEW'])
288290
@@ -321,6 +323,7 @@ def get_columns(
321323
catalog_name: Optional catalog name pattern to filter by
322324
schema_name: Optional schema name pattern to filter by
323325
table_name: Optional table name pattern to filter by
326+
if table_name is None, we fetch across all tables
324327
column_name: Optional column name pattern to filter by
325328
326329
Returns:

src/databricks/sql/backend/sea/backend.py

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import re
66
from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set
77

8-
from databricks.sql.backend.sea.models.base import ExternalLink, ResultManifest
8+
from databricks.sql.backend.sea.models.base import (
9+
ExternalLink,
10+
ResultManifest,
11+
StatementStatus,
12+
)
13+
from databricks.sql.backend.sea.models.responses import GetChunksResponse
914
from databricks.sql.backend.sea.utils.constants import (
1015
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP,
1116
ResultFormat,
@@ -14,6 +19,7 @@
1419
WaitTimeout,
1520
MetadataCommands,
1621
)
22+
from databricks.sql.backend.sea.utils.normalize import normalize_sea_type_to_thrift
1723
from databricks.sql.thrift_api.TCLIService import ttypes
1824

1925
if TYPE_CHECKING:
@@ -45,14 +51,34 @@
4551
GetStatementResponse,
4652
CreateSessionResponse,
4753
)
48-
from databricks.sql.backend.sea.models.responses import GetChunksResponse
4954

5055
logger = logging.getLogger(__name__)
5156

5257

5358
def _filter_session_configuration(
5459
session_configuration: Optional[Dict[str, Any]],
5560
) -> Dict[str, str]:
61+
"""
62+
Filter and normalise the provided session configuration parameters.
63+
64+
The Statement Execution API supports only a subset of SQL session
65+
configuration options. This helper validates the supplied
66+
``session_configuration`` dictionary against the allow-list defined in
67+
``ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP`` and returns a new
68+
dictionary that contains **only** the supported parameters.
69+
70+
Args:
71+
session_configuration: Optional mapping of session configuration
72+
names to their desired values. Key comparison is
73+
case-insensitive.
74+
75+
Returns:
76+
Dict[str, str]: A dictionary containing only the supported
77+
configuration parameters with lower-case keys and string values. If
78+
*session_configuration* is ``None`` or empty, an empty dictionary is
79+
returned.
80+
"""
81+
5682
if not session_configuration:
5783
return {}
5884

@@ -143,7 +169,7 @@ def __init__(
143169
http_path=http_path,
144170
http_headers=http_headers,
145171
auth_provider=auth_provider,
146-
ssl_options=self._ssl_options,
172+
ssl_options=ssl_options,
147173
**kwargs,
148174
)
149175

@@ -275,29 +301,6 @@ def close_session(self, session_id: SessionId) -> None:
275301
data=request_data.to_dict(),
276302
)
277303

278-
@staticmethod
279-
def get_default_session_configuration_value(name: str) -> Optional[str]:
280-
"""
281-
Get the default value for a session configuration parameter.
282-
283-
Args:
284-
name: The name of the session configuration parameter
285-
286-
Returns:
287-
The default value if the parameter is supported, None otherwise
288-
"""
289-
return ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.get(name.upper())
290-
291-
@staticmethod
292-
def get_allowed_session_configurations() -> List[str]:
293-
"""
294-
Get the list of allowed session configuration parameters.
295-
296-
Returns:
297-
List of allowed session configuration parameter names
298-
"""
299-
return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys())
300-
301304
def _extract_description_from_manifest(
302305
self, manifest: ResultManifest
303306
) -> List[Tuple]:
@@ -309,7 +312,7 @@ def _extract_description_from_manifest(
309312
manifest: The ResultManifest object containing schema information
310313
311314
Returns:
312-
List[Tuple]: A list of column tuples
315+
Optional[List]: A list of column tuples or None if no columns are found
313316
"""
314317

315318
schema_data = manifest.schema
@@ -318,15 +321,28 @@ def _extract_description_from_manifest(
318321
columns = []
319322
for col_data in columns_data:
320323
# Format: (name, type_code, display_size, internal_size, precision, scale, null_ok)
324+
name = col_data.get("name", "")
325+
type_name = col_data.get("type_name", "")
326+
327+
# Normalize SEA type to Thrift conventions before any processing
328+
type_name = normalize_sea_type_to_thrift(type_name, col_data)
329+
330+
# Now strip _TYPE suffix and convert to lowercase
331+
type_name = (
332+
type_name[:-5] if type_name.endswith("_TYPE") else type_name
333+
).lower()
334+
precision = col_data.get("type_precision")
335+
scale = col_data.get("type_scale")
336+
321337
columns.append(
322338
(
323-
col_data.get("name", ""), # name
324-
col_data.get("type_name", ""), # type_code
339+
name, # name
340+
type_name, # type_code
325341
None, # display_size (not provided by SEA)
326342
None, # internal_size (not provided by SEA)
327-
col_data.get("precision"), # precision
328-
col_data.get("scale"), # scale
329-
col_data.get("nullable", True), # null_ok
343+
precision, # precision
344+
scale, # scale
345+
None, # null_ok
330346
)
331347
)
332348

@@ -389,8 +405,9 @@ def _response_to_result_set(
389405
)
390406

391407
def _check_command_not_in_failed_or_closed_state(
392-
self, state: CommandState, command_id: CommandId
408+
self, status: StatementStatus, command_id: CommandId
393409
) -> None:
410+
state = status.state
394411
if state == CommandState.CLOSED:
395412
raise DatabaseError(
396413
"Command {} unexpectedly closed server side".format(command_id),
@@ -399,8 +416,11 @@ def _check_command_not_in_failed_or_closed_state(
399416
},
400417
)
401418
if state == CommandState.FAILED:
419+
error = status.error
420+
error_code = error.error_code if error else "UNKNOWN_ERROR_CODE"
421+
error_message = error.message if error else "UNKNOWN_ERROR_MESSAGE"
402422
raise ServerOperationError(
403-
"Command {} failed".format(command_id),
423+
"Command failed: {} - {}".format(error_code, error_message),
404424
{
405425
"operation-id": command_id,
406426
},
@@ -414,16 +434,18 @@ def _wait_until_command_done(
414434
"""
415435

416436
final_response: Union[ExecuteStatementResponse, GetStatementResponse] = response
417-
418-
state = final_response.status.state
419437
command_id = CommandId.from_sea_statement_id(final_response.statement_id)
420438

421-
while state in [CommandState.PENDING, CommandState.RUNNING]:
439+
while final_response.status.state in [
440+
CommandState.PENDING,
441+
CommandState.RUNNING,
442+
]:
422443
time.sleep(self.POLL_INTERVAL_SECONDS)
423444
final_response = self._poll_query(command_id)
424-
state = final_response.status.state
425445

426-
self._check_command_not_in_failed_or_closed_state(state, command_id)
446+
self._check_command_not_in_failed_or_closed_state(
447+
final_response.status, command_id
448+
)
427449

428450
return final_response
429451

@@ -457,7 +479,7 @@ def execute_command(
457479
enforce_embedded_schema_correctness: Whether to enforce schema correctness
458480
459481
Returns:
460-
SeaResultSet: A SeaResultSet instance for the executed command
482+
ResultSet: A SeaResultSet instance for the executed command
461483
"""
462484

463485
if session_id.backend_type != BackendType.SEA:
@@ -513,14 +535,6 @@ def execute_command(
513535
)
514536
response = ExecuteStatementResponse.from_dict(response_data)
515537
statement_id = response.statement_id
516-
if not statement_id:
517-
raise ServerOperationError(
518-
"Failed to execute command: No statement ID returned",
519-
{
520-
"operation-id": None,
521-
"diagnostic-info": None,
522-
},
523-
)
524538

525539
command_id = CommandId.from_sea_statement_id(statement_id)
526540

@@ -552,8 +566,6 @@ def cancel_command(self, command_id: CommandId) -> None:
552566
raise ValueError("Not a valid SEA command ID")
553567

554568
sea_statement_id = command_id.to_sea_statement_id()
555-
if sea_statement_id is None:
556-
raise ValueError("Not a valid SEA command ID")
557569

558570
request = CancelStatementRequest(statement_id=sea_statement_id)
559571
self._http_client._make_request(
@@ -577,8 +589,6 @@ def close_command(self, command_id: CommandId) -> None:
577589
raise ValueError("Not a valid SEA command ID")
578590

579591
sea_statement_id = command_id.to_sea_statement_id()
580-
if sea_statement_id is None:
581-
raise ValueError("Not a valid SEA command ID")
582592

583593
request = CloseStatementRequest(statement_id=sea_statement_id)
584594
self._http_client._make_request(
@@ -596,8 +606,6 @@ def _poll_query(self, command_id: CommandId) -> GetStatementResponse:
596606
raise ValueError("Not a valid SEA command ID")
597607

598608
sea_statement_id = command_id.to_sea_statement_id()
599-
if sea_statement_id is None:
600-
raise ValueError("Not a valid SEA command ID")
601609

602610
request = GetStatementRequest(statement_id=sea_statement_id)
603611
response_data = self._http_client._make_request(
@@ -620,7 +628,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
620628
CommandState: The current state of the command
621629
622630
Raises:
623-
ProgrammingError: If the command ID is invalid
631+
ValueError: If the command ID is invalid
624632
"""
625633

626634
response = self._poll_query(command_id)

src/databricks/sql/backend/sea/models/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
StatementStatus,
1010
ExternalLink,
1111
ResultData,
12-
ColumnInfo,
1312
ResultManifest,
1413
)
1514

@@ -36,7 +35,6 @@
3635
"StatementStatus",
3736
"ExternalLink",
3837
"ResultData",
39-
"ColumnInfo",
4038
"ResultManifest",
4139
# Request models
4240
"StatementParameter",

src/databricks/sql/backend/sea/models/base.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,12 @@ class ResultData:
6767
attachment: Optional[bytes] = None
6868

6969

70-
@dataclass
71-
class ColumnInfo:
72-
"""Information about a column in the result set."""
73-
74-
name: str
75-
type_name: str
76-
type_text: str
77-
nullable: bool = True
78-
precision: Optional[int] = None
79-
scale: Optional[int] = None
80-
ordinal_position: Optional[int] = None
81-
82-
8370
@dataclass
8471
class ResultManifest:
8572
"""Manifest information for a result set."""
8673

8774
format: str
88-
schema: Dict[str, Any] # Will contain column information
75+
schema: Dict[str, Any]
8976
total_row_count: int
9077
total_byte_count: int
9178
total_chunk_count: int

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def to_dict(self) -> Dict[str, Any]:
5454
result["parameters"] = [
5555
{
5656
"name": param.name,
57-
**({"value": param.value} if param.value is not None else {}),
58-
**({"type": param.type} if param.type is not None else {}),
57+
"value": param.value,
58+
"type": param.type,
5959
}
6060
for param in self.parameters
6161
]

0 commit comments

Comments
 (0)