diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py index 4421c4770..ca8077a33 100644 --- a/src/databricks/sql/cloudfetch/downloader.py +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -51,12 +51,14 @@ class DownloadableResultSettings: link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs. download_timeout (int): Timeout for download requests. Default 60 secs. max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down. + min_cloudfetch_download_speed (float): Threshold in MB/s below which to log warning. Default 0.1 MB/s. """ is_lz4_compressed: bool link_expiry_buffer_secs: int = 0 download_timeout: int = 60 max_consecutive_file_download_retries: int = 0 + min_cloudfetch_download_speed: float = 0.1 class ResultSetDownloadHandler: @@ -90,6 +92,8 @@ def run(self) -> DownloadedFile: self.link, self.settings.link_expiry_buffer_secs ) + start_time = time.time() + with self._http_client.execute( method=HttpMethod.GET, url=self.link.fileLink, @@ -102,6 +106,13 @@ def run(self) -> DownloadedFile: # Save (and decompress if needed) the downloaded file compressed_data = response.content + + # Log download metrics + download_duration = time.time() - start_time + self._log_download_metrics( + self.link.fileLink, len(compressed_data), download_duration + ) + decompressed_data = ( ResultSetDownloadHandler._decompress_data(compressed_data) if self.settings.is_lz4_compressed @@ -128,6 +139,32 @@ def run(self) -> DownloadedFile: self.link.rowCount, ) + def _log_download_metrics( + self, url: str, bytes_downloaded: int, duration_seconds: float + ): + """Log download speed metrics at INFO/WARN levels.""" + # Calculate speed in MB/s (ensure float division for precision) + speed_mbps = (float(bytes_downloaded) / (1024 * 1024)) / duration_seconds + + urlEndpoint = url.split("?")[0] + # INFO level logging + logger.info( + "CloudFetch download completed: %.4f MB/s, %d bytes in %.3fs from %s", + speed_mbps, + bytes_downloaded, + duration_seconds, + urlEndpoint, + ) + + # WARN level logging if below threshold + if speed_mbps < self.settings.min_cloudfetch_download_speed: + logger.warning( + "CloudFetch download slower than threshold: %.4f MB/s (threshold: %.1f MB/s) from %s", + speed_mbps, + self.settings.min_cloudfetch_download_speed, + url, + ) + @staticmethod def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int): """ diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py index 1013ba999..a0aa2fed4 100644 --- a/tests/unit/test_downloader.py +++ b/tests/unit/test_downloader.py @@ -23,6 +23,17 @@ class DownloaderTests(unittest.TestCase): Unit tests for checking downloader logic. """ + def _setup_time_mock_for_download(self, mock_time, end_time): + """Helper to setup time mock that handles logging system calls.""" + call_count = [0] + def time_side_effect(): + call_count[0] += 1 + if call_count[0] <= 2: # First two calls (validation, start_time) + return 1000 + else: # All subsequent calls (logging, duration calculation) + return end_time + mock_time.side_effect = time_side_effect + @patch("time.time", return_value=1000) def test_run_link_expired(self, mock_time): settings = Mock() @@ -75,13 +86,17 @@ def test_run_get_response_not_ok(self, mock_time): d.run() self.assertTrue("404" in str(context.exception)) - @patch("time.time", return_value=1000) + @patch("time.time") def test_run_uncompressed_successful(self, mock_time): + self._setup_time_mock_for_download(mock_time, 1000.5) + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = False + settings.min_cloudfetch_download_speed = 1.0 result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123" with patch.object( http_client, @@ -95,15 +110,19 @@ def test_run_uncompressed_successful(self, mock_time): assert file.file_bytes == b"1234567890" * 10 - @patch("time.time", return_value=1000) + @patch("time.time") def test_run_compressed_successful(self, mock_time): + self._setup_time_mock_for_download(mock_time, 1000.2) + http_client = DatabricksHttpClient.get_instance() file_bytes = b"1234567890" * 10 compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) settings.is_lz4_compressed = True + settings.min_cloudfetch_download_speed = 1.0 result_link = Mock(bytesNum=100, expiryTime=1001) + result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789" with patch.object( http_client, "execute",