Skip to content

Commit cfc047f

Browse files
committed
More fix
1 parent e726c33 commit cfc047f

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

src/databricks/sql/utils.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,14 @@ def build_queue(
7474
ResultSetQueue
7575
"""
7676
if row_set_type == TSparkRowSetType.ARROW_BASED_SET:
77-
arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table(
77+
arrow_record_batches, n_valid_rows = convert_arrow_based_set_to_arrow_table(
7878
t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes
7979
)
80-
converted_arrow_table = convert_decimals_in_arrow_table(
81-
arrow_table, description
82-
)
83-
return ArrowQueue(converted_arrow_table, n_valid_rows)
80+
# converted_arrow_table = convert_decimals_in_arrow_table(
81+
# arrow_table, description
82+
# )
83+
arrow_stream_table = ArrowStreamTable(arrow_record_batches, n_valid_rows, description)
84+
return ArrowQueue(arrow_stream_table, n_valid_rows, description)
8485
elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET:
8586
column_table, column_names = convert_column_based_set_to_column_table(
8687
t_row_set.columns, description
@@ -159,7 +160,6 @@ def __init__(self, record_batches: List["pyarrow.RecordBatch"], num_rows: int, c
159160
self.record_batches = record_batches
160161
self.num_rows = num_rows
161162
self.column_description = column_description
162-
self.curr_batch_index = 0
163163

164164
def append(self, other: ArrowStreamTable):
165165
if self.column_description != other.column_description:
@@ -187,7 +187,9 @@ def next_n_rows(self, req_num_rows: int):
187187
req_num_rows = 0
188188

189189
return ArrowStreamTable(consumed_batches, consumed_num_rows, self.column_description)
190-
190+
191+
def remaining_rows(self):
192+
return self
191193

192194
def convert_decimals_in_record_batch(self,batch: "pyarrow.RecordBatch") -> "pyarrow.RecordBatch":
193195
new_columns = []
@@ -258,9 +260,9 @@ def remaining_rows(self):
258260
class ArrowQueue(ResultSetQueue):
259261
def __init__(
260262
self,
261-
arrow_table: "pyarrow.Table",
263+
arrow_stream_table: ArrowStreamTable,
262264
n_valid_rows: int,
263-
start_row_index: int = 0,
265+
column_description,
264266
):
265267
"""
266268
A queue-like wrapper over an Arrow table
@@ -269,25 +271,27 @@ def __init__(
269271
:param n_valid_rows: The index of the last valid row in the table
270272
:param start_row_index: The first row in the table we should start fetching from
271273
"""
272-
self.cur_row_index = start_row_index
273-
self.arrow_table = arrow_table
274+
self.arrow_stream_table = arrow_stream_table
274275
self.n_valid_rows = n_valid_rows
276+
self.column_description = column_description
275277

276-
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
278+
def next_n_rows(self, num_rows: int):
277279
"""Get upto the next n rows of the Arrow dataframe"""
278-
length = min(num_rows, self.n_valid_rows - self.cur_row_index)
279-
# Note that the table.slice API is not the same as Python's slice
280-
# The second argument should be length, not end index
281-
slice = self.arrow_table.slice(self.cur_row_index, length)
282-
self.cur_row_index += slice.num_rows
283-
return slice
280+
return self.arrow_stream_table.next_n_rows(num_rows)
281+
# length = min(num_rows, self.n_valid_rows - self.cur_row_index)
282+
# # Note that the table.slice API is not the same as Python's slice
283+
# # The second argument should be length, not end index
284+
# slice = self.arrow_table.slice(self.cur_row_index, length)
285+
# self.cur_row_index += slice.num_rows
286+
# return slice
284287

285-
def remaining_rows(self) -> "pyarrow.Table":
286-
slice = self.arrow_table.slice(
287-
self.cur_row_index, self.n_valid_rows - self.cur_row_index
288-
)
289-
self.cur_row_index += slice.num_rows
290-
return slice
288+
def remaining_rows(self):
289+
return self.arrow_stream_table.remaining_rows()
290+
# slice = self.arrow_table.slice(
291+
# self.cur_row_index, self.n_valid_rows - self.cur_row_index
292+
# )
293+
# self.cur_row_index += slice.num_rows
294+
# return slice
291295

292296

293297
class CloudFetchQueue(ResultSetQueue):
@@ -740,8 +744,8 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
740744
if lz4_compressed
741745
else arrow_batch.batch
742746
)
743-
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
744-
return arrow_table, n_rows
747+
arrow_record_batches = list(pyarrow.ipc.open_stream(ba))
748+
return arrow_record_batches, n_rows
745749

746750

747751
def convert_decimals_in_arrow_table(table, description) -> "pyarrow.Table":

0 commit comments

Comments
 (0)