@@ -239,7 +239,7 @@ def __init__(
239239 self ._ssl_options = ssl_options
240240
241241 # Table state
242- self .table = None
242+ self .table = self . _create_empty_table ()
243243 self .table_row_index = 0
244244
245245 # Initialize download manager
@@ -260,23 +260,20 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
260260 Returns:
261261 pyarrow.Table
262262 """
263- if not self .table :
264- logger .debug ("CloudFetchQueue: no more rows available" )
265- # Return empty pyarrow table to cause retry of fetch
266- return self ._create_empty_table ()
263+
267264 logger .debug ("CloudFetchQueue: trying to get {} next rows" .format (num_rows ))
268265 results = self .table .slice (0 , 0 )
269- while num_rows > 0 and self .table :
266+ while num_rows > 0 and self .table .num_rows > 0 :
267+ # Replace current table with the next table if we are at the end of the current table
268+ if self .table_row_index == self .table .num_rows :
269+ self .table = self ._create_next_table ()
270+ self .table_row_index = 0
271+
270272 # Get remaining of num_rows or the rest of the current table, whichever is smaller
271273 length = min (num_rows , self .table .num_rows - self .table_row_index )
272274 table_slice = self .table .slice (self .table_row_index , length )
273275 results = pyarrow .concat_tables ([results , table_slice ])
274276 self .table_row_index += table_slice .num_rows
275-
276- # Replace current table with the next table if we are at the end of the current table
277- if self .table_row_index == self .table .num_rows :
278- self .table = self ._create_next_table ()
279- self .table_row_index = 0
280277 num_rows -= table_slice .num_rows
281278
282279 logger .debug ("CloudFetchQueue: collected {} next rows" .format (results .num_rows ))
@@ -290,11 +287,8 @@ def remaining_rows(self) -> "pyarrow.Table":
290287 pyarrow.Table
291288 """
292289
293- if not self .table :
294- # Return empty pyarrow table to cause retry of fetch
295- return self ._create_empty_table ()
296290 results = self .table .slice (0 , 0 )
297- while self .table :
291+ while self .table . num_rows > 0 :
298292 table_slice = self .table .slice (
299293 self .table_row_index , self .table .num_rows - self .table_row_index
300294 )
@@ -304,17 +298,11 @@ def remaining_rows(self) -> "pyarrow.Table":
304298 self .table_row_index = 0
305299 return results
306300
307- def _create_table_at_offset (self , offset : int ) -> Union [ "pyarrow.Table" , None ] :
301+ def _create_table_at_offset (self , offset : int ) -> "pyarrow.Table" :
308302 """Create next table at the given row offset"""
309303
310304 # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
311305 downloaded_file = self .download_manager .get_next_downloaded_file (offset )
312- if not downloaded_file :
313- logger .debug (
314- "CloudFetchQueue: Cannot find downloaded file for row {}" .format (offset )
315- )
316- # None signals no more Arrow tables can be built from the remaining handlers if any remain
317- return None
318306 arrow_table = create_arrow_table_from_arrow_file (
319307 downloaded_file .file_bytes , self .description
320308 )
@@ -330,7 +318,7 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
330318 return arrow_table
331319
332320 @abstractmethod
333- def _create_next_table (self ) -> Union [ "pyarrow.Table" , None ] :
321+ def _create_next_table (self ) -> "pyarrow.Table" :
334322 """Create next table by retrieving the logical next downloaded file."""
335323 pass
336324
@@ -349,7 +337,7 @@ class ThriftCloudFetchQueue(CloudFetchQueue):
349337
350338 def __init__ (
351339 self ,
352- schema_bytes ,
340+ schema_bytes : Optional [ bytes ] ,
353341 max_download_threads : int ,
354342 ssl_options : SSLOptions ,
355343 start_row_offset : int = 0 ,
@@ -395,26 +383,27 @@ def __init__(
395383 )
396384 self .download_manager .add_link (result_link )
397385
398- # Initialize table and position
399- self .table = self ._create_next_table ()
386+ # Initialize table and position
387+ self .table = self ._create_next_table ()
388+ else :
389+ self .table = self ._create_empty_table ()
400390
401391 def _expiry_callback (self , link : TSparkArrowResultLink ):
402392 raise Error ("Cloudfetch link has expired" )
403393
404- def _create_next_table (self ) -> Union [ "pyarrow.Table" , None ] :
394+ def _create_next_table (self ) -> "pyarrow.Table" :
405395 logger .debug (
406396 "ThriftCloudFetchQueue: Trying to get downloaded file for row {}" .format (
407397 self .start_row_index
408398 )
409399 )
410400 arrow_table = self ._create_table_at_offset (self .start_row_index )
411- if arrow_table :
412- self .start_row_index += arrow_table .num_rows
413- logger .debug (
414- "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}" .format (
415- arrow_table .num_rows , self .start_row_index
416- )
401+ self .start_row_index += arrow_table .num_rows
402+ logger .debug (
403+ "ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}" .format (
404+ arrow_table .num_rows , self .start_row_index
417405 )
406+ )
418407 return arrow_table
419408
420409
0 commit comments