@@ -155,7 +155,7 @@ def __eq__(self, other):
155155
156156
157157class ArrowStreamTable :
158- def __init__ (self , record_batches , num_rows , column_description ):
158+ def __init__ (self , record_batches : List [ "pyarrow.RecordBatch" ] , num_rows : int , column_description ):
159159 self .record_batches = record_batches
160160 self .num_rows = num_rows
161161 self .column_description = column_description
@@ -222,9 +222,17 @@ def batch_generator():
222222
223223 def remove_extraneous_rows (self ):
224224 num_rows_in_data = sum (batch .num_rows for batch in self .record_batches )
225- if num_rows_in_data > self .num_rows :
226- self .record_batches = self .record_batches [:self .num_rows ]
227- self .num_rows = self .num_rows
225+ rows_to_delete = num_rows_in_data - self .num_rows
226+ while rows_to_delete > 0 and self .record_batches :
227+ last_batch = self .record_batches [- 1 ]
228+ if last_batch .num_rows <= rows_to_delete :
229+ self .record_batches .pop ()
230+ rows_to_delete -= last_batch .num_rows
231+ else :
232+ keep_rows = last_batch .num_rows - rows_to_delete
233+ self .record_batches [- 1 ] = last_batch .slice (0 , keep_rows )
234+ rows_to_delete = 0
235+
228236
229237class ColumnQueue (ResultSetQueue ):
230238 def __init__ (self , column_table : ColumnTable ):
0 commit comments