Skip to content
Closed
148 changes: 46 additions & 102 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
OperationalError,
SessionAlreadyClosedError,
CursorAlreadyClosedError,
Error,
NotSupportedError,
)
from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.thrift_backend import ThriftBackend
Expand All @@ -45,6 +47,7 @@
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
from databricks.sql.session import Session

from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
Expand Down Expand Up @@ -218,66 +221,27 @@ def read(self) -> Optional[OAuthToken]:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}

self.open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
# Create the session
self.session = Session(
server_hostname,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
http_headers,
session_configuration,
catalog,
schema,
_use_arrow_native_complex_types,
**kwargs,
)

self._open_session_resp = self.thrift_backend.open_session(
session_configuration, catalog, schema
logger.info(
"Successfully opened connection with session "
+ str(self.get_session_id_hex())
)
self._session_handle = self._open_session_resp.sessionHandle
self.protocol_version = self.get_protocol_version(self._open_session_resp)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self.open = True
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
self._cursors = [] # type: List[Cursor]

self.use_inline_params = self._set_use_inline_params_with_warning(
kwargs.get("use_inline_params", False)
Expand Down Expand Up @@ -330,34 +294,32 @@ def __del__(self):
logger.debug("Couldn't close unclosed connection: {}".format(e.message))

def get_session_id(self):
return self.thrift_backend.handle_to_id(self._session_handle)
"""Get the session ID from the Session object"""
return self.session.get_session_id()

@staticmethod
def get_protocol_version(openSessionResp):
"""
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
if (
openSessionResp.sessionHandle
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
and openSessionResp.sessionHandle.serverProtocolVersion
):
return openSessionResp.sessionHandle.serverProtocolVersion
return openSessionResp.serverProtocolVersion
def get_session_id_hex(self):
"""Get the session ID in hex format from the Session object"""
return self.session.get_session_id_hex()

@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False
"""Delegate to Session class static method"""
return Session.server_parameterized_queries_enabled(protocolVersion)

def get_session_id_hex(self):
return self.thrift_backend.handle_to_hex_id(self._session_handle)
@property
def protocol_version(self):
"""Get the protocol version from the Session object"""
return self.session.protocol_version

@staticmethod
def get_protocol_version(openSessionResp):
"""Delegate to Session class static method"""
return Session.get_protocol_version(openSessionResp)

@property
def open(self) -> bool:
"""Return whether the connection is open by checking if the session is open."""
return self.session.open

def cursor(
self,
Expand All @@ -374,7 +336,7 @@ def cursor(

cursor = Cursor(
self,
self.thrift_backend,
self.session.thrift_backend,
arraysize=arraysize,
result_buffer_size_bytes=buffer_size_bytes,
)
Expand All @@ -390,28 +352,10 @@ def _close(self, close_cursors=True) -> None:
for cursor in self._cursors:
cursor.close()

logger.info(f"Closing session {self.get_session_id_hex()}")
if not self.open:
logger.debug("Session appears to have been closed already")

try:
self.thrift_backend.close_session(self._session_handle)
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
f"Attempted to close session that was already closed: {e}"
)
else:
logger.warning(
f"Attempt to close session raised an exception at the server: {e}"
)
self.session.close()
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")

self.open = False
logger.error(f"Attempt to close session raised an exception: {e}")

def commit(self):
"""No-op because Databricks does not support transactions"""
Expand Down Expand Up @@ -811,7 +755,7 @@ def execute(
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
session_handle=self.connection.session._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
Expand Down Expand Up @@ -874,7 +818,7 @@ def execute_async(
self._close_and_clear_active_result_set()
self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
session_handle=self.connection.session._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
Expand Down Expand Up @@ -970,7 +914,7 @@ def catalogs(self) -> "Cursor":
self._check_not_closed()
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.get_catalogs(
session_handle=self.connection._session_handle,
session_handle=self.connection.session._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand All @@ -996,7 +940,7 @@ def schemas(
self._check_not_closed()
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.get_schemas(
session_handle=self.connection._session_handle,
session_handle=self.connection.session._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down Expand Up @@ -1029,7 +973,7 @@ def tables(
self._close_and_clear_active_result_set()

execute_response = self.thrift_backend.get_tables(
session_handle=self.connection._session_handle,
session_handle=self.connection.session._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down Expand Up @@ -1064,7 +1008,7 @@ def columns(
self._close_and_clear_active_result_set()

execute_response = self.thrift_backend.get_columns(
session_handle=self.connection._session_handle,
session_handle=self.connection.session._session_handle,
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down
146 changes: 146 additions & 0 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import logging
from typing import Dict, Tuple, List, Optional, Any

from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError
from databricks.sql import __version__
from databricks.sql import USER_AGENT_NAME
from databricks.sql.thrift_backend import ThriftBackend

logger = logging.getLogger(__name__)


class Session:
def __init__(
self,
server_hostname: str,
http_path: str,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
**kwargs,
) -> None:
"""
Create a session to a Databricks SQL endpoint or a Databricks cluster.

This class handles all session-related behavior and communication with the backend.
"""
self.open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)

self._open_session_resp = self.thrift_backend.open_session(
session_configuration, catalog, schema
)
self._session_handle = self._open_session_resp.sessionHandle
self.protocol_version = self.get_protocol_version(self._open_session_resp)
self.open = True
logger.info("Successfully opened session " + str(self.get_session_id_hex()))

@staticmethod
def get_protocol_version(openSessionResp):
"""
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
if (
openSessionResp.sessionHandle
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
and openSessionResp.sessionHandle.serverProtocolVersion
):
return openSessionResp.sessionHandle.serverProtocolVersion
return openSessionResp.serverProtocolVersion

@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False

def get_session_handle(self):
return self._session_handle

def get_session_id(self):
return self.thrift_backend.handle_to_id(self._session_handle)

def get_session_id_hex(self):
return self.thrift_backend.handle_to_hex_id(self._session_handle)

def close(self) -> None:
"""Close the underlying session."""
logger.info(f"Closing session {self.get_session_id_hex()}")
if not self.open:
logger.debug("Session appears to have been closed already")
return

try:
self.thrift_backend.close_session(self._session_handle)
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
f"Attempted to close session that was already closed: {e}"
)
else:
logger.warning(
f"Attempt to close session raised an exception at the server: {e}"
)
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")

self.open = False
Loading