From 73832906b769feeea8865c0ad1311032d01de026 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 09:55:05 +0530 Subject: [PATCH 1/7] decouple session class from existing Connection ensure maintenance of current APIs of Connection while delegating responsibility Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 152 ++++++++++------------------------ src/databricks/sql/session.py | 146 ++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 108 deletions(-) create mode 100644 src/databricks/sql/session.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f24a6584a..0fbd20df5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -19,6 +19,8 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + Error, + NotSupportedError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -45,6 +47,7 @@ from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.session import Session from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -218,66 +221,24 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} - self.open = False - self.host = server_hostname - self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self._cursors = [] # type: List[Cursor] - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - - user_agent_entry = kwargs.get("user_agent_entry") - if user_agent_entry is None: - user_agent_entry = kwargs.get("_user_agent_entry") - if user_agent_entry is not None: - logger.warning( - "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " - "This parameter will be removed in the upcoming releases." - ) - - if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) - else: - useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) - - base_headers = [("User-Agent", useragent_header)] - - self._ssl_options = SSLOptions( - # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - self.thrift_backend = ThriftBackend( - self.host, - self.port, + # Create the session + self.session = Session( + server_hostname, http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) - - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema + http_headers, + session_configuration, + catalog, + schema, + _use_arrow_native_complex_types, + **kwargs ) - self._session_handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) - self.open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) - self._cursors = [] # type: List[Cursor] + + logger.info("Successfully opened connection with session " + str(self.get_session_id_hex())) self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) @@ -318,7 +279,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - if self.open: + if self.session.open: logger.debug( "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) @@ -330,34 +291,27 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + """Get the session ID from the Session object""" + return self.session.get_session_id() - @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_session_id_hex(self): + """Get the session ID in hex format from the Session object""" + return self.session.get_session_id_hex() @staticmethod def server_parameterized_queries_enabled(protocolVersion): - if ( - protocolVersion - and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ): - return True - else: - return False + """Delegate to Session class static method""" + return Session.server_parameterized_queries_enabled(protocolVersion) - def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + @property + def protocol_version(self): + """Get the protocol version from the Session object""" + return self.session.protocol_version + + @staticmethod + def get_protocol_version(openSessionResp): + """Delegate to Session class static method""" + return Session.get_protocol_version(openSessionResp) def cursor( self, @@ -369,12 +323,12 @@ def cursor( Will throw an Error if the connection has been closed. """ - if not self.open: + if not self.session.open: raise Error("Cannot create cursor from closed connection") cursor = Cursor( self, - self.thrift_backend, + self.session.thrift_backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -390,28 +344,10 @@ def _close(self, close_cursors=True) -> None: for cursor in self._cursors: cursor.close() - logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: - logger.debug("Session appears to have been closed already") - try: - self.thrift_backend.close_session(self._session_handle) - except RequestError as e: - if isinstance(e.args[1], SessionAlreadyClosedError): - logger.info("Session was closed by a prior request") - except DatabaseError as e: - if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) - else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + self.session.close() except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") - - self.open = False + logger.error(f"Attempt to close session raised an exception: {e}") def commit(self): """No-op because Databricks does not support transactions""" @@ -811,7 +747,7 @@ def execute( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -874,7 +810,7 @@ def execute_async( self._close_and_clear_active_result_set() self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -970,7 +906,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -996,7 +932,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1029,7 +965,7 @@ def tables( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_tables( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1064,7 +1000,7 @@ def columns( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_columns( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1493,7 +1429,7 @@ def close(self) -> None: if ( self.op_state != self.thrift_backend.CLOSED_OP_STATE and not self.has_been_closed_server_side - and self.connection.open + and self.connection.session.open ): self.thrift_backend.close_command(self.command_id) except RequestError as e: diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py new file mode 100644 index 000000000..4920550e7 --- /dev/null +++ b/src/databricks/sql/session.py @@ -0,0 +1,146 @@ +import logging +from typing import Dict, Tuple, List, Optional, Any + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError +from databricks.sql import __version__ +from databricks.sql import USER_AGENT_NAME +from databricks.sql.thrift_backend import ThriftBackend + +logger = logging.getLogger(__name__) + + +class Session: + def __init__( + self, + server_hostname: str, + http_path: str, + http_headers: Optional[List[Tuple[str, str]]] = None, + session_configuration: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ) -> None: + """ + Create a session to a Databricks SQL endpoint or a Databricks cluster. + + This class handles all session-related behavior and communication with the backend. + """ + self.open = False + self.host = server_hostname + self.port = kwargs.get("_port", 443) + + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + + user_agent_entry = kwargs.get("user_agent_entry") + if user_agent_entry is None: + user_agent_entry = kwargs.get("_user_agent_entry") + if user_agent_entry is not None: + logger.warning( + "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " + "This parameter will be removed in the upcoming releases." + ) + + if user_agent_entry: + useragent_header = "{}/{} ({})".format( + USER_AGENT_NAME, __version__, user_agent_entry + ) + else: + useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + + base_headers = [("User-Agent", useragent_header)] + + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + self.thrift_backend = ThriftBackend( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self._open_session_resp = self.thrift_backend.open_session( + session_configuration, catalog, schema + ) + self._session_handle = self._open_session_resp.sessionHandle + self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.open = True + logger.info("Successfully opened session " + str(self.get_session_id_hex())) + + @staticmethod + def get_protocol_version(openSessionResp): + """ + Since the sessionHandle will sometimes have a serverProtocolVersion, it takes + precedence over the serverProtocolVersion defined in the OpenSessionResponse. + """ + if ( + openSessionResp.sessionHandle + and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") + and openSessionResp.sessionHandle.serverProtocolVersion + ): + return openSessionResp.sessionHandle.serverProtocolVersion + return openSessionResp.serverProtocolVersion + + @staticmethod + def server_parameterized_queries_enabled(protocolVersion): + if ( + protocolVersion + and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ): + return True + else: + return False + + def get_session_handle(self): + return self._session_handle + + def get_session_id(self): + return self.thrift_backend.handle_to_id(self._session_handle) + + def get_session_id_hex(self): + return self.thrift_backend.handle_to_hex_id(self._session_handle) + + def close(self) -> None: + """Close the underlying session.""" + logger.info(f"Closing session {self.get_session_id_hex()}") + if not self.open: + logger.debug("Session appears to have been closed already") + return + + try: + self.thrift_backend.close_session(self._session_handle) + except RequestError as e: + if isinstance(e.args[1], SessionAlreadyClosedError): + logger.info("Session was closed by a prior request") + except DatabaseError as e: + if "Invalid SessionHandle" in str(e): + logger.warning( + f"Attempted to close session that was already closed: {e}" + ) + else: + logger.warning( + f"Attempt to close session raised an exception at the server: {e}" + ) + except Exception as e: + logger.error(f"Attempt to close session raised a local exception: {e}") + + self.open = False \ No newline at end of file From 0e6efd88af09da9c4e5783fc77c14bcc744fb339 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:35:55 +0530 Subject: [PATCH 2/7] add open property to Connection to ensure maintenance of existing API Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0fbd20df5..6c89ef0a1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -313,6 +313,11 @@ def get_protocol_version(openSessionResp): """Delegate to Session class static method""" return Session.get_protocol_version(openSessionResp) + @property + def open(self) -> bool: + """Return whether the connection is open by checking if the session is open.""" + return self.session.open + def cursor( self, arraysize: int = DEFAULT_ARRAY_SIZE, From e7ebe2bb5584501225a79cc2ac6bdfdb44696c59 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:36:36 +0530 Subject: [PATCH 3/7] update unit tests to address ThriftBackend through session instead of through Connection Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 49 ++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c39aeb524..58607cf49 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -80,7 +80,7 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value @@ -95,7 +95,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): # Test that the following auth args work: # token = foo, @@ -122,7 +122,7 @@ def test_auth_args(self, mock_client_class): self.assertEqual(args["http_path"], http_path) connection.close() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) @@ -130,7 +130,7 @@ def test_http_header_passthrough(self, mock_client_class): call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, @@ -146,7 +146,7 @@ def test_tls_arg_passthrough(self, mock_client_class): self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -167,7 +167,7 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without @@ -184,7 +184,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): ) mock_result_set_class.return_value.close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -194,7 +194,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -214,7 +214,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): thrift_backend=mock_backend, execute_response=Mock(), ) - mock_connection.open = False + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = False + type(mock_connection).session = PropertyMock(return_value=mock_session) result_set.close() @@ -226,7 +229,11 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() - mock_connection.open = True + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = True + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -283,7 +290,7 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -396,7 +403,7 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS @@ -406,7 +413,7 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) @@ -419,7 +426,7 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() databricks.sql.connect( @@ -431,7 +438,7 @@ def test_configuration_passthrough(self, mock_client_class): mock_session_config, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() @@ -505,7 +512,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -518,7 +525,7 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): @@ -603,7 +610,7 @@ def test_column_name_api(self): }, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -620,7 +627,7 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -639,7 +646,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): @@ -658,7 +665,7 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" From c63f6fd5fd254e7c4b97761583e621b9b0c65ecc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:43:24 +0530 Subject: [PATCH 4/7] chore: move session specific tests from test_client to test_session Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 161 ------------------------------- tests/unit/test_session.py | 187 +++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 161 deletions(-) create mode 100644 tests/unit/test_session.py diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 58607cf49..ecbf3493b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -80,93 +80,6 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_close_uses_the_correct_session_id(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_auth_args(self, mock_client_class): - # Test that the following auth args work: - # token = foo, - # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True - connection_args = [ - { - "server_hostname": "foo", - "http_path": None, - "access_token": "tok", - }, - { - "server_hostname": "foo", - "http_path": None, - "_tls_client_cert_file": "something", - "_use_cert_as_auth": True, - "access_token": None, - }, - ] - - for args in connection_args: - connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) - connection.close() - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_http_header_passthrough(self, mock_client_class): - http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, - _tls_verify_hostname="hostname", - _tls_trusted_ca_file="trusted ca file", - _tls_client_cert_key_file="trusted client cert", - _tls_client_cert_key_password="key password", - ) - - kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - http_headers = mock_client_class.call_args[0][3] - user_agent_header = ( - "User-Agent", - "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), - ) - self.assertIn(user_agent_header, http_headers) - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") - user_agent_header_with_entry = ( - "User-Agent", - "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" - ), - ) - http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): @@ -290,21 +203,6 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - def dict_product(self, dicts): """ Generate cartesion product of values in input dictionary, outputting a dictionary @@ -403,21 +301,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - def test_version_is_canonical(self): version = databricks.sql.__version__ canonical_version_re = ( @@ -426,33 +309,6 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_initial_namespace_passthrough(self, mock_client_class): - mock_cat = Mock() - mock_schem = Mock() - - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) @@ -610,23 +466,6 @@ def test_column_name_api(self): }, ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_finalizer_closes_abandoned_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - # not strictly necessary as the refcount is 0, but just to be sure - gc.collect() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..6a49abef6 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,187 @@ +import unittest +from unittest.mock import patch, MagicMock, Mock, PropertyMock +import gc + +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, +) + +import databricks.sql + + +class SessionTestSuite(unittest.TestCase): + """ + Unit tests for Session functionality + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_close_uses_the_correct_session_id(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_auth_args(self, mock_client_class): + # Test that the following auth args work: + # token = foo, + # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True + connection_args = [ + { + "server_hostname": "foo", + "http_path": None, + "access_token": "tok", + }, + { + "server_hostname": "foo", + "http_path": None, + "_tls_client_cert_file": "something", + "_use_cert_as_auth": True, + "access_token": None, + }, + ] + + for args in connection_args: + connection = databricks.sql.connect(**args) + host, port, http_path, *_ = mock_client_class.call_args[0] + self.assertEqual(args["server_hostname"], host) + self.assertEqual(args["http_path"], http_path) + connection.close() + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_http_header_passthrough(self, mock_client_class): + http_headers = [("foo", "bar")] + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + + call_args = mock_client_class.call_args[0][3] + self.assertIn(("foo", "bar"), call_args) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_tls_arg_passthrough(self, mock_client_class): + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + _tls_verify_hostname="hostname", + _tls_trusted_ca_file="trusted ca file", + _tls_client_cert_key_file="trusted client cert", + _tls_client_cert_key_password="key password", + ) + + kwargs = mock_client_class.call_args[1] + self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") + self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") + self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") + self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_useragent_header(self, mock_client_class): + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + http_headers = mock_client_class.call_args[0][3] + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) + self.assertIn(user_agent_header, http_headers) + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) + http_headers = mock_client_class.call_args[0][3] + self.assertIn(user_agent_header_with_entry, http_headers) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_max_number_of_retries_passthrough(self, mock_client_class): + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_socket_timeout_passthrough(self, mock_client_class): + databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_initial_namespace_passthrough(self, mock_client_class): + mock_cat = Mock() + mock_schem = Mock() + + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_finalizer_closes_abandoned_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + # not strictly necessary as the refcount is 0, but just to be sure + gc.collect() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 23cc72a0e6243e63c26b6887b7ea799c5adccca0 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 11:14:16 +0530 Subject: [PATCH 5/7] formatting (black) as in CONTRIBUTING.md Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 9 ++++++--- src/databricks/sql/session.py | 6 +++--- tests/unit/test_client.py | 2 +- tests/unit/test_session.py | 4 ++-- tests/unit/test_thrift_backend.py | 4 +++- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 6c89ef0a1..098aa0548 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -235,10 +235,13 @@ def read(self) -> Optional[OAuthToken]: catalog, schema, _use_arrow_native_complex_types, - **kwargs + **kwargs, + ) + + logger.info( + "Successfully opened connection with session " + + str(self.get_session_id_hex()) ) - - logger.info("Successfully opened connection with session " + str(self.get_session_id_hex())) self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 4920550e7..a308b71d5 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -26,13 +26,13 @@ def __init__( ) -> None: """ Create a session to a Databricks SQL endpoint or a Databricks cluster. - + This class handles all session-related behavior and communication with the backend. """ self.open = False self.host = server_hostname self.port = kwargs.get("_port", 443) - + auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs ) @@ -143,4 +143,4 @@ def close(self) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - self.open = False \ No newline at end of file + self.open = False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ecbf3493b..b67101943 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -146,7 +146,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session = Mock() mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6a49abef6..eb392a229 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -11,7 +11,7 @@ class SessionTestSuite(unittest.TestCase): """ - Unit tests for Session functionality + Unit tests for Session functionality """ PACKAGE_NAME = "databricks.sql" @@ -184,4 +184,4 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7fe318446..458ea9a82 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -86,7 +86,9 @@ def test_make_request_checks_thrift_status_code(self): def _make_type_desc(self, type): return ttypes.TTypeDesc( - types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))] + types=[ + ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type)) + ] ) def _make_fake_thrift_backend(self): From 49d2bcd97b8af11a4135749a74913c7d67eee828 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 16:18:44 +0530 Subject: [PATCH 6/7] use connection open property instead of long chain through session Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 098aa0548..54a097641 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -282,7 +282,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - if self.session.open: + if self.open: logger.debug( "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) @@ -331,7 +331,7 @@ def cursor( Will throw an Error if the connection has been closed. """ - if not self.session.open: + if not self.open: raise Error("Cannot create cursor from closed connection") cursor = Cursor( @@ -1437,7 +1437,7 @@ def close(self) -> None: if ( self.op_state != self.thrift_backend.CLOSED_OP_STATE and not self.has_been_closed_server_side - and self.connection.session.open + and self.connection.open ): self.thrift_backend.close_command(self.command_id) except RequestError as e: From 078f91aa101778db773ccf170f6cada552da275f Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 16:47:02 +0530 Subject: [PATCH 7/7] trigger integration workflow Signed-off-by: varun-edachali-dbx