Skip to content

Commit d80ea49

Browse files
formatting (black)
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent d01f3e4 commit d80ea49

File tree

6 files changed

+160
-75
lines changed

6 files changed

+160
-75
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -801,16 +801,19 @@ def _results_message_to_execute_response(self, resp, operation_state):
801801
if status is None:
802802
raise ValueError(f"Invalid operation state: {operation_state}")
803803

804-
return ExecuteResponse(
805-
command_id=command_id,
806-
status=status,
807-
description=description,
808-
has_more_rows=has_more_rows,
809-
results_queue=arrow_queue_opt,
810-
has_been_closed_server_side=has_been_closed_server_side,
811-
lz4_compressed=lz4_compressed,
812-
is_staging_operation=is_staging_operation,
813-
), schema_bytes
804+
return (
805+
ExecuteResponse(
806+
command_id=command_id,
807+
status=status,
808+
description=description,
809+
has_more_rows=has_more_rows,
810+
results_queue=arrow_queue_opt,
811+
has_been_closed_server_side=has_been_closed_server_side,
812+
lz4_compressed=lz4_compressed,
813+
is_staging_operation=is_staging_operation,
814+
),
815+
schema_bytes,
816+
)
814817

815818
def get_execution_result(
816819
self, command_id: CommandId, cursor: "Cursor"
@@ -881,7 +884,7 @@ def get_execution_result(
881884
buffer_size_bytes=cursor.buffer_size_bytes,
882885
arraysize=cursor.arraysize,
883886
use_cloud_fetch=cursor.connection.use_cloud_fetch,
884-
arrow_schema_bytes=schema_bytes
887+
arrow_schema_bytes=schema_bytes,
885888
)
886889

887890
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
@@ -991,7 +994,9 @@ def execute_command(
991994
self._handle_execute_response_async(resp, cursor)
992995
return None
993996
else:
994-
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
997+
execute_response, arrow_schema_bytes = self._handle_execute_response(
998+
resp, cursor
999+
)
9951000

9961001
return ThriftResultSet(
9971002
connection=cursor.connection,
@@ -1000,7 +1005,7 @@ def execute_command(
10001005
buffer_size_bytes=max_bytes,
10011006
arraysize=max_rows,
10021007
use_cloud_fetch=use_cloud_fetch,
1003-
arrow_schema_bytes=arrow_schema_bytes
1008+
arrow_schema_bytes=arrow_schema_bytes,
10041009
)
10051010

10061011
def get_catalogs(
@@ -1022,7 +1027,9 @@ def get_catalogs(
10221027
)
10231028
resp = self.make_request(self._client.GetCatalogs, req)
10241029

1025-
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
1030+
execute_response, arrow_schema_bytes = self._handle_execute_response(
1031+
resp, cursor
1032+
)
10261033

10271034
return ThriftResultSet(
10281035
connection=cursor.connection,
@@ -1031,7 +1038,7 @@ def get_catalogs(
10311038
buffer_size_bytes=max_bytes,
10321039
arraysize=max_rows,
10331040
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1034-
arrow_schema_bytes=arrow_schema_bytes
1041+
arrow_schema_bytes=arrow_schema_bytes,
10351042
)
10361043

10371044
def get_schemas(
@@ -1057,7 +1064,9 @@ def get_schemas(
10571064
)
10581065
resp = self.make_request(self._client.GetSchemas, req)
10591066

1060-
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
1067+
execute_response, arrow_schema_bytes = self._handle_execute_response(
1068+
resp, cursor
1069+
)
10611070

10621071
return ThriftResultSet(
10631072
connection=cursor.connection,
@@ -1066,7 +1075,7 @@ def get_schemas(
10661075
buffer_size_bytes=max_bytes,
10671076
arraysize=max_rows,
10681077
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1069-
arrow_schema_bytes=arrow_schema_bytes
1078+
arrow_schema_bytes=arrow_schema_bytes,
10701079
)
10711080

10721081
def get_tables(
@@ -1096,7 +1105,9 @@ def get_tables(
10961105
)
10971106
resp = self.make_request(self._client.GetTables, req)
10981107

1099-
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
1108+
execute_response, arrow_schema_bytes = self._handle_execute_response(
1109+
resp, cursor
1110+
)
11001111

11011112
return ThriftResultSet(
11021113
connection=cursor.connection,
@@ -1105,7 +1116,7 @@ def get_tables(
11051116
buffer_size_bytes=max_bytes,
11061117
arraysize=max_rows,
11071118
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1108-
arrow_schema_bytes=arrow_schema_bytes
1119+
arrow_schema_bytes=arrow_schema_bytes,
11091120
)
11101121

11111122
def get_columns(
@@ -1135,7 +1146,9 @@ def get_columns(
11351146
)
11361147
resp = self.make_request(self._client.GetColumns, req)
11371148

1138-
execute_response, arrow_schema_bytes = self._handle_execute_response(resp, cursor)
1149+
execute_response, arrow_schema_bytes = self._handle_execute_response(
1150+
resp, cursor
1151+
)
11391152

11401153
return ThriftResultSet(
11411154
connection=cursor.connection,
@@ -1144,7 +1157,7 @@ def get_columns(
11441157
buffer_size_bytes=max_bytes,
11451158
arraysize=max_rows,
11461159
use_cloud_fetch=cursor.connection.use_cloud_fetch,
1147-
arrow_schema_bytes=arrow_schema_bytes
1160+
arrow_schema_bytes=arrow_schema_bytes,
11481161
)
11491162

11501163
def _handle_execute_response(self, resp, cursor):
@@ -1158,9 +1171,10 @@ def _handle_execute_response(self, resp, cursor):
11581171
resp.directResults and resp.directResults.operationStatus,
11591172
)
11601173

1161-
execute_response, arrow_schema_bytes = self._results_message_to_execute_response(
1162-
resp, final_operation_state
1163-
)
1174+
(
1175+
execute_response,
1176+
arrow_schema_bytes,
1177+
) = self._results_message_to_execute_response(resp, final_operation_state)
11641178
execute_response.command_id = command_id
11651179
return execute_response, arrow_schema_bytes
11661180

tests/unit/test_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
121121

122122
# Verify initial state
123123
self.assertEqual(real_result_set.has_been_closed_server_side, closed)
124-
expected_status = CommandState.CLOSED if closed else CommandState.SUCCEEDED
124+
expected_status = (
125+
CommandState.CLOSED if closed else CommandState.SUCCEEDED
126+
)
125127
self.assertEqual(real_result_set.status, expected_status)
126128

127129
# Mock execute_command to return our real result set

tests/unit/test_result_set_filter.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,30 @@ def sea_response_with_tables(self):
5454
"data_array": [
5555
["schema1", "table1", False, None, "catalog1", "TABLE", None],
5656
["schema1", "table2", False, None, "catalog1", "VIEW", None],
57-
["schema1", "table3", False, None, "catalog1", "SYSTEM TABLE", None],
57+
[
58+
"schema1",
59+
"table3",
60+
False,
61+
None,
62+
"catalog1",
63+
"SYSTEM TABLE",
64+
None,
65+
],
5866
["schema1", "table4", False, None, "catalog1", "EXTERNAL", None],
5967
]
6068
},
6169
}
6270

6371
@pytest.fixture
64-
def sea_result_set_with_tables(self, mock_connection, mock_sea_client, sea_response_with_tables):
72+
def sea_result_set_with_tables(
73+
self, mock_connection, mock_sea_client, sea_response_with_tables
74+
):
6575
"""Create a SeaResultSet with table data."""
6676
# Create a deep copy of the response to avoid test interference
6777
import copy
78+
6879
sea_response_copy = copy.deepcopy(sea_response_with_tables)
69-
80+
7081
return SeaResultSet(
7182
connection=mock_connection,
7283
sea_client=mock_sea_client,
@@ -78,11 +89,15 @@ def sea_result_set_with_tables(self, mock_connection, mock_sea_client, sea_respo
7889
def test_filter_tables_by_type_default(self, sea_result_set_with_tables):
7990
"""Test filtering tables by type with default types."""
8091
# Default types are TABLE, VIEW, SYSTEM TABLE
81-
filtered_result = ResultSetFilter.filter_tables_by_type(sea_result_set_with_tables)
92+
filtered_result = ResultSetFilter.filter_tables_by_type(
93+
sea_result_set_with_tables
94+
)
8295

8396
# Verify that only the default types are included
8497
assert len(filtered_result._response["result"]["data_array"]) == 3
85-
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
98+
table_types = [
99+
row[5] for row in filtered_result._response["result"]["data_array"]
100+
]
86101
assert "TABLE" in table_types
87102
assert "VIEW" in table_types
88103
assert "SYSTEM TABLE" in table_types
@@ -97,7 +112,9 @@ def test_filter_tables_by_type_custom(self, sea_result_set_with_tables):
97112

98113
# Verify that only the specified types are included
99114
assert len(filtered_result._response["result"]["data_array"]) == 2
100-
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
115+
table_types = [
116+
row[5] for row in filtered_result._response["result"]["data_array"]
117+
]
101118
assert "TABLE" in table_types
102119
assert "EXTERNAL" in table_types
103120
assert "VIEW" not in table_types
@@ -112,7 +129,9 @@ def test_filter_tables_by_type_case_insensitive(self, sea_result_set_with_tables
112129

113130
# Verify that the matching types are included despite case differences
114131
assert len(filtered_result._response["result"]["data_array"]) == 2
115-
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
132+
table_types = [
133+
row[5] for row in filtered_result._response["result"]["data_array"]
134+
]
116135
assert "TABLE" in table_types
117136
assert "VIEW" in table_types
118137
assert "SYSTEM TABLE" not in table_types
@@ -126,7 +145,9 @@ def test_filter_tables_by_type_empty_list(self, sea_result_set_with_tables):
126145

127146
# Verify that default types are used
128147
assert len(filtered_result._response["result"]["data_array"]) == 3
129-
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
148+
table_types = [
149+
row[5] for row in filtered_result._response["result"]["data_array"]
150+
]
130151
assert "TABLE" in table_types
131152
assert "VIEW" in table_types
132153
assert "SYSTEM TABLE" in table_types
@@ -144,21 +165,27 @@ def test_filter_by_column_values(self, sea_result_set_with_tables):
144165

145166
# Filter by table name in column index 1
146167
filtered_result = ResultSetFilter.filter_by_column_values(
147-
sea_result_set_with_tables, column_index=1, allowed_values=["table1", "table3"]
168+
sea_result_set_with_tables,
169+
column_index=1,
170+
allowed_values=["table1", "table3"],
148171
)
149172

150173
# Only rows with table1 or table3 should be included
151174
assert len(filtered_result._response["result"]["data_array"]) == 2
152-
table_names = [row[1] for row in filtered_result._response["result"]["data_array"]]
175+
table_names = [
176+
row[1] for row in filtered_result._response["result"]["data_array"]
177+
]
153178
assert "table1" in table_names
154179
assert "table3" in table_names
155180
assert "table2" not in table_names
156181
assert "table4" not in table_names
157182

158-
def test_filter_by_column_values_case_sensitive(self, mock_connection, mock_sea_client, sea_response_with_tables):
183+
def test_filter_by_column_values_case_sensitive(
184+
self, mock_connection, mock_sea_client, sea_response_with_tables
185+
):
159186
"""Test case-sensitive filtering by column values."""
160187
import copy
161-
188+
162189
# Create a fresh result set for the first test
163190
sea_response_copy1 = copy.deepcopy(sea_response_with_tables)
164191
result_set1 = SeaResultSet(
@@ -168,18 +195,18 @@ def test_filter_by_column_values_case_sensitive(self, mock_connection, mock_sea_
168195
buffer_size_bytes=1000,
169196
arraysize=100,
170197
)
171-
198+
172199
# First test: Case-sensitive filtering with lowercase values (should find no matches)
173200
filtered_result = ResultSetFilter.filter_by_column_values(
174201
result_set1,
175202
column_index=5, # tableType column
176203
allowed_values=["table", "view"], # lowercase
177204
case_sensitive=True,
178205
)
179-
206+
180207
# Verify no matches with lowercase values
181208
assert len(filtered_result._response["result"]["data_array"]) == 0
182-
209+
183210
# Create a fresh result set for the second test
184211
sea_response_copy2 = copy.deepcopy(sea_response_with_tables)
185212
result_set2 = SeaResultSet(
@@ -189,20 +216,22 @@ def test_filter_by_column_values_case_sensitive(self, mock_connection, mock_sea_
189216
buffer_size_bytes=1000,
190217
arraysize=100,
191218
)
192-
219+
193220
# Second test: Case-sensitive filtering with correct case (should find matches)
194221
filtered_result = ResultSetFilter.filter_by_column_values(
195222
result_set2,
196223
column_index=5, # tableType column
197224
allowed_values=["TABLE", "VIEW"], # correct case
198225
case_sensitive=True,
199226
)
200-
227+
201228
# Verify matches with correct case
202229
assert len(filtered_result._response["result"]["data_array"]) == 2
203-
230+
204231
# Extract the table types from the filtered results
205-
table_types = [row[5] for row in filtered_result._response["result"]["data_array"]]
232+
table_types = [
233+
row[5] for row in filtered_result._response["result"]["data_array"]
234+
]
206235
assert "TABLE" in table_types
207236
assert "VIEW" in table_types
208237

@@ -214,4 +243,4 @@ def test_filter_by_column_values_out_of_bounds(self, sea_result_set_with_tables)
214243
)
215244

216245
# No rows should match
217-
assert len(filtered_result._response["result"]["data_array"]) == 0
246+
assert len(filtered_result._response["result"]["data_array"]) == 0

tests/unit/test_sea_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def sea_client(self, mock_http_client):
5454
def sea_session_id(self):
5555
"""Create a SEA session ID."""
5656
return SessionId.from_sea_session_id("test-session-123")
57-
57+
5858
@pytest.fixture
5959
def sea_command_id(self):
6060
"""Create a SEA command ID."""
@@ -538,7 +538,7 @@ def test_get_execution_result(
538538
# Verify basic properties of the result
539539
assert result.statement_id == "test-statement-123"
540540
assert result.status == CommandState.SUCCEEDED
541-
541+
542542
# Verify the HTTP request
543543
mock_http_client._make_request.assert_called_once()
544544
args, kwargs = mock_http_client._make_request.call_args

0 commit comments

Comments
 (0)