Skip to content

Commit b5fbeb6

Browse files
Merge branch 'less-defensive-download' into sea-link-expiry
2 parents 0475abf + 3c1ff9b commit b5fbeb6

File tree

5 files changed

+192
-101
lines changed

5 files changed

+192
-101
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,14 @@ def _expiry_callback(self, link: TSparkArrowResultLink):
384384
return
385385
self.link_fetcher._restart_from_expired_link(link)
386386

387-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
387+
def _create_next_table(self) -> "pyarrow.Table":
388388
"""Create next table by retrieving the logical next downloaded file."""
389389
if self.link_fetcher is None:
390-
return None
390+
return self._create_empty_table()
391391

392392
chunk_link = self.link_fetcher.get_chunk_link(self._current_chunk_index)
393393
if chunk_link is None:
394-
return None
394+
return self._create_empty_table()
395395

396396
row_offset = chunk_link.row_offset
397397
# NOTE: link has already been submitted to download manager at this point

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import logging
22

33
from concurrent.futures import ThreadPoolExecutor, Future
4+
import threading
45
from typing import Callable, List, Optional, Union, Generic, TypeVar
56

67
from databricks.sql.cloudfetch.downloader import (
78
ResultSetDownloadHandler,
89
DownloadableResultSettings,
910
DownloadedFile,
1011
)
12+
from databricks.sql.exc import Error
1113
from databricks.sql.types import SSLOptions
1214

1315
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink
@@ -56,25 +58,23 @@ def __init__(
5658
)
5759
self._pending_links.append(link)
5860

59-
self._download_tasks: List[TaskWithMetadata[DownloadedFile]] = []
6061
self._max_download_threads: int = max_download_threads
62+
63+
self._download_condition = threading.Condition()
64+
self._download_tasks: List[TaskWithMetadata[DownloadedFile]] = []
6165
self._thread_pool = ThreadPoolExecutor(max_workers=self._max_download_threads)
6266

6367
self._downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
6468
self._ssl_options = ssl_options
6569
self._expiry_callback = expiry_callback
6670

67-
def get_next_downloaded_file(
68-
self, next_row_offset: int
69-
) -> Union[DownloadedFile, None]:
71+
def get_next_downloaded_file(self, next_row_offset: int) -> DownloadedFile:
7072
"""
7173
Get next file that starts at given offset.
7274
7375
This function gets the next downloaded file in which its rows start at the specified next_row_offset
7476
in relation to the full result. File downloads are scheduled if not already, and once the correct
7577
download handler is located, the function waits for the download status and returns the resulting file.
76-
If there are no more downloads, a download was not successful, or the correct file could not be located,
77-
this function returns None.
7878
7979
Args:
8080
next_row_offset (int): The offset of the starting row of the next file we want data from.
@@ -84,8 +84,11 @@ def get_next_downloaded_file(
8484
self._schedule_downloads()
8585

8686
# No more files to download from this batch of links
87-
if len(self._download_tasks) == 0:
88-
return None
87+
while len(self._download_tasks) == 0:
88+
if self._thread_pool._shutdown:
89+
raise Error("download manager shut down before file was ready")
90+
with self._download_condition:
91+
self._download_condition.wait()
8992

9093
task = self._download_tasks.pop(0)
9194
# Future's `result()` method will wait for the call to complete, and return
@@ -160,6 +163,9 @@ def _schedule_downloads(self):
160163
task = TaskWithMetadata(future, link)
161164
self._download_tasks.append(task)
162165

166+
with self._download_condition:
167+
self._download_condition.notify_all()
168+
163169
def add_link(self, link: TSparkArrowResultLink):
164170
"""
165171
Add more links to the download manager.
@@ -176,8 +182,11 @@ def add_link(self, link: TSparkArrowResultLink):
176182
)
177183
self._pending_links.append(link)
178184

185+
self._schedule_downloads()
186+
179187
def _shutdown_manager(self):
180188
# Clear download handlers and shutdown the thread pool
181189
self._pending_links = []
182190
self._download_tasks = []
183191
self._thread_pool.shutdown(wait=False)
192+
self._download_condition.notify_all()

src/databricks/sql/utils.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
self._ssl_options = ssl_options
240240

241241
# Table state
242-
self.table = None
242+
self.table = self._create_empty_table()
243243
self.table_row_index = 0
244244

