Skip to content

Commit bc7ae81

Browse files
nits: string literalrs around type defs, naming, excess changes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 77e7061 commit bc7ae81

File tree

4 files changed

+28
-25
lines changed

4 files changed

+28
-25
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
602602
session_id = SessionId.from_thrift_handle(
603603
response.sessionHandle, properties
604604
)
605-
self._session_id_hex = session_id.guid_hex
605+
self._session_id_hex = session_id.hex_guid
606606
return session_id
607607
except:
608608
self._transport.close()
@@ -832,7 +832,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
832832
return execute_response, is_direct_results
833833

834834
def get_execution_result(
835-
self, command_id: CommandId, cursor: "Cursor"
835+
self, command_id: CommandId, cursor: Cursor
836836
) -> "ResultSet":
837837
thrift_handle = command_id.to_thrift_handle()
838838
if not thrift_handle:
@@ -1044,8 +1044,8 @@ def get_catalogs(
10441044
session_id: SessionId,
10451045
max_rows: int,
10461046
max_bytes: int,
1047-
cursor: "Cursor",
1048-
) -> "ResultSet":
1047+
cursor: Cursor,
1048+
) -> ResultSet:
10491049
thrift_handle = session_id.to_thrift_handle()
10501050
if not thrift_handle:
10511051
raise ValueError("Not a valid Thrift session ID")
@@ -1087,7 +1087,7 @@ def get_schemas(
10871087
cursor: Cursor,
10881088
catalog_name=None,
10891089
schema_name=None,
1090-
) -> "ResultSet":
1090+
) -> ResultSet:
10911091
from databricks.sql.result_set import ThriftResultSet
10921092

10931093
thrift_handle = session_id.to_thrift_handle()
@@ -1135,7 +1135,7 @@ def get_tables(
11351135
schema_name=None,
11361136
table_name=None,
11371137
table_types=None,
1138-
) -> "ResultSet":
1138+
) -> ResultSet:
11391139
from databricks.sql.result_set import ThriftResultSet
11401140

11411141
thrift_handle = session_id.to_thrift_handle()
@@ -1185,7 +1185,7 @@ def get_columns(
11851185
schema_name=None,
11861186
table_name=None,
11871187
column_name=None,
1188-
) -> "ResultSet":
1188+
) -> ResultSet:
11891189
from databricks.sql.result_set import ThriftResultSet
11901190

11911191
thrift_handle = session_id.to_thrift_handle()

src/databricks/sql/backend/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def __str__(self) -> str:
160160
if isinstance(self.secret, bytes)
161161
else str(self.secret)
162162
)
163-
return f"{self.guid_hex}|{secret_hex}"
163+
return f"{self.hex_guid}|{secret_hex}"
164164
return str(self.guid)
165165

166166
@classmethod
@@ -239,7 +239,7 @@ def to_sea_session_id(self):
239239
return self.guid
240240

241241
@property
242-
def guid_hex(self) -> str:
242+
def hex_guid(self) -> str:
243243
"""
244244
Get a hexadecimal string representation of the session ID.
245245

src/databricks/sql/result_set.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
24
from typing import List, Optional, Any, TYPE_CHECKING
35

@@ -33,8 +35,8 @@ class ResultSet(ABC):
3335

3436
def __init__(
3537
self,
36-
connection: "Connection",
37-
backend: "DatabricksClient",
38+
connection: Connection,
39+
backend: DatabricksClient,
3840
arraysize: int,
3941
buffer_size_bytes: int,
4042
command_id: CommandId,
@@ -51,8 +53,8 @@ def __init__(
5153
A ResultSet manages the results of a single command.
5254
5355
Parameters:
54-
:param connection: The parent connection
55-
:param backend: The backend client
56+
:param connection: The parent connection that was used to execute this command
57+
:param backend: The specialised backend client to be invoked in the fetch phase
5658
:param arraysize: The max number of rows to fetch at a time (PEP-249)
5759
:param buffer_size_bytes: The size (in bytes) of the internal buffer + max fetch
5860
:param command_id: The command ID
@@ -156,9 +158,9 @@ class ThriftResultSet(ResultSet):
156158

157159
def __init__(
158160
self,
159-
connection: "Connection",
160-
execute_response: "ExecuteResponse",
161-
thrift_client: "ThriftDatabricksClient",
161+
connection: Connection,
162+
execute_response: ExecuteResponse,
163+
thrift_client: ThriftDatabricksClient,
162164
buffer_size_bytes: int = 104857600,
163165
arraysize: int = 10000,
164166
use_cloud_fetch: bool = True,
@@ -314,6 +316,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
314316
if size < 0:
315317
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)
316318
results = self.results.next_n_rows(size)
319+
partial_result_chunks = [results]
317320
n_remaining_rows = size - results.num_rows
318321
self._next_row_index += results.num_rows
319322

@@ -324,11 +327,11 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
324327
):
325328
self._fill_results_buffer()
326329
partial_results = self.results.next_n_rows(n_remaining_rows)
327-
results = pyarrow.concat_tables([results, partial_results])
330+
partial_result_chunks.append(partial_results)
328331
n_remaining_rows -= partial_results.num_rows
329332
self._next_row_index += partial_results.num_rows
330333

331-
return results
334+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
332335

333336
def fetchmany_columnar(self, size: int):
334337
"""
@@ -359,7 +362,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
359362
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
360363
results = self.results.remaining_rows()
361364
self._next_row_index += results.num_rows
362-
365+
partial_result_chunks = [results]
363366
while not self.has_been_closed_server_side and self.is_direct_results:
364367
self._fill_results_buffer()
365368
partial_results = self.results.remaining_rows()
@@ -368,7 +371,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
368371
):
369372
results = self.merge_columnar(results, partial_results)
370373
else:
371-
results = pyarrow.concat_tables([results, partial_results])
374+
partial_result_chunks.append(partial_results)
372375
self._next_row_index += partial_results.num_rows
373376

374377
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
@@ -379,7 +382,7 @@ def fetchall_arrow(self) -> "pyarrow.Table":
379382
for name, col in zip(results.column_names, results.column_table)
380383
}
381384
return pyarrow.Table.from_pydict(data)
382-
return results
385+
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
383386

384387
def fetchall_columnar(self):
385388
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
@@ -452,9 +455,9 @@ class SeaResultSet(ResultSet):
452455

453456
def __init__(
454457
self,
455-
connection: "Connection",
456-
execute_response: "ExecuteResponse",
457-
sea_client: "SeaDatabricksClient",
458+
connection: Connection,
459+
execute_response: ExecuteResponse,
460+
sea_client: SeaDatabricksClient,
458461
buffer_size_bytes: int = 104857600,
459462
arraysize: int = 10000,
460463
result_data=None,

src/databricks/sql/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def guid(self):
156156
@property
157157
def guid_hex(self) -> str:
158158
"""Get the session ID in hex format"""
159-
return self._session_id.guid_hex
159+
return self._session_id.hex_guid
160160

161161
def close(self) -> None:
162162
"""Close the underlying session."""

0 commit comments

Comments
 (0)