Skip to content

Commit 80b7bc3

Browse files
Merge remote-tracking branch 'origin/backend-interface' into fetch-interface
2 parents 00d9aeb + 1ec8c45 commit 80b7bc3

File tree

5 files changed

+142
-55
lines changed

5 files changed

+142
-55
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1+
"""
2+
Abstract client interface for interacting with Databricks SQL services.
3+
4+
Implementations of this class are responsible for:
5+
- Managing connections to Databricks SQL services
6+
- Handling authentication
7+
- Executing SQL queries and commands
8+
- Retrieving query results
9+
- Fetching metadata about catalogs, schemas, tables, and columns
10+
- Managing error handling and retries
11+
"""
12+
113
from abc import ABC, abstractmethod
2-
from typing import Dict, Tuple, List, Optional, Any, Union
14+
from typing import Dict, Tuple, List, Optional, Any, Union, TYPE_CHECKING
15+
16+
if TYPE_CHECKING:
17+
from databricks.sql.client import Cursor
318

419
from databricks.sql.thrift_api.TCLIService import ttypes
520
from databricks.sql.backend.types import SessionId, CommandId, CommandState
@@ -22,10 +37,42 @@ def open_session(
2237
catalog: Optional[str],
2338
schema: Optional[str],
2439
) -> SessionId:
40+
"""
41+
Opens a new session with the Databricks SQL service.
42+
43+
This method establishes a new session with the server and returns a session
44+
identifier that can be used for subsequent operations.
45+
46+
Args:
47+
session_configuration: Optional dictionary of configuration parameters for the session
48+
catalog: Optional catalog name to use as the initial catalog for the session
49+
schema: Optional schema name to use as the initial schema for the session
50+
51+
Returns:
52+
SessionId: A session identifier object that can be used for subsequent operations
53+
54+
Raises:
55+
Error: If the session configuration is invalid
56+
OperationalError: If there's an error establishing the session
57+
InvalidServerResponseError: If the server response is invalid or unexpected
58+
"""
2559
pass
2660

2761
@abstractmethod
2862
def close_session(self, session_id: SessionId) -> None:
63+
"""
64+
Closes an existing session with the Databricks SQL service.
65+
66+
This method terminates the session identified by the given session ID and
67+
releases any resources associated with it.
68+
69+
Args:
70+
session_id: The session identifier returned by open_session()
71+
72+
Raises:
73+
ValueError: If the session ID is invalid
74+
OperationalError: If there's an error closing the session
75+
"""
2976
pass
3077

