Skip to content

Commit c68e42f

Browse files
committed
revert change to connection_uuid
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 6212710 commit c68e42f

File tree

1 file changed

+45
-37
lines changed

1 file changed

+45
-37
lines changed

src/databricks/sql/thrift_backend.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@
7272
"_retry_delay_default": (float, 5, 1, 60),
7373
}
7474

75-
# Add thread local storage
76-
_connection_uuid = threading.local()
77-
7875

7976
class ThriftBackend:
8077
CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE
@@ -226,7 +223,7 @@ def __init__(
226223
raise
227224

228225
self._request_lock = threading.RLock()
229-
_connection_uuid.value = None
226+
self._connection_uuid = None
230227

231228
# TODO: Move this bounding logic into DatabricksRetryPolicy for v3 (PECO-918)
232229
def _initialize_retry_args(self, kwargs):
@@ -259,14 +256,14 @@ def _initialize_retry_args(self, kwargs):
259256
)
260257

261258
@staticmethod
262-
def _check_response_for_error(response):
259+
def _check_response_for_error(response, connection_uuid=None):
263260
if response.status and response.status.statusCode in [
264261
ttypes.TStatusCode.ERROR_STATUS,
265262
ttypes.TStatusCode.INVALID_HANDLE_STATUS,
266263
]:
267264
raise DatabaseError(
268265
response.status.errorMessage,
269-
connection_uuid=getattr(_connection_uuid, "value", None),
266+
connection_uuid=connection_uuid,
270267
)
271268

272269
@staticmethod
@@ -320,7 +317,7 @@ def _handle_request_error(self, error_info, attempt, elapsed):
320317
network_request_error = RequestError(
321318
user_friendly_error_message,
322319
full_error_info_context,
323-
getattr(_connection_uuid, "value", None),
320+
self._connection_uuid,
324321
error_info.error,
325322
)
326323
logger.info(network_request_error.message_with_context())
@@ -493,7 +490,7 @@ def attempt_request(attempt):
493490
if not isinstance(response_or_error_info, RequestErrorInfo):
494491
# log nothing here, presume that main request logging covers
495492
response = response_or_error_info
496-
ThriftBackend._check_response_for_error(response)
493+
ThriftBackend._check_response_for_error(response, self._connection_uuid)
497494
return response
498495

499496
error_info = response_or_error_info
@@ -508,7 +505,7 @@ def _check_protocol_version(self, t_open_session_resp):
508505
"Error: expected server to use a protocol version >= "
509506
"SPARK_CLI_SERVICE_PROTOCOL_V2, "
510507
"instead got: {}".format(protocol_version),
511-
connection_uuid=getattr(_connection_uuid, "value", None),
508+
connection_uuid=self._connection_uuid,
512509
)
513510

514511
def _check_initial_namespace(self, catalog, schema, response):
@@ -522,15 +519,15 @@ def _check_initial_namespace(self, catalog, schema, response):
522519
raise InvalidServerResponseError(
523520
"Setting initial namespace not supported by the DBR version, "
524521
"Please use a Databricks SQL endpoint or a cluster with DBR >= 9.0.",
525-
connection_uuid=getattr(_connection_uuid, "value", None),
522+
connection_uuid=self._connection_uuid,
526523
)
527524

528525
if catalog:
529526
if not response.canUseMultipleCatalogs:
530527
raise InvalidServerResponseError(
531528
"Unexpected response from server: Trying to set initial catalog to {}, "
532529
+ "but server does not support multiple catalogs.".format(catalog), # type: ignore
533-
connection_uuid=getattr(_connection_uuid, "value", None),
530+
connection_uuid=self._connection_uuid,
534531
)
535532

536533
def _check_session_configuration(self, session_configuration):
@@ -545,7 +542,7 @@ def _check_session_configuration(self, session_configuration):
545542
TIMESTAMP_AS_STRING_CONFIG,
546543
session_configuration[TIMESTAMP_AS_STRING_CONFIG],
547544
),
548-
connection_uuid=getattr(_connection_uuid, "value", None),
545+
connection_uuid=self._connection_uuid,
549546
)
550547

