11from __future__ import annotations
22
33from abc import ABC
4- from typing import List , Optional , Tuple , Union
4+ from typing import List , Optional , Tuple , Union , TYPE_CHECKING
55
66from databricks .sql .cloudfetch .download_manager import ResultFileDownloadManager
77
1212
1313import dateutil
1414
15- from databricks .sql .backend .sea .backend import SeaDatabricksClient
16- from databricks .sql .backend .sea .models .base import (
17- ExternalLink ,
18- ResultData ,
19- ResultManifest ,
20- )
15+ if TYPE_CHECKING :
16+ from databricks .sql .backend .sea .backend import SeaDatabricksClient
17+ from databricks .sql .backend .sea .models .base import (
18+ ExternalLink ,
19+ ResultData ,
20+ ResultManifest ,
21+ )
2122from databricks .sql .backend .sea .utils .constants import ResultFormat
2223from databricks .sql .exc import ProgrammingError , ServerOperationError
2324from databricks .sql .thrift_api .TCLIService .ttypes import TSparkArrowResultLink
@@ -109,7 +110,7 @@ def __init__(
109110 result_data : ResultData ,
110111 max_download_threads : int ,
111112 ssl_options : SSLOptions ,
112- sea_client : " SeaDatabricksClient" ,
113+ sea_client : SeaDatabricksClient ,
113114 statement_id : str ,
114115 total_chunk_count : int ,
115116 lz4_compressed : bool = False ,
@@ -140,6 +141,7 @@ def __init__(
140141
141142 self ._sea_client = sea_client
142143 self ._statement_id = statement_id
144+ self ._total_chunk_count = total_chunk_count
143145
144146 logger .debug (
145147 "SeaCloudFetchQueue: Initialize CloudFetch loader for statement {}, total chunks: {}" .format (
@@ -154,12 +156,11 @@ def __init__(
154156 return None
155157
156158 # Track the current chunk we're processing
157- self ._current_chunk_link = first_link
158-
159+ self ._current_chunk_index = 0
159160 # Initialize table and position
160- self .table = self ._create_table_from_link (self . _current_chunk_link )
161+ self .table = self ._create_table_from_link (first_link )
161162
162- def _convert_to_thrift_link (self , link : " ExternalLink" ) -> TSparkArrowResultLink :
163+ def _convert_to_thrift_link (self , link : ExternalLink ) -> TSparkArrowResultLink :
163164 """Convert SEA external links to Thrift format for compatibility with existing download manager."""
164165 # Parse the ISO format expiration time
165166 expiry_time = int (dateutil .parser .parse (link .expiration ).timestamp ())
@@ -172,36 +173,24 @@ def _convert_to_thrift_link(self, link: "ExternalLink") -> TSparkArrowResultLink
172173 httpHeaders = link .http_headers or {},
173174 )
174175
175- def _progress_chunk_link (self ) :
176+ def _get_chunk_link (self , chunk_index : int ) -> Optional [ ExternalLink ] :
176177 """Progress to the next chunk link."""
177- if not self ._current_chunk_link :
178- return None
179-
180- next_chunk_index = self ._current_chunk_link .next_chunk_index
181-
182- if next_chunk_index is None :
183- self ._current_chunk_link = None
178+ if chunk_index >= self ._total_chunk_count :
184179 return None
185180
186181 try :
187- self ._current_chunk_link = self ._sea_client .get_chunk_link (
188- self ._statement_id , next_chunk_index
189- )
182+ return self ._sea_client .get_chunk_link (self ._statement_id , chunk_index )
190183 except Exception as e :
191184 raise ServerOperationError (
192- f"Error fetching link for chunk { next_chunk_index } : { e } " ,
185+ f"Error fetching link for chunk { chunk_index } : { e } " ,
193186 {
194187 "operation-id" : self ._statement_id ,
195188 "diagnostic-info" : None ,
196189 },
197190 )
198191
199- logger .debug (
200- f"SeaCloudFetchQueue: Progressed to link for chunk { next_chunk_index } : { self ._current_chunk_link } "
201- )
202-
203192 def _create_table_from_link (
204- self , link : " ExternalLink"
193+ self , link : ExternalLink
205194 ) -> Union ["pyarrow.Table" , None ]:
206195 """Create a table from a link."""
207196
@@ -215,11 +204,8 @@ def _create_table_from_link(
215204
216205 def _create_next_table (self ) -> Union ["pyarrow.Table" , None ]:
217206 """Create next table by retrieving the logical next downloaded file."""
218-
219- self ._progress_chunk_link ()
220-
221- if not self ._current_chunk_link :
222- logger .debug ("SeaCloudFetchQueue: No current chunk link, returning" )
207+ self ._current_chunk_index += 1
208+ next_chunk_link = self ._get_chunk_link (self ._current_chunk_index )
209+ if not next_chunk_link :
223210 return None
224-
225- return self ._create_table_from_link (self ._current_chunk_link )
211+ return self ._create_table_from_link (next_chunk_link )
0 commit comments