@@ -74,13 +74,14 @@ def build_queue(
7474 ResultSetQueue
7575 """
7676 if row_set_type == TSparkRowSetType .ARROW_BASED_SET :
77- arrow_table , n_valid_rows = convert_arrow_based_set_to_arrow_table (
77+ arrow_record_batches , n_valid_rows = convert_arrow_based_set_to_arrow_table (
7878 t_row_set .arrowBatches , lz4_compressed , arrow_schema_bytes
7979 )
80- converted_arrow_table = convert_decimals_in_arrow_table (
81- arrow_table , description
82- )
83- return ArrowQueue (converted_arrow_table , n_valid_rows )
80+ # converted_arrow_table = convert_decimals_in_arrow_table(
81+ # arrow_table, description
82+ # )
83+ arrow_stream_table = ArrowStreamTable (arrow_record_batches , n_valid_rows , description )
84+ return ArrowQueue (arrow_stream_table , n_valid_rows , description )
8485 elif row_set_type == TSparkRowSetType .COLUMN_BASED_SET :
8586 column_table , column_names = convert_column_based_set_to_column_table (
8687 t_row_set .columns , description
@@ -159,7 +160,6 @@ def __init__(self, record_batches: List["pyarrow.RecordBatch"], num_rows: int, c
159160 self .record_batches = record_batches
160161 self .num_rows = num_rows
161162 self .column_description = column_description
162- self .curr_batch_index = 0
163163
164164 def append (self , other : ArrowStreamTable ):
165165 if self .column_description != other .column_description :
@@ -187,7 +187,9 @@ def next_n_rows(self, req_num_rows: int):
187187 req_num_rows = 0
188188
189189 return ArrowStreamTable (consumed_batches , consumed_num_rows , self .column_description )
190-
190+
191+ def remaining_rows (self ):
192+ return self
191193
192194 def convert_decimals_in_record_batch (self ,batch : "pyarrow.RecordBatch" ) -> "pyarrow.RecordBatch" :
193195 new_columns = []
@@ -258,9 +260,9 @@ def remaining_rows(self):
258260class ArrowQueue (ResultSetQueue ):
259261 def __init__ (
260262 self ,
261- arrow_table : "pyarrow.Table" ,
263+ arrow_stream_table : ArrowStreamTable ,
262264 n_valid_rows : int ,
263- start_row_index : int = 0 ,
265+ column_description ,
264266 ):
265267 """
266268 A queue-like wrapper over an Arrow table
@@ -269,25 +271,27 @@ def __init__(
269271 :param n_valid_rows: The index of the last valid row in the table
270272 :param start_row_index: The first row in the table we should start fetching from
271273 """
272- self .cur_row_index = start_row_index
273- self .arrow_table = arrow_table
274+ self .arrow_stream_table = arrow_stream_table
274275 self .n_valid_rows = n_valid_rows
276+ self .column_description = column_description
275277
276- def next_n_rows (self , num_rows : int ) -> "pyarrow.Table" :
278+ def next_n_rows (self , num_rows : int ):
277279 """Get upto the next n rows of the Arrow dataframe"""
278- length = min (num_rows , self .n_valid_rows - self .cur_row_index )
279- # Note that the table.slice API is not the same as Python's slice
280- # The second argument should be length, not end index
281- slice = self .arrow_table .slice (self .cur_row_index , length )
282- self .cur_row_index += slice .num_rows
283- return slice
280+ return self .arrow_stream_table .next_n_rows (num_rows )
281+ # length = min(num_rows, self.n_valid_rows - self.cur_row_index)
282+ # # Note that the table.slice API is not the same as Python's slice
283+ # # The second argument should be length, not end index
284+ # slice = self.arrow_table.slice(self.cur_row_index, length)
285+ # self.cur_row_index += slice.num_rows
286+ # return slice
284287
285- def remaining_rows (self ) -> "pyarrow.Table" :
286- slice = self .arrow_table .slice (
287- self .cur_row_index , self .n_valid_rows - self .cur_row_index
288- )
289- self .cur_row_index += slice .num_rows
290- return slice
288+ def remaining_rows (self ):
289+ return self .arrow_stream_table .remaining_rows ()
290+ # slice = self.arrow_table.slice(
291+ # self.cur_row_index, self.n_valid_rows - self.cur_row_index
292+ # )
293+ # self.cur_row_index += slice.num_rows
294+ # return slice
291295
292296
293297class CloudFetchQueue (ResultSetQueue ):
@@ -740,8 +744,8 @@ def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema
740744 if lz4_compressed
741745 else arrow_batch .batch
742746 )
743- arrow_table = pyarrow .ipc .open_stream (ba ). read_all ( )
744- return arrow_table , n_rows
747+ arrow_record_batches = list ( pyarrow .ipc .open_stream (ba ))
748+ return arrow_record_batches , n_rows
745749
746750
747751def convert_decimals_in_arrow_table (table , description ) -> "pyarrow.Table" :
0 commit comments