3178
# == Query Execution, Command Management ==
@@ -37,7 +84,7 @@ def execute_command(
3784
max_rows: int,
3885
max_bytes: int,
3986
lz4_compression: bool,
40-
cursor: Any,
87+
cursor: "Cursor",
4188
use_cloud_fetch: bool,
4289
parameters: List[ttypes.TSparkParameter],
4390
async_op: bool,
@@ -47,6 +94,19 @@ def execute_command(
4794

4895
@abstractmethod
4996
def cancel_command(self, command_id: CommandId) -> None:
97+
"""
98+
Cancels a running command or query.
99+
100+
This method attempts to cancel a command that is currently being executed.
101+
It can be called from a different thread than the one executing the command.
102+
103+
Args:
104+
command_id: The command identifier to cancel
105+
106+
Raises:
107+
ValueError: If the command ID is invalid
108+
OperationalError: If there's an error canceling the command
109+
"""
50110
pass
51111

52112
@abstractmethod
@@ -82,7 +142,7 @@ def get_schemas(
82142
session_id: SessionId,
83143
max_rows: int,
84144
max_bytes: int,
85-
cursor: Any,
145+
cursor: "Cursor",
86146
catalog_name: Optional[str] = None,
87147
schema_name: Optional[str] = None,
88148
) -> "ResultSet":
@@ -94,7 +154,7 @@ def get_tables(
94154
session_id: SessionId,
95155
max_rows: int,
96156
max_bytes: int,
97-
cursor: Any,
157+
cursor: "Cursor",
98158
catalog_name: Optional[str] = None,
99159
schema_name: Optional[str] = None,
100160
table_name: Optional[str] = None,
@@ -108,35 +168,45 @@ def get_columns(
108168
session_id: SessionId,
109169
max_rows: int,
110170
max_bytes: int,
111-
cursor: Any,
171+
cursor: "Cursor",
112172
catalog_name: Optional[str] = None,
113173
schema_name: Optional[str] = None,
114174
table_name: Optional[str] = None,
115175
column_name: Optional[str] = None,
116176
) -> "ResultSet":
117177
pass
118178

119-
# == Utility Methods ==
120-
@abstractmethod
121-
def handle_to_id(self, session_id: SessionId) -> Any:
122-
pass
123-
124-
@abstractmethod
125-
def handle_to_hex_id(self, session_id: SessionId) -> str:
126-
pass
127-
128-
# Properties related to specific backend features
179+
# == Properties ==
129180
@property
130181
@abstractmethod
131182
def staging_allowed_local_path(self) -> Union[None, str, List[str]]:
183+
"""
184+
Gets the allowed local paths for staging operations.
185+
186+
Returns:
187+
Union[None, str, List[str]]: The allowed local paths for staging operations,
188+
or None if staging is not allowed
189+
"""
132190
pass
133191

134192
@property
135193
@abstractmethod
136194
def ssl_options(self) -> SSLOptions:
195+
"""
196+
Gets the SSL options for this client.
197+
198+
Returns:
199+
SSLOptions: The SSL configuration options
200+
"""
137201
pass
138202

139203
@property
140204
@abstractmethod
141205
def max_download_threads(self) -> int:
206+
"""
207+
Gets the maximum number of download threads for cloud fetch operations.
208+
209+
Returns:
210+
int: The maximum number of download threads
211+
"""
142212
pass

src/databricks/sql/backend/thrift_backend.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import time
66
import uuid
77
import threading
8-
from typing import List, Union, Any
8+
from typing import List, Union, Any, TYPE_CHECKING
9+
10+
if TYPE_CHECKING:
11+
from databricks.sql.client import Cursor
912

1013
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1114
from databricks.sql.backend.types import (
@@ -585,12 +588,12 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
585588
response = self.make_request(self._client.OpenSession, open_session_req)
586589
self._check_initial_namespace(catalog, schema, response)
587590
self._check_protocol_version(response)
588-
info = (
591+
properties = (
589592
{"serverProtocolVersion": response.serverProtocolVersion}
590593
if response.serverProtocolVersion
591594
else {}
592595
)
593-
return SessionId.from_thrift_handle(response.sessionHandle, info)
596+
return SessionId.from_thrift_handle(response.sessionHandle, properties)
594597
except:
595598
self._transport.close()
596599
raise
@@ -802,6 +805,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
802805
arrow_queue_opt = None
803806

804807
command_id = CommandId.from_thrift_handle(resp.operationHandle)
808+
805809
return ExecuteResponse(
806810
arrow_queue=arrow_queue_opt,
807811
status=operation_state,
@@ -814,7 +818,9 @@ def _results_message_to_execute_response(self, resp, operation_state):
814818
arrow_schema_bytes=schema_bytes,
815819
)
816820

817-
def get_execution_result(self, command_id: CommandId, cursor):
821+
def get_execution_result(
822+
self, command_id: CommandId, cursor: "Cursor"
823+
) -> ExecuteResponse:
818824
thrift_handle = command_id.to_thrift_handle()
819825
if not thrift_handle:
820826
raise ValueError("Not a valid Thrift command ID")
@@ -939,7 +945,7 @@ def execute_command(
939945
max_rows: int,
940946
max_bytes: int,
941947
lz4_compression: bool,
942-
cursor: Any,
948+
cursor: "Cursor",
943949
use_cloud_fetch=True,
944950
parameters=[],
945951
async_op=False,
@@ -1037,7 +1043,7 @@ def get_schemas(
10371043
session_id: SessionId,
10381044
max_rows: int,
10391045
max_bytes: int,
1040-
cursor: Any,
1046+
cursor: "Cursor",
10411047
catalog_name=None,
10421048
schema_name=None,
10431049
) -> "ResultSet":
@@ -1071,7 +1077,7 @@ def get_tables(
10711077
session_id: SessionId,
10721078
max_rows: int,
10731079
max_bytes: int,
1074-
cursor: Any,
1080+
cursor: "Cursor",
10751081
catalog_name=None,
10761082
schema_name=None,
10771083
table_name=None,
@@ -1109,7 +1115,7 @@ def get_columns(
11091115
session_id: SessionId,
11101116
max_rows: int,
11111117
max_bytes: int,
1112-
cursor: Any,
1118+
cursor: "Cursor",
11131119
catalog_name=None,
11141120
schema_name=None,
11151121
table_name=None,
@@ -1230,18 +1236,4 @@ def close_command(self, command_id: CommandId) -> None:
12301236
logger.debug("ThriftBackend.close_command(command_id=%s)", command_id)
12311237
req = ttypes.TCloseOperationReq(operationHandle=thrift_handle)
12321238
resp = self.make_request(self._client.CloseOperation, req)
1233-
logger.debug(
1234-
"ThriftBackend.close_command(command_id=%s) -> %s", command_id, resp
1235-
)
1236-
1237-
def handle_to_id(self, session_id: SessionId) -> Any:
1238-
"""Get the raw session ID from a SessionId"""
1239-
if session_id.backend_type != BackendType.THRIFT:
1240-
raise ValueError("Not a valid Thrift session ID")
1241-
return session_id.guid
1242-
1243-
def handle_to_hex_id(self, session_id: SessionId) -> str:
1244-
"""Get the hex representation of a session ID"""
1245-
if session_id.backend_type != BackendType.THRIFT:
1246-
raise ValueError("Not a valid Thrift session ID")
1247-
return guid_to_hex_id(session_id.guid)
1239+
return resp.status

src/databricks/sql/backend/types.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
backend_type: BackendType,
7979
guid: Any,
8080
secret: Optional[Any] = None,
81-
info: Optional[Dict[str, Any]] = None,
81+
properties: Optional[Dict[str, Any]] = None,
8282
):
8383
"""
8484
Initialize a SessionId.
@@ -92,10 +92,28 @@ def __init__(
9292
self.backend_type = backend_type
9393
self.guid = guid
9494
self.secret = secret
95-
self.info = info or {}
95+
self.properties = properties or {}
96+
97+
def __str__(self) -> str:
98+
"""
99+
Return a string representation of the SessionId.
100+
101+
For SEA backend, returns the guid.
102+
For Thrift backend, returns a format like "guid|secret".
103+
104+
Returns:
105+
A string representation of the session ID
106+
"""
107+
if self.backend_type == BackendType.SEA:
108+
return str(self.guid)
109+
elif self.backend_type == BackendType.THRIFT:
110+
return f"{self.get_hex_id()}|{guid_to_hex_id(self.secret) if isinstance(self.secret, bytes) else str(self.secret)}"
111+
return str(self.guid)
96112

97113
@classmethod
98-
def from_thrift_handle(cls, session_handle, info: Optional[Dict[str, Any]] = None):
114+
def from_thrift_handle(
115+
cls, session_handle, properties: Optional[Dict[str, Any]] = None
116+
):
99117
"""
100118
Create a SessionId from a Thrift session handle.
101119
@@ -112,15 +130,15 @@ def from_thrift_handle(cls, session_handle, info: Optional[Dict[str, Any]] = Non
112130
secret_bytes = session_handle.sessionId.secret
113131

114132
if session_handle.serverProtocolVersion is not None:
115-
if info is None:
116-
info = {}
117-
info["serverProtocolVersion"] = session_handle.serverProtocolVersion
133+
if properties is None:
134+
properties = {}
135+
properties["serverProtocolVersion"] = session_handle.serverProtocolVersion
118136

119-
return cls(BackendType.THRIFT, guid_bytes, secret_bytes, info)
137+
return cls(BackendType.THRIFT, guid_bytes, secret_bytes, properties)
120138

121139
@classmethod
122140
def from_sea_session_id(
123-
cls, session_id: str, info: Optional[Dict[str, Any]] = None
141+
cls, session_id: str, properties: Optional[Dict[str, Any]] = None
124142
):
125143
"""
126144
Create a SessionId from a SEA session ID.
@@ -131,7 +149,7 @@ def from_sea_session_id(
131149
Returns:
132150
A SessionId instance
133151
"""
134-
return cls(BackendType.SEA, session_id, info=info)
152+
return cls(BackendType.SEA, session_id, properties=properties)
135153

136154
def to_thrift_handle(self):
137155
"""
@@ -146,7 +164,7 @@ def to_thrift_handle(self):
146164
from databricks.sql.thrift_api.TCLIService import ttypes
147165

148166
handle_identifier = ttypes.THandleIdentifier(guid=self.guid, secret=self.secret)
149-
server_protocol_version = self.info.get("serverProtocolVersion")
167+
server_protocol_version = self.properties.get("serverProtocolVersion")
150168
return ttypes.TSessionHandle(
151169
sessionId=handle_identifier, serverProtocolVersion=server_protocol_version
152170
)
@@ -163,7 +181,13 @@ def to_sea_session_id(self):
163181

164182
return self.guid
165183

166-
def to_hex_id(self) -> str:
184+
def get_id(self) -> Any:
185+
"""
186+
Get the ID of the session.
187+
"""
188+
return self.guid
189+
190+
def get_hex_id(self) -> str:
167191
"""
168192
Get a hexadecimal string representation of the session ID.
169193
@@ -180,9 +204,10 @@ def get_protocol_version(self):
180204
Get the server protocol version for this session.
181205
182206
Returns:
183-
The server protocol version or None if this is not a Thrift session ID
207+
The server protocol version or None if it does not exist
208+
It is not expected to exist for SEA sessions.
184209
"""
185-
return self.info.get("serverProtocolVersion")
210+
return self.properties.get("serverProtocolVersion")
186211

187212

188213
class CommandId:

src/databricks/sql/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def get_session_id(self) -> SessionId:
116116

117117
def get_id(self):
118118
"""Get the raw session ID (backend-specific)"""
119-
return self.backend.handle_to_id(self._session_id)
119+
return self._session_id.get_id()
120120

121121
def get_id_hex(self) -> str:
122122
"""Get the session ID in hex format"""
123-
return self.backend.handle_to_hex_id(self._session_id)
123+
return self._session_id.get_hex_id()
124124

125125
def close(self) -> None:
126126
"""Close the underlying session."""

0 commit comments

Comments
 (0)