551548
def open_session(self, session_configuration, catalog, schema):
@@ -576,7 +573,7 @@ def open_session(self, session_configuration, catalog, schema):
576573
response = self.make_request(self._client.OpenSession, open_session_req)
577574
self._check_initial_namespace(catalog, schema, response)
578575
self._check_protocol_version(response)
579-
_connection_uuid.value = (
576+
self._connection_uuid = (
580577
self.handle_to_hex_id(response.sessionHandle)
581578
if response.sessionHandle
582579
else None
@@ -605,7 +602,7 @@ def _check_command_not_in_error_or_closed_state(
605602
and self.guid_to_hex_id(op_handle.operationId.guid),
606603
"diagnostic-info": get_operations_resp.diagnosticInfo,
607604
},
608-
connection_uuid=getattr(_connection_uuid, "value", None),
605+
connection_uuid=self._connection_uuid,
609606
)
610607
else:
611608
raise ServerOperationError(
@@ -615,7 +612,7 @@ def _check_command_not_in_error_or_closed_state(
615612
and self.guid_to_hex_id(op_handle.operationId.guid),
616613
"diagnostic-info": None,
617614
},
618-
connection_uuid=getattr(_connection_uuid, "value", None),
615+
connection_uuid=self._connection_uuid,
619616
)
620617
elif get_operations_resp.operationState == ttypes.TOperationState.CLOSED_STATE:
621618
raise DatabaseError(
@@ -626,7 +623,7 @@ def _check_command_not_in_error_or_closed_state(
626623
"operation-id": op_handle
627624
and self.guid_to_hex_id(op_handle.operationId.guid)
628625
},
629-
connection_uuid=getattr(_connection_uuid, "value", None),
626+
connection_uuid=self._connection_uuid,
630627
)
631628

632629
def _poll_for_status(self, op_handle):
@@ -649,7 +646,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
649646
else:
650647
raise OperationalError(
651648
"Unsupported TRowSet instance {}".format(t_row_set),
652-
connection_uuid=getattr(_connection_uuid, "value", None),
649+
connection_uuid=self._connection_uuid,
653650
)
654651
return convert_decimals_in_arrow_table(arrow_table, description), num_rows
655652

@@ -658,7 +655,7 @@ def _get_metadata_resp(self, op_handle):
658655
return self.make_request(self._client.GetResultSetMetadata, req)
659656

660657
@staticmethod
661-
def _hive_schema_to_arrow_schema(t_table_schema):
658+
def _hive_schema_to_arrow_schema(t_table_schema, connection_uuid=None):
662659
def map_type(t_type_entry):
663660
if t_type_entry.primitiveEntry:
664661
return {
@@ -690,7 +687,7 @@ def map_type(t_type_entry):
690687
# even for complex types
691688
raise OperationalError(
692689
"Thrift protocol error: t_type_entry not a primitiveEntry",
693-
connection_uuid=getattr(_connection_uuid, "value", None),
690+
connection_uuid=connection_uuid,
694691
)
695692

696693
def convert_col(t_column_desc):
@@ -701,7 +698,7 @@ def convert_col(t_column_desc):
701698
return pyarrow.schema([convert_col(col) for col in t_table_schema.columns])
702699

703700
@staticmethod
704-
def _col_to_description(col):
701+
def _col_to_description(col, connection_uuid=None):
705702
type_entry = col.typeDesc.types[0]
706703

707704
if type_entry.primitiveEntry:
@@ -711,7 +708,7 @@ def _col_to_description(col):
711708
else:
712709
raise OperationalError(
713710
"Thrift protocol error: t_type_entry not a primitiveEntry",
714-
connection_uuid=getattr(_connection_uuid, "value", None),
711+
connection_uuid=connection_uuid,
715712
)
716713

717714
if type_entry.primitiveEntry.type == ttypes.TTypeId.DECIMAL_TYPE:
@@ -725,17 +722,18 @@ def _col_to_description(col):
725722
raise OperationalError(
726723
"Decimal type did not provide typeQualifier precision, scale in "
727724
"primitiveEntry {}".format(type_entry.primitiveEntry),
728-
connection_uuid=getattr(_connection_uuid, "value", None),
725+
connection_uuid=connection_uuid,
729726
)
730727
else:
731728
precision, scale = None, None
732729

733730
return col.columnName, cleaned_type, None, None, precision, scale, None
734731

735732
@staticmethod
736-
def _hive_schema_to_description(t_table_schema):
733+
def _hive_schema_to_description(t_table_schema, connection_uuid=None):
737734
return [
738-
ThriftBackend._col_to_description(col) for col in t_table_schema.columns
735+
ThriftBackend._col_to_description(col, connection_uuid)
736+
for col in t_table_schema.columns
739737
]
740738

741739
def _results_message_to_execute_response(self, resp, operation_state):
@@ -756,7 +754,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
756754
t_result_set_metadata_resp.resultFormat
757755
]
758756
),
759-
connection_uuid=getattr(_connection_uuid, "value", None),
757+
connection_uuid=self._connection_uuid,
760758
)
761759
direct_results = resp.directResults
762760
has_been_closed_server_side = direct_results and direct_results.closeOperation
@@ -766,13 +764,16 @@ def _results_message_to_execute_response(self, resp, operation_state):
766764
or direct_results.resultSet.hasMoreRows
767765
)
768766
description = self._hive_schema_to_description(
769-
t_result_set_metadata_resp.schema
767+
t_result_set_metadata_resp.schema,
768+
self._connection_uuid,
770769
)
771770

