Skip to content

Commit 0868fe3

Browse files
formatting + minor type fixes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent a736bb4 commit 0868fe3

File tree

5 files changed

+86
-53
lines changed

5 files changed

+86
-53
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -145,55 +145,60 @@ def __init__(
145145
def _add_links_to_manager(self, links: List["ExternalLink"], notify: bool = True):
146146
"""
147147
Add external links to both chunk mapping and download manager.
148-
148+
149149
Args:
150150
links: List of external links to add
151151
notify: Whether to notify waiting threads (default True)
152152
"""
153153
for link in links:
154154
self.chunk_index_to_link[link.chunk_index] = link
155155
self.download_manager.add_link(self._convert_to_thrift_link(link))
156-
156+
157157
if notify:
158158
self._link_data_update.notify_all()
159159

160160
def _clear_chunks_from_index(self, start_chunk_index: int):
161161
"""
162162
Clear all chunks >= start_chunk_index from the chunk mapping.
163-
163+
164164
Args:
165165
start_chunk_index: The chunk index to start clearing from (inclusive)
166166
"""
167167
chunks_to_remove = [
168-
chunk_idx for chunk_idx in self.chunk_index_to_link.keys()
168+
chunk_idx
169+
for chunk_idx in self.chunk_index_to_link.keys()
169170
if chunk_idx >= start_chunk_index
170171
]
171-
172-
logger.debug(f"LinkFetcher: Clearing chunks {chunks_to_remove} from index {start_chunk_index}")
172+
173+
logger.debug(
174+
f"LinkFetcher: Clearing chunks {chunks_to_remove} from index {start_chunk_index}"
175+
)
173176
for chunk_idx in chunks_to_remove:
174177
del self.chunk_index_to_link[chunk_idx]
175178

176179
def _fetch_and_add_links(self, chunk_index: int) -> List["ExternalLink"]:
177180
"""
178181
Fetch links from backend and add them to manager.
179-
182+
180183
Args:
181184
chunk_index: The chunk index to fetch
182-
185+
183186
Returns:
184187
List of fetched external links
185-
188+
186189
Raises:
187190
Exception: If fetching fails
188191
"""
189192
logger.debug(f"LinkFetcher: Fetching links for chunk {chunk_index}")
190-
193+
191194
try:
192195
links = self.backend.get_chunk_links(self._statement_id, chunk_index)
193196
self._add_links_to_manager(links, notify=True)
194-
logger.debug(f"LinkFetcher: Added {len(links)} links starting from chunk {chunk_index}")
197+
logger.debug(
198+
f"LinkFetcher: Added {len(links)} links starting from chunk {chunk_index}"
199+
)
195200
return links
196-
201+
197202
except Exception as e:
198203
logger.error(f"LinkFetcher: Failed to fetch chunk {chunk_index}: {e}")
199204
self._error = e
@@ -236,38 +241,38 @@ def get_chunk_link(self, chunk_index: int) -> Optional["ExternalLink"]:
236241
def restart_from_chunk(self, chunk_index: int):
237242
"""
238243
Restart the LinkFetcher from a specific chunk index.
239-
244+
240245
This method handles both cases:
241246
1. LinkFetcher is done/closed but we need to restart it
242247
2. LinkFetcher is active but we need it to start from the expired chunk
243-
248+
244249
The key insight: we need to clear all chunks >= restart_chunk_index
245250
so that _get_next_chunk_index() returns the correct next chunk.
246-
251+
247252
Args:
248253
chunk_index: The chunk index to restart from
249254
"""
250255
logger.debug(f"LinkFetcher: Restarting from chunk {chunk_index}")
251-
256+
252257
# Stop the current worker if running
253258
self.stop()
254-
259+
255260
with self._link_data_update:
256261
# Clear error state
257262
self._error = None
258-
263+
259264
# 🔥 CRITICAL: Clear all chunks >= restart_chunk_index
260265
# This ensures _get_next_chunk_index() works correctly
261266
self._clear_chunks_from_index(chunk_index)
262-
267+
263268
# Now fetch the restart chunk (and potentially its batch)
264269
# This becomes our new "max chunk" and starting point
265270
try:
266271
self._fetch_and_add_links(chunk_index)
267272
except Exception as e:
268273
# Error already logged and set by _fetch_and_add_links
269274
raise e
270-
275+
271276
# Start the worker again - now _get_next_chunk_index() will work correctly
272277
self.start()
273278
logger.debug(f"LinkFetcher: Successfully restarted from chunk {chunk_index}")
@@ -294,7 +299,7 @@ def _worker_loop(self):
294299
def start(self):
295300
if self._worker_thread and self._worker_thread.is_alive():
296301
return # Already running
297-
302+
298303
self._shutdown_event.clear()
299304
self._worker_thread = threading.Thread(target=self._worker_loop)
300305
self._worker_thread.start()
@@ -376,22 +381,24 @@ def __init__(
376381
# Initialize table and position
377382
self.table = self._create_next_table()
378383

379-
def _handle_expired_link(self, expired_link: TSparkArrowResultLink) -> TSparkArrowResultLink:
384+
def _handle_expired_link(
385+
self, expired_link: TSparkArrowResultLink
386+
) -> TSparkArrowResultLink:
380387
"""
381388
Handle expired link for SEA backend.
382-
389+
383390
For SEA backend, we can handle expired links robustly by:
384391
1. Cancelling all pending downloads
385392
2. Finding the chunk index for the expired link
386393
3. Restarting the LinkFetcher from that chunk
387394
4. Returning the requested link
388-
395+
389396
Args:
390397
expired_link: The expired link
391-
398+
392399
Returns:
393400
A new link with the same row offset
394-
401+
395402
Raises:
396403
Error: If unable to fetch new link
397404
"""
@@ -400,14 +407,19 @@ def _handle_expired_link(self, expired_link: TSparkArrowResultLink) -> TSparkArr
400407
expired_link.startRowOffset, expired_link.rowCount
401408
)
402409
)
403-
410+
411+
if not self.download_manager:
412+
raise ValueError("Download manager not initialized")
413+
404414
try:
405415
# Step 1: Cancel all pending downloads
406416
self.download_manager.cancel_all_downloads()
407417
logger.debug("SeaCloudFetchQueue: Cancelled all pending downloads")
408-
418+
409419
# Step 2: Find which chunk contains the expired link
410-
target_chunk_index = self._find_chunk_index_for_row_offset(expired_link.startRowOffset)
420+
target_chunk_index = self._find_chunk_index_for_row_offset(
421+
expired_link.startRowOffset
422+
)
411423
if target_chunk_index is None:
412424
# If we can't find the chunk, we may need to search more broadly
413425
# For now, let's assume it's a reasonable chunk based on the row offset
@@ -419,31 +431,38 @@ def _handle_expired_link(self, expired_link: TSparkArrowResultLink) -> TSparkArr
419431
)
420432
# Try to estimate chunk index - this is a heuristic
421433
target_chunk_index = 0 # Start from beginning as fallback
422-
434+
423435
# Step 3: Restart LinkFetcher from the target chunk
424436
# This handles both stopped and active LinkFetcher cases
425437
self.link_fetcher.restart_from_chunk(target_chunk_index)
426-
438+
427439
# Step 4: Find and return the link that matches the expired link's row offset
428440
# After restart, the chunk should be available
429-
for chunk_index, external_link in self.link_fetcher.chunk_index_to_link.items():
441+
for (
442+
chunk_index,
443+
external_link,
444+
) in self.link_fetcher.chunk_index_to_link.items():
430445
if external_link.row_offset == expired_link.startRowOffset:
431-
new_thrift_link = self.link_fetcher._convert_to_thrift_link(external_link)
446+
new_thrift_link = self.link_fetcher._convert_to_thrift_link(
447+
external_link
448+
)
432449
logger.debug(
433450
"SeaCloudFetchQueue: Found replacement link for offset {}, row count {}".format(
434451
new_thrift_link.startRowOffset, new_thrift_link.rowCount
435452
)
436453
)
437454
return new_thrift_link
438-
455+
439456
# If we still can't find it, raise an error
440457
logger.error(
441458
"SeaCloudFetchQueue: Could not find replacement link for row offset {} after restart".format(
442459
expired_link.startRowOffset
443460
)
444461
)
445-
raise Error(f"CloudFetch link has expired and could not be renewed for offset {expired_link.startRowOffset}")
446-
462+
raise Error(
463+
f"CloudFetch link has expired and could not be renewed for offset {expired_link.startRowOffset}"
464+
)
465+
447466
except Exception as e:
448467
logger.error(
449468
"SeaCloudFetchQueue: Error handling expired link: {}".format(str(e))
@@ -456,18 +475,18 @@ def _handle_expired_link(self, expired_link: TSparkArrowResultLink) -> TSparkArr
456475
def _find_chunk_index_for_row_offset(self, row_offset: int) -> Optional[int]:
457476
"""
458477
Find the chunk index that contains the given row offset.
459-
478+
460479
Args:
461480
row_offset: The row offset to find
462-
481+
463482
Returns:
464483
The chunk index, or None if not found
465484
"""
466485
# Search through our known chunks to find the one containing this row offset
467486
for chunk_index, external_link in self.link_fetcher.chunk_index_to_link.items():
468487
if external_link.row_offset == row_offset:
469488
return chunk_index
470-
489+
471490
# If not found in known chunks, return None and let the caller handle it
472491
return None
473492

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141

4242
self._downloadable_result_settings = DownloadableResultSettings(
4343
is_lz4_compressed=lz4_compressed,
44-
expired_link_callback=expired_link_callback
44+
expired_link_callback=expired_link_callback,
4545
)
4646
self._ssl_options = ssl_options
4747

@@ -126,22 +126,22 @@ def add_link(self, link: TSparkArrowResultLink):
126126
def cancel_all_downloads(self):
127127
"""
128128
Cancel all pending downloads and clear the download queue.
129-
129+
130130
This method is typically called when links have expired and we need to
131131
cancel all pending downloads before fetching new links.
132132
"""
133133
logger.debug("ResultFileDownloadManager: cancelling all downloads")
134-
134+
135135
# Cancel all pending download tasks
136136
cancelled_count = 0
137137
for task in self._download_tasks:
138138
if task.cancel():
139139
cancelled_count += 1
140-
140+
141141
logger.debug(
142142
f"ResultFileDownloadManager: cancelled {cancelled_count} out of {len(self._download_tasks)} downloads"
143143
)
144-
144+
145145
# Clear the download tasks and pending links
146146
self._download_tasks.clear()
147147
self._pending_links.clear()

src/databricks/sql/cloudfetch/downloader.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ class DownloadableResultSettings:
5656
expired_link_callback (Callable): Callback function to handle expired links. Must return a new link.
5757
"""
5858

59+
expired_link_callback: Callable[[TSparkArrowResultLink], TSparkArrowResultLink]
5960
is_lz4_compressed: bool
6061
link_expiry_buffer_secs: int = 0
6162
download_timeout: int = 60
6263
max_consecutive_file_download_retries: int = 0
63-
expired_link_callback: Callable[[TSparkArrowResultLink], TSparkArrowResultLink] = None
6464

6565

6666
class ResultSetDownloadHandler:
@@ -90,7 +90,10 @@ def run(self) -> DownloadedFile:
9090

9191
# Check if link is already expired or is expiring
9292
ResultSetDownloadHandler._validate_link(
93-
self.link, self.settings.link_expiry_buffer_secs, self.settings.expired_link_callback, self
93+
self.link,
94+
self.settings.link_expiry_buffer_secs,
95+
self.settings.expired_link_callback,
96+
self,
9497
)
9598

9699
session = requests.Session()
@@ -158,7 +161,12 @@ def _is_link_expired(link: TSparkArrowResultLink, expiry_buffer_secs: int) -> bo
158161
)
159162

160163
@staticmethod
161-
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int, expired_link_callback: Callable, handler_instance):
164+
def _validate_link(
165+
link: TSparkArrowResultLink,
166+
expiry_buffer_secs: int,
167+
expired_link_callback: Callable,
168+
handler_instance,
169+
):
162170
"""
163171
Check if a link has expired or will expire, and handle expired links via callback.
164172

src/databricks/sql/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,16 +387,18 @@ def __init__(
387387
# Initialize table and position
388388
self.table = self._create_next_table()
389389

390-
def _handle_expired_link(self, expired_link: TSparkArrowResultLink) -> TSparkArrowResultLink:
390+
def _handle_expired_link(
391+
self, expired_link: TSparkArrowResultLink
392+
) -> TSparkArrowResultLink:
391393
"""
392394
Handle expired link for Thrift backend.
393-
395+
394396
For Thrift backend, we cannot fetch new links, so we raise an error.
395397
This maintains the existing behavior for Thrift.
396-
398+
397399
Args:
398400
expired_link: The expired link
399-
401+
400402
Raises:
401403
Error: Always raises an error indicating the link has expired
402404
"""

tests/unit/test_downloader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ class DownloaderTests(unittest.TestCase):
2424
def test_run_link_expired(self, mock_time):
2525
settings = Mock()
2626
settings.link_expiry_buffer_secs = 0
27-
settings.expired_link_callback = Mock(side_effect=Error("CloudFetch link has expired"))
27+
settings.expired_link_callback = Mock(
28+
side_effect=Error("CloudFetch link has expired")
29+
)
2830
result_link = Mock()
2931
# Already expired
3032
result_link.expiryTime = 999
@@ -41,7 +43,9 @@ def test_run_link_expired(self, mock_time):
4143
@patch("time.time", return_value=1000)
4244
def test_run_link_past_expiry_buffer(self, mock_time):
4345
settings = Mock(link_expiry_buffer_secs=5)
44-
settings.expired_link_callback = Mock(side_effect=Error("CloudFetch link has expired"))
46+
settings.expired_link_callback = Mock(
47+
side_effect=Error("CloudFetch link has expired")
48+
)
4549
result_link = Mock()
4650
# Within the expiry buffer time
4751
result_link.expiryTime = 1004

0 commit comments

Comments
 (0)