Skip to content

Commit b2cf39b

Browse files
committed
fixed
1 parent e0ca049 commit b2cf39b

File tree

3 files changed

+82
-35
lines changed

3 files changed

+82
-35
lines changed

src/databricks/sql/result_set.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020

2121
from databricks.sql.types import Row
2222
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
23-
from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue
23+
from databricks.sql.utils import (
24+
ExecuteResponse,
25+
ColumnTable,
26+
ColumnQueue,
27+
concat_table_chunks,
28+
)
2429

2530
logger = logging.getLogger(__name__)
2631

@@ -251,23 +256,6 @@ def _convert_arrow_table(self, table):
251256
res = df.to_numpy(na_value=None, dtype="object")
252257
return [ResultRow(*v) for v in res]
253258

254-
def merge_columnar(self, result1, result2) -> "ColumnTable":
255-
"""
256-
Function to merge / combining the columnar results into a single result
257-
:param result1:
258-
:param result2:
259-
:return:
260-
"""
261-
262-
if result1.column_names != result2.column_names:
263-
raise ValueError("The columns in the results don't match")
264-
265-
merged_result = [
266-
result1.column_table[i] + result2.column_table[i]
267-
for i in range(result1.num_columns)
268-
]
269-
return ColumnTable(merged_result, result1.column_names)
270-
271259
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
272260
"""
273261
Fetch the next set of rows of a query result, returning a PyArrow table.
@@ -292,7 +280,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
292280
n_remaining_rows -= partial_results.num_rows
293281
self._next_row_index += partial_results.num_rows
294282

295-
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
283+
return concat_table_chunks(partial_result_chunks)
296284

297285
def fetchmany_columnar(self, size: int):
298286
"""
@@ -305,19 +293,19 @@ def fetchmany_columnar(self, size: int):
305293
results = self.results.next_n_rows(size)
306294
n_remaining_rows = size - results.num_rows
307295
self._next_row_index += results.num_rows
308-
296+
partial_result_chunks = [results]
309297
while (
310298
n_remaining_rows > 0
311299
and not self.has_been_closed_server_side
312300
and self.has_more_rows
313301
):
314302
self._fill_results_buffer()
315303
partial_results = self.results.next_n_rows(n_remaining_rows)
316-
results = self.merge_columnar(results, partial_results)
304+
partial_result_chunks.append(partial_results)
317305
n_remaining_rows -= partial_results.num_rows
318306
self._next_row_index += partial_results.num_rows
319307

320-
return results
308+
return concat_table_chunks(partial_result_chunks)
321309

322310
def fetchall_arrow(self) -> "pyarrow.Table":
323311
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
@@ -327,36 +315,34 @@ def fetchall_arrow(self) -> "pyarrow.Table":
327315
while not self.has_been_closed_server_side and self.has_more_rows:
328316
self._fill_results_buffer()
329317
partial_results = self.results.remaining_rows()
330-
if isinstance(results, ColumnTable) and isinstance(
331-
partial_results, ColumnTable
332-
):
333-
results = self.merge_columnar(results, partial_results)
334-
else:
335-
partial_result_chunks.append(partial_results)
318+
partial_result_chunks.append(partial_results)
336319
self._next_row_index += partial_results.num_rows
337320

321+
result_table = concat_table_chunks(partial_result_chunks)
338322
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
339323
# Valid only for metadata commands result set
340-
if isinstance(results, ColumnTable) and pyarrow:
324+
if isinstance(result_table, ColumnTable) and pyarrow:
341325
data = {
342326
name: col
343-
for name, col in zip(results.column_names, results.column_table)
327+
for name, col in zip(
328+
result_table.column_names, result_table.column_table
329+
)
344330
}
345331
return pyarrow.Table.from_pydict(data)
346-
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
332+
return result_table
347333

348334
def fetchall_columnar(self):
349335
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
350336
results = self.results.remaining_rows()
351337
self._next_row_index += results.num_rows
352-
338+
partial_result_chunks = [results]
353339
while not self.has_been_closed_server_side and self.has_more_rows:
354340
self._fill_results_buffer()
355341
partial_results = self.results.remaining_rows()
356-
results = self.merge_columnar(results, partial_results)
342+
partial_result_chunks.append(partial_results)
357343
self._next_row_index += partial_results.num_rows
358344

359-
return results
345+
return concat_table_chunks(partial_result_chunks)
360346

361347
def fetchone(self) -> Optional[Row]:
362348
"""

src/databricks/sql/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,3 +785,25 @@ def _create_python_tuple(t_col_value_wrapper):
785785
result[i] = None
786786

787787
return tuple(result)
788+
789+
790+
def concat_table_chunks(
791+
table_chunks: List[Union["pyarrow.Table", ColumnTable]]
792+
) -> Union["pyarrow.Table", ColumnTable]:
793+
if len(table_chunks) == 0:
794+
return table_chunks
795+
796+
if isinstance(table_chunks[0], ColumnTable):
797+
## Check if all have the same column names
798+
if not all(
799+
table.column_names == table_chunks[0].column_names for table in table_chunks
800+
):
801+
raise ValueError("The columns in the results don't match")
802+
803+
result_table = table_chunks[0].column_table
804+
for i in range(1, len(table_chunks)):
805+
for j in range(table_chunks[i].num_columns):
806+
result_table[j].extend(table_chunks[i].column_table[j])
807+
return ColumnTable(result_table, table_chunks[0].column_names)
808+
else:
809+
return pyarrow.concat_tables(table_chunks, use_threads=True)

tests/unit/test_util.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
import decimal
22
import datetime
33
from datetime import timezone, timedelta
4+
import pytest
5+
from databricks.sql.utils import (
6+
convert_to_assigned_datatypes_in_column_table,
7+
ColumnTable,
8+
concat_table_chunks,
9+
)
410

5-
from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table
11+
try:
12+
import pyarrow
13+
except ImportError:
14+
pyarrow = None
615

716

817
class TestUtils:
@@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self):
122131
for index, entry in enumerate(converted_column_table):
123132
assert entry[0] == expected_convertion[index][0]
124133
assert isinstance(entry[0], expected_convertion[index][1])
134+
135+
def test_concat_table_chunks_column_table(self):
136+
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
137+
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"])
138+
139+
result_table = concat_table_chunks([column_table1, column_table2])
140+
141+
assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]]
142+
assert result_table.column_names == ["col1", "col2"]
143+
144+
@pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed")
145+
def test_concat_table_chunks_arrow_table(self):
146+
arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]})
147+
arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]})
148+
149+
result_table = concat_table_chunks([arrow_table1, arrow_table2])
150+
assert result_table.column_names == ["col1", "col2"]
151+
assert result_table.column("col1").to_pylist() == [1, 2, 3, 4]
152+
assert result_table.column("col2").to_pylist() == [5, 6, 7, 8]
153+
154+
def test_concat_table_chunks_empty(self):
155+
result_table = concat_table_chunks([])
156+
assert result_table == []
157+
158+
def test_concat_table_chunks__incorrect_column_names_error(self):
159+
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
160+
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"])
161+
162+
with pytest.raises(ValueError):
163+
concat_table_chunks([column_table1, column_table2])

0 commit comments

Comments
 (0)