Skip to content

Commit ab2c2c4

Browse files
add unit tests for LinkFetcher
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 76fd190 commit ab2c2c4

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

src/databricks/sql/backend/sea/queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def _worker_loop(self):
215215
links_downloaded = self._trigger_next_batch_download()
216216
if not links_downloaded:
217217
break
218+
self._link_data_update.notify_all()
218219

219220
def start(self):
220221
self._worker_thread = threading.Thread(target=self._worker_loop)

tests/unit/test_sea_queue.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from databricks.sql.exc import ProgrammingError, ServerOperationError
2525
from databricks.sql.types import SSLOptions
2626
from databricks.sql.utils import ArrowQueue
27+
import threading
28+
import time
2729

2830

2931
class TestJsonQueue:
@@ -570,3 +572,156 @@ def test_hybrid_disposition_with_compressed_attachment(
570572
assert isinstance(queue, ArrowQueue)
571573
mock_decompress.assert_called_once_with(compressed_data)
572574
mock_create_table.assert_called_once_with(decompressed_data, description)
575+
576+
577+
class TestLinkFetcher:
578+
"""Unit tests for the LinkFetcher helper class."""
579+
580+
@pytest.fixture
581+
def sample_links(self):
582+
"""Provide a pair of ExternalLink objects forming two sequential chunks."""
583+
link0 = ExternalLink(
584+
external_link="https://example.com/data/chunk0",
585+
expiration="2030-01-01T00:00:00.000000",
586+
row_count=100,
587+
byte_count=1024,
588+
row_offset=0,
589+
chunk_index=0,
590+
next_chunk_index=1,
591+
http_headers={"Authorization": "Bearer token0"},
592+
)
593+
594+
link1 = ExternalLink(
595+
external_link="https://example.com/data/chunk1",
596+
expiration="2030-01-01T00:00:00.000000",
597+
row_count=100,
598+
byte_count=1024,
599+
row_offset=100,
600+
chunk_index=1,
601+
next_chunk_index=None,
602+
http_headers={"Authorization": "Bearer token1"},
603+
)
604+
605+
return link0, link1
606+
607+
def _create_fetcher(
608+
self,
609+
initial_links,
610+
backend_mock=None,
611+
download_manager_mock=None,
612+
total_chunk_count=10,
613+
):
614+
"""Helper to create a LinkFetcher instance with supplied mocks."""
615+
if backend_mock is None:
616+
backend_mock = Mock()
617+
if download_manager_mock is None:
618+
download_manager_mock = Mock()
619+
620+
return (
621+
LinkFetcher(
622+
download_manager=download_manager_mock,
623+
backend=backend_mock,
624+
statement_id="statement-123",
625+
initial_links=list(initial_links),
626+
total_chunk_count=total_chunk_count,
627+
),
628+
backend_mock,
629+
download_manager_mock,
630+
)
631+
632+
def test_add_links_and_get_next_chunk_index(self, sample_links):
633+
"""Verify that initial links are stored and next chunk index is computed correctly."""
634+
link0, link1 = sample_links
635+
636+
fetcher, _backend, download_manager = self._create_fetcher([link0])
637+
638+
# add_link should have been called for the initial link
639+
download_manager.add_link.assert_called_once()
640+
641+
# Internal mapping should contain the link
642+
assert fetcher.chunk_index_to_link[0] == link0
643+
644+
# The next chunk index should be 1 (from link0.next_chunk_index)
645+
assert fetcher._get_next_chunk_index() == 1
646+
647+
# Add second link and validate it is present
648+
fetcher._add_links([link1])
649+
assert fetcher.chunk_index_to_link[1] == link1
650+
651+
def test_trigger_next_batch_download_success(self, sample_links):
652+
"""Check that _trigger_next_batch_download fetches and stores new links."""
653+
link0, link1 = sample_links
654+
655+
backend_mock = Mock()
656+
backend_mock.get_chunk_links = Mock(return_value=[link1])
657+
658+
fetcher, backend, download_manager = self._create_fetcher(
659+
[link0], backend_mock=backend_mock
660+
)
661+
662+
# Trigger download of the next chunk (index 1)
663+
success = fetcher._trigger_next_batch_download()
664+
665+
assert success is True
666+
backend.get_chunk_links.assert_called_once_with("statement-123", 1)
667+
assert fetcher.chunk_index_to_link[1] == link1
668+
# Two calls to add_link: one for initial link, one for new link
669+
assert download_manager.add_link.call_count == 2
670+
671+
def test_trigger_next_batch_download_error(self, sample_links):
672+
"""Ensure that errors from backend are captured and surfaced."""
673+
link0, _link1 = sample_links
674+
675+
backend_mock = Mock()
676+
backend_mock.get_chunk_links.side_effect = ServerOperationError(
677+
"Backend failure"
678+
)
679+
680+
fetcher, backend, download_manager = self._create_fetcher(
681+
[link0], backend_mock=backend_mock
682+
)
683+
684+
success = fetcher._trigger_next_batch_download()
685+
686+
assert success is False
687+
assert fetcher._error is not None
688+
689+
def test_get_chunk_link_waits_until_available(self, sample_links):
690+
"""Validate that get_chunk_link blocks until the requested link is available and then returns it."""
691+
link0, link1 = sample_links
692+
693+
backend_mock = Mock()
694+
# Configure backend to return link1 when requested for chunk index 1
695+
backend_mock.get_chunk_links = Mock(return_value=[link1])
696+
697+
fetcher, backend, download_manager = self._create_fetcher(
698+
[link0], backend_mock=backend_mock, total_chunk_count=2
699+
)
700+
701+
# Holder to capture the link returned from the background thread
702+
result_container = {}
703+
704+
def _worker():
705+
result_container["link"] = fetcher.get_chunk_link(1)
706+
707+
thread = threading.Thread(target=_worker)
708+
thread.start()
709+
710+
# Give the thread a brief moment to start and attempt to fetch (and therefore block)
711+
time.sleep(0.1)
712+
713+
# Trigger the backend fetch which will add link1 and notify waiting threads
714+
fetcher._trigger_next_batch_download()
715+
716+
thread.join(timeout=2)
717+
718+
# The thread should have finished and captured link1
719+
assert result_container.get("link") == link1
720+
721+
def test_get_chunk_link_out_of_range_returns_none(self, sample_links):
722+
"""Requesting a chunk index >= total_chunk_count should immediately return None."""
723+
link0, _ = sample_links
724+
725+
fetcher, _backend, _dm = self._create_fetcher([link0], total_chunk_count=1)
726+
727+
assert fetcher.get_chunk_link(10) is None

0 commit comments

Comments
 (0)