7272 "_retry_delay_default" : (float , 5 , 1 , 60 ),
7373}
7474
75- # Add thread local storage
76- _connection_uuid = threading .local ()
77-
7875
7976class ThriftBackend :
8077 CLOSED_OP_STATE = ttypes .TOperationState .CLOSED_STATE
@@ -226,7 +223,7 @@ def __init__(
226223 raise
227224
228225 self ._request_lock = threading .RLock ()
229- _connection_uuid . value = None
226+ self . _connection_uuid = None
230227
231228 # TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918)
232229 def _initialize_retry_args (self , kwargs ):
@@ -259,14 +256,14 @@ def _initialize_retry_args(self, kwargs):
259256 )
260257
261258 @staticmethod
262- def _check_response_for_error (response ):
259+ def _check_response_for_error (response , connection_uuid = None ):
263260 if response .status and response .status .statusCode in [
264261 ttypes .TStatusCode .ERROR_STATUS ,
265262 ttypes .TStatusCode .INVALID_HANDLE_STATUS ,
266263 ]:
267264 raise DatabaseError (
268265 response .status .errorMessage ,
269- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
266+ connection_uuid = connection_uuid ,
270267 )
271268
272269 @staticmethod
@@ -320,7 +317,7 @@ def _handle_request_error(self, error_info, attempt, elapsed):
320317 network_request_error = RequestError (
321318 user_friendly_error_message ,
322319 full_error_info_context ,
323- getattr ( _connection_uuid , "value" , None ) ,
320+ self . _connection_uuid ,
324321 error_info .error ,
325322 )
326323 logger .info (network_request_error .message_with_context ())
@@ -493,7 +490,7 @@ def attempt_request(attempt):
493490 if not isinstance (response_or_error_info , RequestErrorInfo ):
494491 # log nothing here, presume that main request logging covers
495492 response = response_or_error_info
496- ThriftBackend ._check_response_for_error (response )
493+ ThriftBackend ._check_response_for_error (response , self . _connection_uuid )
497494 return response
498495
499496 error_info = response_or_error_info
@@ -508,7 +505,7 @@ def _check_protocol_version(self, t_open_session_resp):
508505 "Error: expected server to use a protocol version >= "
509506 "SPARK_CLI_SERVICE_PROTOCOL_V2, "
510507 "instead got: {}" .format (protocol_version ),
511- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
508+ connection_uuid = self . _connection_uuid ,
512509 )
513510
514511 def _check_initial_namespace (self , catalog , schema , response ):
@@ -522,15 +519,15 @@ def _check_initial_namespace(self, catalog, schema, response):
522519 raise InvalidServerResponseError (
523520 "Setting initial namespace not supported by the DBR version, "
524521 "Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0." ,
525- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
522+ connection_uuid = self . _connection_uuid ,
526523 )
527524
528525 if catalog :
529526 if not response .canUseMultipleCatalogs :
530527 raise InvalidServerResponseError (
531528 "Unexpected response from server: Trying to set initial catalog to {}, "
532529 + "but server does not support multiple catalogs." .format (catalog ), # type: ignore
533- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
530+ connection_uuid = self . _connection_uuid ,
534531 )
535532
536533 def _check_session_configuration (self , session_configuration ):
@@ -545,7 +542,7 @@ def _check_session_configuration(self, session_configuration):
545542 TIMESTAMP_AS_STRING_CONFIG ,
546543 session_configuration [TIMESTAMP_AS_STRING_CONFIG ],
547544 ),
548- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
545+ connection_uuid = self . _connection_uuid ,
549546 )
550547
551548 def open_session (self , session_configuration , catalog , schema ):
@@ -576,7 +573,7 @@ def open_session(self, session_configuration, catalog, schema):
576573 response = self .make_request (self ._client .OpenSession , open_session_req )
577574 self ._check_initial_namespace (catalog , schema , response )
578575 self ._check_protocol_version (response )
579- _connection_uuid . value = (
576+ self . _connection_uuid = (
580577 self .handle_to_hex_id (response .sessionHandle )
581578 if response .sessionHandle
582579 else None
@@ -605,7 +602,7 @@ def _check_command_not_in_error_or_closed_state(
605602 and self .guid_to_hex_id (op_handle .operationId .guid ),
606603 "diagnostic-info" : get_operations_resp .diagnosticInfo ,
607604 },
608- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
605+ connection_uuid = self . _connection_uuid ,
609606 )
610607 else :
611608 raise ServerOperationError (
@@ -615,7 +612,7 @@ def _check_command_not_in_error_or_closed_state(
615612 and self .guid_to_hex_id (op_handle .operationId .guid ),
616613 "diagnostic-info" : None ,
617614 },
618- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
615+ connection_uuid = self . _connection_uuid ,
619616 )
620617 elif get_operations_resp .operationState == ttypes .TOperationState .CLOSED_STATE :
621618 raise DatabaseError (
@@ -626,7 +623,7 @@ def _check_command_not_in_error_or_closed_state(
626623 "operation-id" : op_handle
627624 and self .guid_to_hex_id (op_handle .operationId .guid )
628625 },
629- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
626+ connection_uuid = self . _connection_uuid ,
630627 )
631628
632629 def _poll_for_status (self , op_handle ):
@@ -649,7 +646,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
649646 else :
650647 raise OperationalError (
651648 "Unsupported TRowSet instance {}" .format (t_row_set ),
652- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
649+ connection_uuid = self . _connection_uuid ,
653650 )
654651 return convert_decimals_in_arrow_table (arrow_table , description ), num_rows
655652
@@ -658,7 +655,7 @@ def _get_metadata_resp(self, op_handle):
658655 return self .make_request (self ._client .GetResultSetMetadata , req )
659656
660657 @staticmethod
661- def _hive_schema_to_arrow_schema (t_table_schema ):
658+ def _hive_schema_to_arrow_schema (t_table_schema , connection_uuid = None ):
662659 def map_type (t_type_entry ):
663660 if t_type_entry .primitiveEntry :
664661 return {
@@ -690,7 +687,7 @@ def map_type(t_type_entry):
690687 # even for complex types
691688 raise OperationalError (
692689 "Thrift protocol error: t_type_entry not a primitiveEntry" ,
693- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
690+ connection_uuid = connection_uuid ,
694691 )
695692
696693 def convert_col (t_column_desc ):
@@ -701,7 +698,7 @@ def convert_col(t_column_desc):
701698 return pyarrow .schema ([convert_col (col ) for col in t_table_schema .columns ])
702699
703700 @staticmethod
704- def _col_to_description (col ):
701+ def _col_to_description (col , connection_uuid = None ):
705702 type_entry = col .typeDesc .types [0 ]
706703
707704 if type_entry .primitiveEntry :
@@ -711,7 +708,7 @@ def _col_to_description(col):
711708 else :
712709 raise OperationalError (
713710 "Thrift protocol error: t_type_entry not a primitiveEntry" ,
714- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
711+ connection_uuid = connection_uuid ,
715712 )
716713
717714 if type_entry .primitiveEntry .type == ttypes .TTypeId .DECIMAL_TYPE :
@@ -725,17 +722,18 @@ def _col_to_description(col):
725722 raise OperationalError (
726723 "Decimal type did not provide typeQualifier precision, scale in "
727724 "primitiveEntry {}" .format (type_entry .primitiveEntry ),
728- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
725+ connection_uuid = connection_uuid ,
729726 )
730727 else :
731728 precision , scale = None , None
732729
733730 return col .columnName , cleaned_type , None , None , precision , scale , None
734731
735732 @staticmethod
736- def _hive_schema_to_description (t_table_schema ):
733+ def _hive_schema_to_description (t_table_schema , connection_uuid = None ):
737734 return [
738- ThriftBackend ._col_to_description (col ) for col in t_table_schema .columns
735+ ThriftBackend ._col_to_description (col , connection_uuid )
736+ for col in t_table_schema .columns
739737 ]
740738
741739 def _results_message_to_execute_response (self , resp , operation_state ):
@@ -756,7 +754,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
756754 t_result_set_metadata_resp .resultFormat
757755 ]
758756 ),
759- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
757+ connection_uuid = self . _connection_uuid ,
760758 )
761759 direct_results = resp .directResults
762760 has_been_closed_server_side = direct_results and direct_results .closeOperation
@@ -766,13 +764,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
766764 or direct_results .resultSet .hasMoreRows
767765 )
768766 description = self ._hive_schema_to_description (
769- t_result_set_metadata_resp .schema
767+ t_result_set_metadata_resp .schema ,
768+ self ._connection_uuid ,
770769 )
771770
772771 if pyarrow :
773772 schema_bytes = (
774773 t_result_set_metadata_resp .arrowSchema
775- or self ._hive_schema_to_arrow_schema (t_result_set_metadata_resp .schema )
774+ or self ._hive_schema_to_arrow_schema (
775+ t_result_set_metadata_resp .schema , self ._connection_uuid
776+ )
776777 .serialize ()
777778 .to_pybytes ()
778779 )
@@ -833,13 +834,16 @@ def get_execution_result(self, op_handle, cursor):
833834 is_staging_operation = t_result_set_metadata_resp .isStagingOperation
834835 has_more_rows = resp .hasMoreRows
835836 description = self ._hive_schema_to_description (
836- t_result_set_metadata_resp .schema
837+ t_result_set_metadata_resp .schema ,
838+ self ._connection_uuid ,
837839 )
838840
839841 if pyarrow :
840842 schema_bytes = (
841843 t_result_set_metadata_resp .arrowSchema
842- or self ._hive_schema_to_arrow_schema (t_result_set_metadata_resp .schema )
844+ or self ._hive_schema_to_arrow_schema (
845+ t_result_set_metadata_resp .schema , self ._connection_uuid
846+ )
843847 .serialize ()
844848 .to_pybytes ()
845849 )
@@ -893,23 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState":
893897 return operation_state
894898
895899 @staticmethod
896- def _check_direct_results_for_error (t_spark_direct_results ):
900+ def _check_direct_results_for_error (t_spark_direct_results , connection_uuid = None ):
897901 if t_spark_direct_results :
898902 if t_spark_direct_results .operationStatus :
899903 ThriftBackend ._check_response_for_error (
900- t_spark_direct_results .operationStatus
904+ t_spark_direct_results .operationStatus ,
905+ connection_uuid ,
901906 )
902907 if t_spark_direct_results .resultSetMetadata :
903908 ThriftBackend ._check_response_for_error (
904- t_spark_direct_results .resultSetMetadata
909+ t_spark_direct_results .resultSetMetadata ,
910+ connection_uuid ,
905911 )
906912 if t_spark_direct_results .resultSet :
907913 ThriftBackend ._check_response_for_error (
908- t_spark_direct_results .resultSet
914+ t_spark_direct_results .resultSet ,
915+ connection_uuid ,
909916 )
910917 if t_spark_direct_results .closeOperation :
911918 ThriftBackend ._check_response_for_error (
912- t_spark_direct_results .closeOperation
919+ t_spark_direct_results .closeOperation ,
920+ connection_uuid ,
913921 )
914922
915923 def execute_command (
@@ -1058,7 +1066,7 @@ def get_columns(
10581066
10591067 def _handle_execute_response (self , resp , cursor ):
10601068 cursor .active_op_handle = resp .operationHandle
1061- self ._check_direct_results_for_error (resp .directResults )
1069+ self ._check_direct_results_for_error (resp .directResults , self . _connection_uuid )
10621070
10631071 final_operation_state = self ._wait_until_command_done (
10641072 resp .operationHandle ,
@@ -1069,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor):
10691077
10701078 def _handle_execute_response_async (self , resp , cursor ):
10711079 cursor .active_op_handle = resp .operationHandle
1072- self ._check_direct_results_for_error (resp .directResults )
1080+ self ._check_direct_results_for_error (resp .directResults , self . _connection_uuid )
10731081
10741082 def fetch_results (
10751083 self ,
@@ -1104,7 +1112,7 @@ def fetch_results(
11041112 "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped" .format (
11051113 expected_row_start_offset , resp .results .startRowOffset
11061114 ),
1107- connection_uuid = getattr ( _connection_uuid , "value" , None ) ,
1115+ connection_uuid = self . _connection_uuid ,
11081116 )
11091117
11101118 queue = ResultSetQueueFactory .build_queue (
0 commit comments