@@ -48,30 +48,7 @@ def remaining_rows(self):
4848 pass
4949
5050
51- class JsonQueue (ResultSetQueue ):
52- """Queue implementation for JSON_ARRAY format data."""
53-
54- def __init__ (self , data_array ):
55- """Initialize with JSON array data."""
56- self .data_array = data_array
57- self .cur_row_index = 0
58- self .n_valid_rows = len (data_array )
59-
60- def next_n_rows (self , num_rows ):
61- """Get the next n rows from the data array."""
62- length = min (num_rows , self .n_valid_rows - self .cur_row_index )
63- slice = self .data_array [self .cur_row_index :self .cur_row_index + length ]
64- self .cur_row_index += length
65- return slice
66-
67- def remaining_rows (self ):
68- """Get all remaining rows from the data array."""
69- slice = self .data_array [self .cur_row_index :]
70- self .cur_row_index += len (slice )
71- return slice
72-
73-
74- class ResultSetQueueFactory (ABC ):
51+ class ThriftResultSetQueueFactory (ABC ):
7552 @staticmethod
7653 def build_queue (
7754 row_set_type : Optional [TSparkRowSetType ] = None ,
@@ -81,57 +58,38 @@ def build_queue(
8158 ssl_options : Optional [SSLOptions ] = None ,
8259 lz4_compressed : bool = True ,
8360 description : Optional [List [List [Any ]]] = None ,
84- # SEA specific parameters
85- sea_result_data : Optional [Any ] = None ,
8661 ) -> ResultSetQueue :
8762 """
88- Factory method to build a result set queue.
89-
90- This method is extended to handle both Thrift and SEA result formats.
91- For SEA, the sea_result_data parameter is used instead of row_set_type and t_row_set.
92-
63+ Factory method to build a result set queue for Thrift backend.
64+
9365 Args:
94- # Thrift parameters
9566 row_set_type (enum): Row set type (Arrow, Column, or URL).
9667 t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links.
97-
98- # Common parameters
9968 arrow_schema_bytes (bytes): Bytes representing the arrow schema.
10069 lz4_compressed (bool): Whether result data has been lz4 compressed.
10170 description (List[List[Any]]): Hive table schema description.
10271 max_download_threads (int): Maximum number of downloader thread pool threads.
10372 ssl_options (SSLOptions): SSLOptions object for CloudFetchQueue
104-
105- # SEA parameters
106- sea_result_data (ResultData): Result data from SEA response
107-
73+
10874 Returns:
10975 ResultSetQueue
11076 """
111- # Handle SEA result data
112- if sea_result_data is not None :
113- if sea_result_data .data :
114- # INLINE disposition with JSON_ARRAY format
115- return JsonQueue (sea_result_data .data )
116- elif sea_result_data .external_links :
117- # EXTERNAL_LINKS disposition (not implemented yet)
118- raise NotImplementedError (
119- "EXTERNAL_LINKS disposition is not supported yet"
120- )
121- else :
122- # Empty result set
123- return JsonQueue ([])
124-
125- # Handle Thrift result data (existing implementation)
126- if row_set_type == TSparkRowSetType .ARROW_BASED_SET and t_row_set is not None and arrow_schema_bytes is not None :
77+ # Handle Thrift result data
78+ if (
79+ row_set_type == TSparkRowSetType .ARROW_BASED_SET
80+ and t_row_set is not None
81+ and arrow_schema_bytes is not None
82+ ):
12783 arrow_table , n_valid_rows = convert_arrow_based_set_to_arrow_table (
12884 t_row_set .arrowBatches , lz4_compressed , arrow_schema_bytes
12985 )
13086 converted_arrow_table = convert_decimals_in_arrow_table (
13187 arrow_table , description
13288 )
13389 return ArrowQueue (converted_arrow_table , n_valid_rows )
134- elif row_set_type == TSparkRowSetType .COLUMN_BASED_SET and t_row_set is not None :
90+ elif (
91+ row_set_type == TSparkRowSetType .COLUMN_BASED_SET and t_row_set is not None
92+ ):
13593 column_table , column_names = convert_column_based_set_to_column_table (
13694 t_row_set .columns , description
13795 )
@@ -141,7 +99,13 @@ def build_queue(
14199 )
142100
143101 return ColumnQueue (ColumnTable (converted_column_table , column_names ))
144- elif row_set_type == TSparkRowSetType .URL_BASED_SET and t_row_set is not None and arrow_schema_bytes is not None and max_download_threads is not None and ssl_options is not None :
102+ elif (
103+ row_set_type == TSparkRowSetType .URL_BASED_SET
104+ and t_row_set is not None
105+ and arrow_schema_bytes is not None
106+ and max_download_threads is not None
107+ and ssl_options is not None
108+ ):
145109 return CloudFetchQueue (
146110 schema_bytes = arrow_schema_bytes ,
147111 start_row_offset = t_row_set .startRowOffset ,
@@ -155,6 +119,56 @@ def build_queue(
155119 raise AssertionError ("Row set type is not valid" )
156120
157121
122+ class SeaResultSetQueueFactory (ABC ):
123+ @staticmethod
124+ def build_queue (
125+ sea_result_data : Any ,
126+ description : Optional [List [List [Any ]]] = None ,
127+ ) -> ResultSetQueue :
128+ """
129+ Factory method to build a result set queue for SEA backend.
130+
131+ Args:
132+ sea_result_data (ResultData): Result data from SEA response
133+ description (List[List[Any]]): Column descriptions
134+
135+ Returns:
136+ ResultSetQueue: The appropriate queue for the result data
137+ """
138+ if sea_result_data .data :
139+ # INLINE disposition with JSON_ARRAY format
140+ return JsonQueue (sea_result_data .data )
141+ elif sea_result_data .external_links :
142+ # EXTERNAL_LINKS disposition (not implemented yet)
143+ raise NotImplementedError ("EXTERNAL_LINKS disposition is not supported yet" )
144+ else :
145+ # Empty result set
146+ return JsonQueue ([])
147+
148+
149+ class JsonQueue (ResultSetQueue ):
150+ """Queue implementation for JSON_ARRAY format data."""
151+
152+ def __init__ (self , data_array ):
153+ """Initialize with JSON array data."""
154+ self .data_array = data_array
155+ self .cur_row_index = 0
156+ self .n_valid_rows = len (data_array )
157+
158+ def next_n_rows (self , num_rows ):
159+ """Get the next n rows from the data array."""
160+ length = min (num_rows , self .n_valid_rows - self .cur_row_index )
161+ slice = self .data_array [self .cur_row_index : self .cur_row_index + length ]
162+ self .cur_row_index += length
163+ return slice
164+
165+ def remaining_rows (self ):
166+ """Get all remaining rows from the data array."""
167+ slice = self .data_array [self .cur_row_index :]
168+ self .cur_row_index += len (slice )
169+ return slice
170+
171+
158172class ColumnTable :
159173 def __init__ (self , column_table , column_names ):
160174 self .column_table = column_table
0 commit comments