Skip to content

Commit 2cd802e

Browse files
reduce repeated init
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 47bb758 commit 2cd802e

File tree

3 files changed

+15
-42
lines changed

3 files changed

+15
-42
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -280,22 +280,7 @@ def _worker_loop(self):
280280
self._link_data_update.notify_all()
281281

282282
def _restart_from_expired_link(self, link: TSparkArrowResultLink):
283-
self.stop()
284-
285-
with self._link_data_update:
286-
self.download_manager.cancel_tasks_from_offset(link.startRowOffset)
287-
288-
chunks_to_restart = []
289-
for chunk_index, l in self.chunk_index_to_link.items():
290-
if l.row_offset < link.startRowOffset:
291-
continue
292-
chunks_to_restart.append(chunk_index)
293-
for chunk_index in chunks_to_restart:
294-
self.chunk_index_to_link.pop(chunk_index)
295-
296-
self.start()
297-
298-
def _restart_from_expired_link(self, link: TSparkArrowResultLink):
283+
"""Restart the link fetcher from the expired link."""
299284
self.stop()
300285

301286
with self._link_data_update:
@@ -363,6 +348,7 @@ def __init__(
363348
schema_bytes=None,
364349
lz4_compressed=lz4_compressed,
365350
description=description,
351+
expiry_callback=self._expiry_callback,
366352
)
367353

368354
logger.debug(
@@ -376,14 +362,6 @@ def __init__(
376362
# Track the current chunk we're processing
377363
self._current_chunk_index = 0
378364

379-
self.download_manager = ResultFileDownloadManager(
380-
links=[],
381-
max_download_threads=max_download_threads,
382-
lz4_compressed=lz4_compressed,
383-
ssl_options=ssl_options,
384-
expiry_callback=self._expiry_callback,
385-
)
386-
387365
self.link_fetcher = None
388366
if total_chunk_count > 0:
389367
self.link_fetcher = LinkFetcher(
@@ -402,6 +380,8 @@ def _expiry_callback(self, link: TSparkArrowResultLink):
402380
logger.info(
403381
f"SeaCloudFetchQueue: Link expired, restarting from offset {link.startRowOffset}"
404382
)
383+
if not self.link_fetcher:
384+
return
405385
self.link_fetcher._restart_from_expired_link(link)
406386

407387
def _create_next_table(self) -> Union["pyarrow.Table", None]:

src/databricks/sql/cloudfetch/download_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,28 @@
1414

1515
logger = logging.getLogger(__name__)
1616

17-
T = TypeVar('T')
17+
T = TypeVar("T")
1818

1919

2020
class TaskWithMetadata(Generic[T]):
2121
"""
2222
Wrapper around Future that stores additional metadata (the link).
2323
Provides type-safe access to both the Future result and the associated link.
2424
"""
25-
25+
2626
def __init__(self, future: Future[T], link: TSparkArrowResultLink):
2727
self.future = future
2828
self.link = link
29-
29+
3030
def result(self, timeout: Optional[float] = None) -> T:
3131
"""Get the result of the Future, blocking if necessary."""
3232
return self.future.result(timeout)
33-
33+
3434
def cancel(self) -> bool:
3535
"""Cancel the Future if possible."""
3636
return self.future.cancel()
3737

3838

39-
4039
class ResultFileDownloadManager:
4140
def __init__(
4241
self,

src/databricks/sql/utils.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Dict, List, Optional, Union
2+
from typing import Callable, Dict, List, Optional, Union
33

44
from dateutil import parser
55
import datetime
@@ -219,6 +219,7 @@ def __init__(
219219
schema_bytes: Optional[bytes] = None,
220220
lz4_compressed: bool = True,
221221
description: List[Tuple] = [],
222+
expiry_callback: Callable[[TSparkArrowResultLink], None] = lambda _: None,
222223
):
223224
"""
224225
Initialize the base CloudFetchQueue.
@@ -247,6 +248,7 @@ def __init__(
247248
max_download_threads=max_download_threads,
248249
lz4_compressed=lz4_compressed,
249250
ssl_options=ssl_options,
251+
expiry_callback=expiry_callback,
250252
)
251253

252254
def next_n_rows(self, num_rows: int) -> "pyarrow.Table":
@@ -373,6 +375,7 @@ def __init__(
373375
schema_bytes=schema_bytes,
374376
lz4_compressed=lz4_compressed,
375377
description=description,
378+
expiry_callback=self._expiry_callback,
376379
)
377380

378381
self.start_row_index = start_row_offset
@@ -392,21 +395,12 @@ def __init__(
392395
)
393396
self.download_manager.add_link(result_link)
394397

395-
def expiry_callback(link: TSparkArrowResultLink):
396-
raise Error("Cloudfetch link has expired")
397-
398-
# Initialize download manager
399-
self.download_manager = ResultFileDownloadManager(
400-
links=self.result_links,
401-
max_download_threads=self.max_download_threads,
402-
lz4_compressed=self.lz4_compressed,
403-
ssl_options=self._ssl_options,
404-
expiry_callback=expiry_callback,
405-
)
406-
407398
# Initialize table and position
408399
self.table = self._create_next_table()
409400

401+
def _expiry_callback(self, link: TSparkArrowResultLink):
402+
raise Error("Cloudfetch link has expired")
403+
410404
def _create_next_table(self) -> Union["pyarrow.Table", None]:
411405
logger.debug(
412406
"ThriftCloudFetchQueue: Trying to get downloaded file for row {}".format(

0 commit comments

Comments
 (0)