@@ -52,7 +52,6 @@ def new(cls):
5252 )
5353
5454 ThriftBackendMock .execute_command .return_value = MockTExecuteStatementResp
55-
5655 return ThriftBackendMock
5756
5857 @classmethod
@@ -352,30 +351,30 @@ def test_context_manager_closes_cursor(self):
352351 finally :
353352 cursor .close .assert_called ()
354353
355- # @patch("%s.client.ThriftBackend" % PACKAGE_NAME)
356- # def test_context_manager_closes_connection(self, mock_client_class):
357- # print("hellow1")
358- # instance = mock_client_class.return_value
359-
360- # mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
361- # mock_open_session_resp.sessionHandle.sessionId = b"\x22"
362- # instance.open_session.return_value = mock_open_session_resp
363-
364- # with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection:
365- # pass
366-
367- # # Check the close session request has an id of x22
368- # close_session_id = instance.close_session.call_args[0][0].sessionId
369- # self.assertEqual(close_session_id, b"\x22")
370-
371- # connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
372- # connection.close = Mock()
373- # try:
374- # with self.assertRaises(KeyboardInterrupt):
375- # with connection:
376- # raise KeyboardInterrupt("Simulated interrupt")
377- # finally:
378- # connection.close.assert_called()
354+ @patch ("%s.client.ThriftBackend" % PACKAGE_NAME )
355+ def test_context_manager_closes_connection (self , mock_client_class ):
356+ print ("hellow1" )
357+ instance = mock_client_class .return_value
358+
359+ mock_open_session_resp = MagicMock (spec = TOpenSessionResp )()
360+ mock_open_session_resp .sessionHandle .sessionId = b"\x22 "
361+ instance .open_session .return_value = mock_open_session_resp
362+
363+ with databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS ) as connection :
364+ pass
365+
366+ # Check the close session request has an id of x22
367+ close_session_id = instance .close_session .call_args [0 ][0 ].sessionId
368+ self .assertEqual (close_session_id , b"\x22 " )
369+
370+ connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
371+ connection .close = Mock ()
372+ try :
373+ with self .assertRaises (KeyboardInterrupt ):
374+ with connection :
375+ raise KeyboardInterrupt ("Simulated interrupt" )
376+ finally :
377+ connection .close .assert_called ()
379378
380379 def dict_product (self , dicts ):
381380 """
0 commit comments