245245
# Initialize download manager
@@ -260,23 +260,20 @@ def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
260260
Returns:
261261
pyarrow.Table
262262
"""
263-
if not self.table:
264-
logger.debug("CloudFetchQueue: no more rows available")
265-
# Return empty pyarrow table to cause retry of fetch
266-
return self._create_empty_table()
263+
267264
logger.debug("CloudFetchQueue: trying to get {} next rows".format(num_rows))
268265
results = self.table.slice(0, 0)
269-
while num_rows > 0 and self.table:
266+
while num_rows > 0 and self.table.num_rows > 0:
267+
# Replace current table with the next table if we are at the end of the current table
268+
if self.table_row_index == self.table.num_rows:
269+
self.table = self._create_next_table()
270+
self.table_row_index = 0
271+
270272
# Get remaining of num_rows or the rest of the current table, whichever is smaller
271273
length = min(num_rows, self.table.num_rows - self.table_row_index)
272274
table_slice = self.table.slice(self.table_row_index, length)
273275
results = pyarrow.concat_tables([results, table_slice])
274276
self.table_row_index += table_slice.num_rows
275-
276-
# Replace current table with the next table if we are at the end of the current table
277-
if self.table_row_index == self.table.num_rows:
278-
self.table = self._create_next_table()
279-
self.table_row_index = 0
280277
num_rows -= table_slice.num_rows
281278

282279
logger.debug("CloudFetchQueue: collected {} next rows".format(results.num_rows))
@@ -290,11 +287,8 @@ def remaining_rows(self) -> "pyarrow.Table":
290287
pyarrow.Table
291288
"""
292289

293-
if not self.table:
294-
# Return empty pyarrow table to cause retry of fetch
295-
return self._create_empty_table()
296290
results = self.table.slice(0, 0)
297-
while self.table:
291+
while self.table.num_rows > 0:
298292
table_slice = self.table.slice(
299293
self.table_row_index, self.table.num_rows - self.table_row_index
300294
)
@@ -304,17 +298,11 @@ def remaining_rows(self) -> "pyarrow.Table":
304298
self.table_row_index = 0
305299
return results
306300

307-
def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
301+
def _create_table_at_offset(self, offset: int) -> "pyarrow.Table":
308302
"""Create next table at the given row offset"""
309303

310304
# Create next table by retrieving the logical next downloaded file, or return None to signal end of queue
311305
downloaded_file = self.download_manager.get_next_downloaded_file(offset)
312-
if not downloaded_file:
313-
logger.debug(
314-
"CloudFetchQueue: Cannot find downloaded file for row {}".format(offset)
315-
)
316-
# None signals no more Arrow tables can be built from the remaining handlers if any remain
317-
return None
318306
arrow_table = create_arrow_table_from_arrow_file(
319307
downloaded_file.file_bytes, self.description
320308
)
@@ -330,7 +318,7 @@ def _create_table_at_offset(self, offset: int) -> Union["pyarrow.Table", None]:
330318
return arrow_table
331319

332320
@abstractmethod
333-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
321+
def _create_next_table(self) -> "pyarrow.Table":
334322
"""Create next table by retrieving the logical next downloaded file."""
335323
pass
336324

@@ -349,7 +337,7 @@ class ThriftCloudFetchQueue(CloudFetchQueue):
349337

350338
def __init__(
351339
self,
352-
schema_bytes,
340+
schema_bytes: Optional[bytes],
353341
max_download_threads: int,
354342
ssl_options: SSLOptions,
355343
start_row_offset: int = 0,
@@ -395,26 +383,27 @@ def __init__(
395383
)
396384
self.download_manager.add_link(result_link)
397385

398-
# Initialize table and position
399-
self.table = self._create_next_table()
386+
# Initialize table and position
387+
self.table = self._create_next_table()
388+
else:
389+
self.table = self._create_empty_table()
400390

401391
def _expiry_callback(self, link: TSparkArrowResultLink):
402392
raise Error("Cloudfetch link has expired")
403393

404-
def _create_next_table(self) -> Union["pyarrow.Table", None]:
394+
def _create_next_table(self) -> "pyarrow.Table":
405395
logger.debug(
406396
"ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format(
407397
self.start_row_index
408398
)
409399
)
410400
arrow_table = self._create_table_at_offset(self.start_row_index)
411-
if arrow_table:
412-
self.start_row_index += arrow_table.num_rows
413-
logger.debug(
414-
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
415-
arrow_table.num_rows, self.start_row_index
416-
)
401+
self.start_row_index += arrow_table.num_rows
402+
logger.debug(
403+
"ThriftCloudFetchQueue: Found downloaded file, row count: {}, new start offset: {}".format(
404+
arrow_table.num_rows, self.start_row_index
417405
)
406+
)
418407
return arrow_table
419408

420409

0 commit comments

Comments
 (0)