Skip to content

Commit 2bb8328

Browse files
make download manager less defensive
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 806e5f5 commit 2bb8328

File tree

5 files changed

+178
-97
lines changed

5 files changed

+178
-97
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,14 @@ def __init__(
357357
# Initialize table and position
358358
self.table = self._create_next_table()
359359

360-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
360+
def _create_next_table(self) -> "pyarrow.Table":
361361
"""Create next table by retrieving the logical next downloaded file."""
362362
if self.link_fetcher is None:
363-
return None
363+
return self._create_empty_table()
364364

365365
chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index)
366366
if chunk_link is None:
367-
return None
367+
return self._create_empty_table()
368368

369369
row_offset = chunk_link.row_offset
370370
# NOTE: link has already been submitted to download manager at this point

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4+
import threading
45
from typing import List, Union
56

67
from databricks.sql.cloudfetch.downloader import (
78
ResultSetDownloadHandler,
89
DownloadableResultSettings,
910
DownloadedFile,
1011
)
12+
from databricks.sql.exc import Error
1113
from databricks.sql.types import SSLOptions
1214

1315
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
@@ -34,16 +36,16 @@ def __init__(
3436
)
3537
self._pending_links.append(link)
3638

37-
self._download_tasks: List[Future[DownloadedFile]] = []
3839
self._max_download_threads: int = max_download_threads
40+
41+
self._download_condition = threading.Condition()
42+
self._download_tasks: List[Future[DownloadedFile]] = []
3943
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
4044

4145
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
4246
self._ssl_options = ssl_options
4347

44-
def get_next_downloaded_file(
45-
self, next_row_offset: int
46-
) -> Union[DownloadedFile, None]:
48+
def get_next_downloaded_file(self, next_row_offset: int) -> DownloadedFile:
4749
"""
4850
Get next file that starts at given offset.
4951
@@ -62,8 +64,10 @@ def get_next_downloaded_file(
6264

6365
# No more files to download from this batch of links
6466
if len(self._download_tasks) == 0:
65-
self._shutdown_manager()
66-
return None
67+
if self._thread_pool._shutdown:
68+
raise Error("download manager shut down before file was ready")
69+
with self._download_condition:
70+
self._download_condition.wait()
6771

6872
task = self._download_tasks.pop(0)
6973
# Future's `result()` method will wait for the call to complete, and return
@@ -124,3 +128,4 @@ def _shutdown_manager(self):
124128
self._pending_links = []
125129
self._download_tasks = []
126130
self._thread_pool.shutdown(wait=False)
131+
self._download_condition.notify_all()

src/databricks/sql/utils.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __init__(
236236
self._ssl_options = ssl_options
237237

238238
# Table state
239-
self.table = None
239+
self.table = self._create_empty_table()
240240
self.table_row_index = 0
241241

242242
# Initialize download manager
@@ -256,23 +256,20 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
256256
Returns:
257257
pyarrow.Table
258258
"""
259-
if not self.table:
260-
logger.debug("CloudFetchQueue: no more rows available")
261-
# Return empty pyarrow table to cause retry of fetch
262-
return self._create_empty_table()
259+
263260
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
264261
results = self.table.slice(0, 0)
265-
while num_rows > 0 and self.table:
262+
while num_rows > 0 and self.table.num_rows > 0:
263+
# Replace current table with the next table if we are at the end of the current table
264+
if self.table_row_index == self.table.num_rows:
265+
self.table = self._create_next_table()
266+
self.table_row_index = 0
267+
266268
# Get remaining of num_rows or the rest of the current table, whichever is smaller
267269
length = min(num_rows, self.table.num_rows - self.table_row_index)
268270
table_slice = self.table.slice(self.table_row_index, length)
269271
results = pyarrow.concat_tables([results, table_slice])
270272
self.table_row_index += table_slice.num_rows
271-
272-
# Replace current table with the next table if we are at the end of the current table
273-
if self.table_row_index == self.table.num_rows:
274-
self.table = self._create_next_table()
275-
self.table_row_index = 0
276273
num_rows -= table_slice.num_rows
277274

278275
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
@@ -286,11 +283,8 @@ def remaining_rows(self) -> "pyarrow.Table":
286283
pyarrow.Table
287284
"""
288285

289-
if not self.table:
290-
# Return empty pyarrow table to cause retry of fetch
291-
return self._create_empty_table()
292286
results = self.table.slice(0, 0)
293-
while self.table:
287+
while self.table.num_rows > 0:
294288
table_slice = self.table.slice(
295289
self.table_row_index, self.table.num_rows - self.table_row_index
296290
)
@@ -300,17 +294,11 @@ def remaining_rows(self) -> "pyarrow.Table":
300294
self.table_row_index = 0
301295
return results
302296

303-
def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
297+
def _create_table_at_offset(self, offset: int) -> "pyarrow.Table":
304298
"""Create next table at the given row offset"""
305299

306300
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
307301
downloaded_file = self.download_manager.get_next_downloaded_file(offset)
308-
if not downloaded_file:
309-
logger.debug(
310-
"CloudFetchQueue: Cannot find downloaded file for row {}".format(offset)
311-
)
312-
# None signals no more Arrow tables can be built from the remaining handlers if any remain
313-
return None
314302
arrow_table = create_arrow_table_from_arrow_file(
315303
downloaded_file.file_bytes, self.description
316304
)
@@ -326,7 +314,7 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
326314
return arrow_table
327315

328316
@abstractmethod
329-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
317+
def _create_next_table(self) -> "pyarrow.Table":
330318
"""Create next table by retrieving the logical next downloaded file."""
331319
pass
332320

@@ -345,7 +333,7 @@ class ThriftCloudFetchQueue(CloudFetchQueue):
345333

346334
def __init__(
347335
self,
348-
schema_bytes,
336+
schema_bytes: Optional[bytes],
349337
max_download_threads: int,
350338
ssl_options: SSLOptions,
351339
start_row_offset: int = 0,
@@ -390,23 +378,24 @@ def __init__(
390378
)
391379
self.download_manager.add_link(result_link)
392380

393-
# Initialize table and position
394-
self.table = self._create_next_table()
381+
# Initialize table and position
382+
self.table = self._create_next_table()
383+
else:
384+
self.table = self._create_empty_table()
395385

396-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
386+
def _create_next_table(self) -> "pyarrow.Table":
397387
logger.debug(
398388
"ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format(
399389
self.start_row_index
400390
)
401391
)
402392
arrow_table = self._create_table_at_offset(self.start_row_index)
403-
if arrow_table:
404-
self.start_row_index += arrow_table.num_rows
405-
logger.debug(
406-
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
407-
arrow_table.num_rows, self.start_row_index
408-
)
393+
self.start_row_index += arrow_table.num_rows
394+
logger.debug(
395+
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
396+
arrow_table.num_rows, self.start_row_index
409397
)
398+
)
410399
return arrow_table
411400

412401

0 commit comments

Comments
 (0)