@@ -623,7 +623,10 @@ def test_handle_execute_response_sets_compression_in_direct_results(
623623 status = Mock (),
624624 operationHandle = Mock (),
625625 directResults = ttypes .TSparkDirectResults (
626- operationStatus = op_status ,
626+ operationStatus = ttypes .TGetOperationStatusResp (
627+ status = self .okay_status ,
628+ operationState = ttypes .TOperationState .FINISHED_STATE ,
629+ ),
627630 resultSetMetadata = ttypes .TGetResultSetMetadataResp (
628631 status = self .okay_status ,
629632 resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET ,
@@ -832,9 +835,10 @@ def test_handle_execute_response_checks_direct_results_for_error_statuses(self):
832835 thrift_backend ._handle_execute_response (error_resp , Mock ())
833836 self .assertIn ("this is a bad error" , str (cm .exception ))
834837
838+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
835839 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
836840 def test_handle_execute_response_can_handle_without_direct_results (
837- self , tcli_service_class
841+ self , tcli_service_class , mock_result_set
838842 ):
839843 tcli_service_instance = tcli_service_class .return_value
840844
@@ -878,10 +882,10 @@ def test_handle_execute_response_can_handle_without_direct_results(
878882 auth_provider = AuthProvider (),
879883 ssl_options = SSLOptions (),
880884 )
881- execute_response , _ = thrift_backend . _handle_execute_response (
882- execute_resp , Mock ()
883- )
884-
885+ (
886+ execute_response ,
887+ _ ,
888+ ) = thrift_backend . _handle_execute_response ( execute_resp , Mock ())
885889 self .assertEqual (
886890 execute_response .status ,
887891 CommandState .SUCCEEDED ,
@@ -947,8 +951,14 @@ def test_use_arrow_schema_if_available(self, tcli_service_class):
947951 tcli_service_instance .GetResultSetMetadata .return_value = (
948952 t_get_result_set_metadata_resp
949953 )
954+ tcli_service_instance .GetOperationStatus .return_value = (
955+ ttypes .TGetOperationStatusResp (
956+ status = self .okay_status ,
957+ operationState = ttypes .TOperationState .FINISHED_STATE ,
958+ )
959+ )
950960 thrift_backend = self ._make_fake_thrift_backend ()
951- execute_response , arrow_schema_bytes = thrift_backend ._handle_execute_response (
961+ execute_response , _ = thrift_backend ._handle_execute_response (
952962 t_execute_resp , Mock ()
953963 )
954964
@@ -973,8 +983,14 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
973983 )
974984
975985 tcli_service_instance .GetResultSetMetadata .return_value = hive_schema_req
986+ tcli_service_instance .GetOperationStatus .return_value = (
987+ ttypes .TGetOperationStatusResp (
988+ status = self .okay_status ,
989+ operationState = ttypes .TOperationState .FINISHED_STATE ,
990+ )
991+ )
976992 thrift_backend = self ._make_fake_thrift_backend ()
977- thrift_backend ._handle_execute_response (t_execute_resp , Mock ())
993+ _ , _ = thrift_backend ._handle_execute_response (t_execute_resp , Mock ())
978994
979995 self .assertEqual (
980996 hive_schema_mock ,
@@ -988,10 +1004,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
9881004 def test_handle_execute_response_reads_has_more_rows_in_direct_results (
9891005 self , tcli_service_class , build_queue
9901006 ):
991- for has_more_rows , resp_type in itertools .product (
1007+ for is_direct_results , resp_type in itertools .product (
9921008 [True , False ], self .execute_response_types
9931009 ):
994- with self .subTest (has_more_rows = has_more_rows , resp_type = resp_type ):
1010+ with self .subTest (is_direct_results = is_direct_results , resp_type = resp_type ):
9951011 tcli_service_instance = tcli_service_class .return_value
9961012 results_mock = Mock ()
9971013 results_mock .startRowOffset = 0
@@ -1003,7 +1019,7 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10031019 resultSetMetadata = self .metadata_resp ,
10041020 resultSet = ttypes .TFetchResultsResp (
10051021 status = self .okay_status ,
1006- hasMoreRows = has_more_rows ,
1022+ hasMoreRows = is_direct_results ,
10071023 results = results_mock ,
10081024 ),
10091025 closeOperation = Mock (),
@@ -1019,11 +1035,12 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10191035 )
10201036 thrift_backend = self ._make_fake_thrift_backend ()
10211037
1022- execute_response , _ = thrift_backend ._handle_execute_response (
1023- execute_resp , Mock ()
1024- )
1038+ (
1039+ execute_response ,
1040+ has_more_rows_result ,
1041+ ) = thrift_backend ._handle_execute_response (execute_resp , Mock ())
10251042
1026- self .assertEqual (has_more_rows , execute_response . has_more_rows )
1043+ self .assertEqual (is_direct_results , has_more_rows_result )
10271044
10281045 @patch (
10291046 "databricks.sql.utils.ResultSetQueueFactory.build_queue" , return_value = Mock ()
@@ -1032,10 +1049,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results(
10321049 def test_handle_execute_response_reads_has_more_rows_in_result_response (
10331050 self , tcli_service_class , build_queue
10341051 ):
1035- for has_more_rows , resp_type in itertools .product (
1052+ for is_direct_results , resp_type in itertools .product (
10361053 [True , False ], self .execute_response_types
10371054 ):
1038- with self .subTest (has_more_rows = has_more_rows , resp_type = resp_type ):
1055+ with self .subTest (is_direct_results = is_direct_results , resp_type = resp_type ):
10391056 tcli_service_instance = tcli_service_class .return_value
10401057 results_mock = MagicMock ()
10411058 results_mock .startRowOffset = 0
@@ -1048,7 +1065,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
10481065
10491066 fetch_results_resp = ttypes .TFetchResultsResp (
10501067 status = self .okay_status ,
1051- hasMoreRows = has_more_rows ,
1068+ hasMoreRows = is_direct_results ,
10521069 results = results_mock ,
10531070 resultSetMetadata = ttypes .TGetResultSetMetadataResp (
10541071 resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET
@@ -1081,7 +1098,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
10811098 description = Mock (),
10821099 )
10831100
1084- self .assertEqual (has_more_rows , has_more_rows_resp )
1101+ self .assertEqual (is_direct_results , has_more_rows_resp )
10851102
10861103 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
10871104 def test_arrow_batches_row_count_are_respected (self , tcli_service_class ):
@@ -1136,9 +1153,10 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
11361153
11371154 self .assertEqual (arrow_queue .n_valid_rows , 15 * 10 )
11381155
1156+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
11391157 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
11401158 def test_execute_statement_calls_client_and_handle_execute_response (
1141- self , tcli_service_class
1159+ self , tcli_service_class , mock_result_set
11421160 ):
11431161 tcli_service_instance = tcli_service_class .return_value
11441162 response = Mock ()
@@ -1151,14 +1169,15 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11511169 auth_provider = AuthProvider (),
11521170 ssl_options = SSLOptions (),
11531171 )
1154- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1172+ thrift_backend ._handle_execute_response = Mock ()
1173+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
11551174 cursor_mock = Mock ()
11561175
11571176 result = thrift_backend .execute_command (
11581177 "foo" , Mock (), 100 , 200 , Mock (), cursor_mock
11591178 )
11601179 # Verify the result is a ResultSet
1161- self .assertIsInstance (result , ResultSet )
1180+ self .assertEqual (result , mock_result_set . return_value )
11621181
11631182 # Check call to client
11641183 req = tcli_service_instance .ExecuteStatement .call_args [0 ][0 ]
@@ -1170,9 +1189,10 @@ def test_execute_statement_calls_client_and_handle_execute_response(
11701189 response , cursor_mock
11711190 )
11721191
1192+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
11731193 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
11741194 def test_get_catalogs_calls_client_and_handle_execute_response (
1175- self , tcli_service_class
1195+ self , tcli_service_class , mock_result_set
11761196 ):
11771197 tcli_service_instance = tcli_service_class .return_value
11781198 response = Mock ()
@@ -1185,12 +1205,13 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
11851205 auth_provider = AuthProvider (),
11861206 ssl_options = SSLOptions (),
11871207 )
1188- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1208+ thrift_backend ._handle_execute_response = Mock ()
1209+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
11891210 cursor_mock = Mock ()
11901211
11911212 result = thrift_backend .get_catalogs (Mock (), 100 , 200 , cursor_mock )
11921213 # Verify the result is a ResultSet
1193- self .assertIsInstance (result , ResultSet )
1214+ self .assertEqual (result , mock_result_set . return_value )
11941215
11951216 # Check call to client
11961217 req = tcli_service_instance .GetCatalogs .call_args [0 ][0 ]
@@ -1201,9 +1222,10 @@ def test_get_catalogs_calls_client_and_handle_execute_response(
12011222 response , cursor_mock
12021223 )
12031224
1225+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
12041226 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
12051227 def test_get_schemas_calls_client_and_handle_execute_response (
1206- self , tcli_service_class
1228+ self , tcli_service_class , mock_result_set
12071229 ):
12081230 tcli_service_instance = tcli_service_class .return_value
12091231 response = Mock ()
@@ -1216,7 +1238,8 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12161238 auth_provider = AuthProvider (),
12171239 ssl_options = SSLOptions (),
12181240 )
1219- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1241+ thrift_backend ._handle_execute_response = Mock ()
1242+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
12201243 cursor_mock = Mock ()
12211244
12221245 result = thrift_backend .get_schemas (
@@ -1228,7 +1251,7 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12281251 schema_name = "schema_pattern" ,
12291252 )
12301253 # Verify the result is a ResultSet
1231- self .assertIsInstance (result , ResultSet )
1254+ self .assertEqual (result , mock_result_set . return_value )
12321255
12331256 # Check call to client
12341257 req = tcli_service_instance .GetSchemas .call_args [0 ][0 ]
@@ -1241,9 +1264,10 @@ def test_get_schemas_calls_client_and_handle_execute_response(
12411264 response , cursor_mock
12421265 )
12431266
1267+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
12441268 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
12451269 def test_get_tables_calls_client_and_handle_execute_response (
1246- self , tcli_service_class
1270+ self , tcli_service_class , mock_result_set
12471271 ):
12481272 tcli_service_instance = tcli_service_class .return_value
12491273 response = Mock ()
@@ -1256,7 +1280,8 @@ def test_get_tables_calls_client_and_handle_execute_response(
12561280 auth_provider = AuthProvider (),
12571281 ssl_options = SSLOptions (),
12581282 )
1259- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1283+ thrift_backend ._handle_execute_response = Mock ()
1284+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
12601285 cursor_mock = Mock ()
12611286
12621287 result = thrift_backend .get_tables (
@@ -1270,7 +1295,7 @@ def test_get_tables_calls_client_and_handle_execute_response(
12701295 table_types = ["type1" , "type2" ],
12711296 )
12721297 # Verify the result is a ResultSet
1273- self .assertIsInstance (result , ResultSet )
1298+ self .assertEqual (result , mock_result_set . return_value )
12741299
12751300 # Check call to client
12761301 req = tcli_service_instance .GetTables .call_args [0 ][0 ]
@@ -1285,9 +1310,10 @@ def test_get_tables_calls_client_and_handle_execute_response(
12851310 response , cursor_mock
12861311 )
12871312
1313+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
12881314 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
12891315 def test_get_columns_calls_client_and_handle_execute_response (
1290- self , tcli_service_class
1316+ self , tcli_service_class , mock_result_set
12911317 ):
12921318 tcli_service_instance = tcli_service_class .return_value
12931319 response = Mock ()
@@ -1300,7 +1326,8 @@ def test_get_columns_calls_client_and_handle_execute_response(
13001326 auth_provider = AuthProvider (),
13011327 ssl_options = SSLOptions (),
13021328 )
1303- thrift_backend ._handle_execute_response = Mock (return_value = (Mock (), Mock ()))
1329+ thrift_backend ._handle_execute_response = Mock ()
1330+ thrift_backend ._handle_execute_response .return_value = (Mock (), Mock ())
13041331 cursor_mock = Mock ()
13051332
13061333 result = thrift_backend .get_columns (
@@ -1314,7 +1341,7 @@ def test_get_columns_calls_client_and_handle_execute_response(
13141341 column_name = "column_pattern" ,
13151342 )
13161343 # Verify the result is a ResultSet
1317- self .assertIsInstance (result , ResultSet )
1344+ self .assertEqual (result , mock_result_set . return_value )
13181345
13191346 # Check call to client
13201347 req = tcli_service_instance .GetColumns .call_args [0 ][0 ]
@@ -2203,14 +2230,23 @@ def test_protocol_v3_fails_if_initial_namespace_set(self, tcli_client_class):
22032230 str (cm .exception ),
22042231 )
22052232
2233+ @patch ("databricks.sql.backend.thrift_backend.ThriftResultSet" )
22062234 @patch ("databricks.sql.backend.thrift_backend.TCLIService.Client" , autospec = True )
22072235 @patch (
22082236 "databricks.sql.backend.thrift_backend.ThriftDatabricksClient._handle_execute_response"
22092237 )
22102238 def test_execute_command_sets_complex_type_fields_correctly (
2211- self , mock_handle_execute_response , tcli_service_class
2239+ self , mock_handle_execute_response , tcli_service_class , mock_result_set
22122240 ):
22132241 tcli_service_instance = tcli_service_class .return_value
2242+ # Set up the mock to return a tuple with two values
2243+ mock_execute_response = Mock ()
2244+ mock_arrow_schema = Mock ()
2245+ mock_handle_execute_response .return_value = (
2246+ mock_execute_response ,
2247+ mock_arrow_schema ,
2248+ )
2249+
22142250 # Iterate through each possible combination of native types (True, False and unset)
22152251 for complex , timestamp , decimals in itertools .product (
22162252 [True , False , None ], [True , False , None ], [True , False , None ]
0 commit comments