772771
if pyarrow:
773772
schema_bytes = (
774773
t_result_set_metadata_resp.arrowSchema
775-
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
774+
or self._hive_schema_to_arrow_schema(
775+
t_result_set_metadata_resp.schema, self._connection_uuid
776+
)
776777
.serialize()
777778
.to_pybytes()
778779
)
@@ -833,13 +834,16 @@ def get_execution_result(self, op_handle, cursor):
833834
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
834835
has_more_rows = resp.hasMoreRows
835836
description = self._hive_schema_to_description(
836-
t_result_set_metadata_resp.schema
837+
t_result_set_metadata_resp.schema,
838+
self._connection_uuid,
837839
)
838840

839841
if pyarrow:
840842
schema_bytes = (
841843
t_result_set_metadata_resp.arrowSchema
842-
or self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
844+
or self._hive_schema_to_arrow_schema(
845+
t_result_set_metadata_resp.schema, self._connection_uuid
846+
)
843847
.serialize()
844848
.to_pybytes()
845849
)
@@ -893,23 +897,27 @@ def get_query_state(self, op_handle) -> "TOperationState":
893897
return operation_state
894898

895899
@staticmethod
896-
def _check_direct_results_for_error(t_spark_direct_results):
900+
def _check_direct_results_for_error(t_spark_direct_results, connection_uuid=None):
897901
if t_spark_direct_results:
898902
if t_spark_direct_results.operationStatus:
899903
ThriftBackend._check_response_for_error(
900-
t_spark_direct_results.operationStatus
904+
t_spark_direct_results.operationStatus,
905+
connection_uuid,
901906
)
902907
if t_spark_direct_results.resultSetMetadata:
903908
ThriftBackend._check_response_for_error(
904-
t_spark_direct_results.resultSetMetadata
909+
t_spark_direct_results.resultSetMetadata,
910+
connection_uuid,
905911
)
906912
if t_spark_direct_results.resultSet:
907913
ThriftBackend._check_response_for_error(
908-
t_spark_direct_results.resultSet
914+
t_spark_direct_results.resultSet,
915+
connection_uuid,
909916
)
910917
if t_spark_direct_results.closeOperation:
911918
ThriftBackend._check_response_for_error(
912-
t_spark_direct_results.closeOperation
919+
t_spark_direct_results.closeOperation,
920+
connection_uuid,
913921
)
914922

915923
def execute_command(
@@ -1058,7 +1066,7 @@ def get_columns(
10581066

10591067
def _handle_execute_response(self, resp, cursor):
10601068
cursor.active_op_handle = resp.operationHandle
1061-
self._check_direct_results_for_error(resp.directResults)
1069+
self._check_direct_results_for_error(resp.directResults, self._connection_uuid)
10621070

10631071
final_operation_state = self._wait_until_command_done(
10641072
resp.operationHandle,
@@ -1069,7 +1077,7 @@ def _handle_execute_response(self, resp, cursor):
10691077

10701078
def _handle_execute_response_async(self, resp, cursor):
10711079
cursor.active_op_handle = resp.operationHandle
1072-
self._check_direct_results_for_error(resp.directResults)
1080+
self._check_direct_results_for_error(resp.directResults, self._connection_uuid)
10731081

10741082
def fetch_results(
10751083
self,
@@ -1104,7 +1112,7 @@ def fetch_results(
11041112
"fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format(
11051113
expected_row_start_offset, resp.results.startRowOffset
11061114
),
1107-
connection_uuid=getattr(_connection_uuid, "value", None),
1115+
connection_uuid=self._connection_uuid,
11081116
)
11091117

11101118
queue = ResultSetQueueFactory.build_queue(

0 commit comments

Comments
 (0)