@@ -27,7 +27,7 @@ use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
2727use super :: arrow_struct_to_literal;
2828use super :: record_batch_projector:: RecordBatchProjector ;
2929use crate :: arrow:: type_to_arrow_type;
30- use crate :: spec:: { Literal , PartitionSpecRef , SchemaRef , Struct , StructType , Type } ;
30+ use crate :: spec:: { Literal , PartitionKey , PartitionSpecRef , SchemaRef , Struct , StructType , Type } ;
3131use crate :: transform:: { BoxedTransformFunction , create_transform_function} ;
3232use crate :: { Error , ErrorKind , Result } ;
3333
@@ -141,7 +141,7 @@ impl RecordBatchPartitionSplitter {
141141 }
142142
143143 /// Split the record batch into multiple record batches based on the partition spec.
144- pub fn split ( & self , batch : & RecordBatch ) -> Result < Vec < ( Struct , RecordBatch ) > > {
144+ pub fn split ( & self , batch : & RecordBatch ) -> Result < Vec < ( PartitionKey , RecordBatch ) > > {
145145 let source_columns = self . projector . project_column ( batch. columns ( ) ) ?;
146146 let partition_columns = source_columns
147147 . into_iter ( )
@@ -172,8 +172,15 @@ impl RecordBatchPartitionSplitter {
172172 filter. into ( )
173173 } ;
174174
175+ // Create PartitionKey from the partition struct
176+ let partition_key = PartitionKey :: new (
177+ self . partition_spec . as_ref ( ) . clone ( ) ,
178+ self . schema . clone ( ) ,
179+ row,
180+ ) ;
181+
175182 // filter the RecordBatch
176- partition_batches. push ( ( row , filter_record_batch ( batch, & filter_array) ?) ) ;
183+ partition_batches. push ( ( partition_key , filter_record_batch ( batch, & filter_array) ?) ) ;
177184 }
178185
179186 Ok ( partition_batches)
@@ -243,8 +250,8 @@ mod tests {
243250 let mut partitioned_batches = partition_splitter
244251 . split ( & batch)
245252 . expect ( "Failed to split RecordBatch" ) ;
246- partitioned_batches. sort_by_key ( |( row , _) | {
247- if let PrimitiveLiteral :: Int ( i) = row . fields ( ) [ 0 ]
253+ partitioned_batches. sort_by_key ( |( partition_key , _) | {
254+ if let PrimitiveLiteral :: Int ( i) = partition_key . data ( ) . fields ( ) [ 0 ]
248255 . as_ref ( )
249256 . unwrap ( )
250257 . as_primitive_literal ( )
@@ -292,7 +299,7 @@ mod tests {
292299
293300 let partition_values = partitioned_batches
294301 . iter ( )
295- . map ( |( row , _) | row . clone ( ) )
302+ . map ( |( partition_key , _) | partition_key . data ( ) . clone ( ) )
296303 . collect :: < Vec < _ > > ( ) ;
297304 // check partition value is struct(1), struct(2), struct(3)
298305 assert_eq ! ( partition_values, vec![
0 commit comments