@@ -128,6 +128,15 @@ def test_context_manager_closes_connection(self, mock_client_class):
128128 self .assertEqual (close_session_call_args .guid , b"\x22 " )
129129 self .assertEqual (close_session_call_args .secret , b"\x33 " )
130130
131+ connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
132+ connection .close = Mock ()
133+ try :
134+ with self .assertRaises (KeyboardInterrupt ):
135+ with connection :
136+ raise KeyboardInterrupt ("Simulated interrupt" )
137+ finally :
138+ connection .close .assert_called ()
139+
131140 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
132141 def test_max_number_of_retries_passthrough (self , mock_client_class ):
133142 databricks .sql .connect (
@@ -146,33 +155,21 @@ def test_socket_timeout_passthrough(self, mock_client_class):
146155 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
147156 def test_configuration_passthrough (self , mock_client_class ):
148157 mock_session_config = Mock ()
149-
150- # Create a mock SessionId that will be returned by open_session
151- mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
152- mock_client_class .return_value .open_session .return_value = mock_session_id
153-
154158 databricks .sql .connect (
155159 session_configuration = mock_session_config , ** self .DUMMY_CONNECTION_ARGS
156160 )
157161
158- # Check that open_session was called with the correct session_configuration as keyword argument
159162 call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
160163 self .assertEqual (call_kwargs ["session_configuration" ], mock_session_config )
161164
162165 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
163166 def test_initial_namespace_passthrough (self , mock_client_class ):
164167 mock_cat = Mock ()
165168 mock_schem = Mock ()
166-
167- # Create a mock SessionId that will be returned by open_session
168- mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
169- mock_client_class .return_value .open_session .return_value = mock_session_id
170-
171169 databricks .sql .connect (
172170 ** self .DUMMY_CONNECTION_ARGS , catalog = mock_cat , schema = mock_schem
173171 )
174172
175- # Check that open_session was called with the correct catalog and schema as keyword arguments
176173 call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
177174 self .assertEqual (call_kwargs ["catalog" ], mock_cat )
178175 self .assertEqual (call_kwargs ["schema" ], mock_schem )
@@ -181,7 +178,6 @@ def test_initial_namespace_passthrough(self, mock_client_class):
181178 def test_finalizer_closes_abandoned_connection (self , mock_client_class ):
182179 instance = mock_client_class .return_value
183180
184- # Create a mock SessionId that will be returned by open_session
185181 mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
186182 instance .open_session .return_value = mock_session_id
187183
0 commit comments