35
35
use arrow:: buffer:: MutableBuffer ;
36
36
use arrow:: compute:: BatchCoalescer ;
37
37
use futures:: { ready, StreamExt } ;
38
+ use log:: debug;
38
39
use std:: sync:: Arc ;
39
40
use std:: task:: Poll ;
40
41
@@ -59,8 +60,6 @@ use datafusion_expr::JoinType;
59
60
60
61
use futures:: Stream ;
61
62
62
- const OUTPUT_BUFFER_LIMIT : usize = 8192 ;
63
-
64
63
/// Simplified state enum without inner structs
65
64
#[ derive( Debug , Clone , Copy ) ]
66
65
enum NLJState {
@@ -279,6 +278,46 @@ impl NLJStream {
279
278
let cur_right_batch = unwrap_or_internal_err ! ( right_batch) ;
280
279
281
280
// ==== Setup unmatched indices ====
281
+ // If Right Mark
282
+ // ----
283
+ if self . join_type == JoinType :: RightMark {
284
+ // For RightMark, output all right rows, left is null where bitmap is unset, right is 0..N
285
+ let right_row_count = cur_right_batch. num_rows ( ) ;
286
+ let right_indices = UInt64Array :: from_iter_values ( 0 ..right_row_count as u64 ) ;
287
+ // TODO(now-perf): directly copy the null buffer to make this step
288
+ // faster
289
+ let mut left_indices_builder = UInt32Builder :: new ( ) ;
290
+ for i in 0 ..right_row_count {
291
+ if bitmap. value ( i) {
292
+ left_indices_builder. append_value ( i as u32 ) ;
293
+ } else {
294
+ left_indices_builder. append_null ( ) ;
295
+ }
296
+ }
297
+ let left_indices = left_indices_builder. finish ( ) ;
298
+
299
+ let left_data = self . buffered_left_data . as_ref ( ) . ok_or_else ( || {
300
+ internal_datafusion_err ! ( "LeftData should be available" )
301
+ } ) ?;
302
+ let left_batch = left_data. batch ( ) ;
303
+ let empty_left_batch = RecordBatch :: new_empty ( left_batch. schema ( ) . clone ( ) ) ;
304
+
305
+ let result_batch = build_batch_from_indices_maybe_empty (
306
+ & self . schema ,
307
+ & cur_right_batch, // swapped: right is build side
308
+ & empty_left_batch,
309
+ & right_indices,
310
+ & left_indices,
311
+ & self . column_indices ,
312
+ JoinSide :: Right ,
313
+ ) ?;
314
+
315
+ self . current_right_batch_matched = None ;
316
+ return Ok ( result_batch) ;
317
+ }
318
+
319
+ // Non Right Mark
320
+ // ----
282
321
// TODO(polish): now the actual length of bitmap might be longer than
283
322
// the actual in-use. So we have to use right batch length here to
284
323
// iterate through the bitmap
@@ -362,7 +401,7 @@ impl NLJStream {
362
401
inner_table,
363
402
join_metrics,
364
403
buffered_left_data : None ,
365
- output_buffer : Box :: new ( BatchCoalescer :: new ( schema, OUTPUT_BUFFER_LIMIT ) ) ,
404
+ output_buffer : Box :: new ( BatchCoalescer :: new ( schema, cfg_batch_size ) ) ,
366
405
cfg_batch_size,
367
406
current_right_batch : None ,
368
407
current_right_batch_matched,
@@ -404,6 +443,7 @@ impl Stream for NLJStream {
404
443
// better performance, by buffering many build-side batches
405
444
// up front.
406
445
NLJState :: BufferingLeft => {
446
+ debug ! ( "[NLJState] Entering: {:?}" , self . state) ;
407
447
match ready ! ( self . inner_table. get_shared( cx) ) {
408
448
Ok ( left_data) => {
409
449
self . buffered_left_data = Some ( left_data) ;
@@ -439,9 +479,16 @@ impl Stream for NLJStream {
439
479
// the next `EmitUnmatched` phase to check if there is any
440
480
// special handling (e.g., in cases like left join).
441
481
NLJState :: FetchingRight => {
482
+ debug ! ( "[NLJState] Entering: {:?}" , self . state) ;
442
483
match ready ! ( self . outer_table. poll_next_unpin( cx) ) {
443
484
Some ( Ok ( right_batch) ) => {
444
485
let right_batch_size = right_batch. num_rows ( ) ;
486
+
487
+ // Skip the empty batch
488
+ if right_batch_size == 0 {
489
+ continue ;
490
+ }
491
+
445
492
self . current_right_batch = Some ( right_batch) ;
446
493
447
494
// TOOD(polish): make it more understandable
@@ -482,10 +529,12 @@ impl Stream for NLJStream {
482
529
// After it has done with the current right batch, it will
483
530
// go to FetchRight state to check what to do next.
484
531
NLJState :: ProbeJoin => {
532
+ debug ! ( "[NLJState] Entering: {:?}" , self . state) ;
485
533
// Return any completed batches first
486
534
if self . output_buffer . has_completed_batch ( ) {
487
535
if let Some ( batch) = self . output_buffer . next_completed_batch ( ) {
488
- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
536
+ let poll = Poll :: Ready ( Some ( Ok ( batch) ) ) ;
537
+ return self . join_metrics . baseline . record_poll ( poll) ;
489
538
}
490
539
}
491
540
@@ -523,6 +572,7 @@ impl Stream for NLJStream {
523
572
// Precondition: we have checked the join type so that it's
524
573
// possible to output right unmatched (e.g. it's right join)
525
574
NLJState :: EmitRightUnmatched => {
575
+ debug ! ( "[NLJState] Entering: {:?}" , self . state) ;
526
576
debug_assert ! ( self . current_right_batch. is_some( ) ) ;
527
577
debug_assert ! ( self . current_right_batch_matched. is_some( ) ) ;
528
578
@@ -546,10 +596,12 @@ impl Stream for NLJStream {
546
596
// the same state, to check if there are any more final
547
597
// results to output.
548
598
NLJState :: EmitLeftUnmatched => {
599
+ debug ! ( "[NLJState] Entering: {:?}" , self . state) ;
549
600
// Return any completed batches first
550
601
if self . output_buffer . has_completed_batch ( ) {
551
602
if let Some ( batch) = self . output_buffer . next_completed_batch ( ) {
552
- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
603
+ let poll = Poll :: Ready ( Some ( Ok ( batch) ) ) ;
604
+ return self . join_metrics . baseline . record_poll ( poll) ;
553
605
}
554
606
}
555
607
@@ -570,10 +622,12 @@ impl Stream for NLJStream {
570
622
}
571
623
572
624
NLJState :: Done => {
625
+ debug ! ( "[NLJState] Entering: {:?}" , self . state) ;
573
626
// Return any remaining completed batches before final termination
574
627
if self . output_buffer . has_completed_batch ( ) {
575
628
if let Some ( batch) = self . output_buffer . next_completed_batch ( ) {
576
- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
629
+ let poll = Poll :: Ready ( Some ( Ok ( batch) ) ) ;
630
+ return self . join_metrics . baseline . record_poll ( poll) ;
577
631
}
578
632
}
579
633
@@ -607,12 +661,6 @@ impl NLJStream {
607
661
. ok_or_else ( || internal_datafusion_err ! ( "Right batch should be available" ) ) ?
608
662
. clone ( ) ;
609
663
610
- // Skip this empty batch and continue probing the next one
611
- let right_row_count = right_batch. num_rows ( ) ;
612
- if right_row_count == 0 {
613
- return Ok ( true ) ;
614
- }
615
-
616
664
// stop probing, the caller will go to the next state
617
665
if self . l_index >= left_data. batch ( ) . num_rows ( ) {
618
666
return Ok ( false ) ;
@@ -669,7 +717,7 @@ impl NLJStream {
669
717
// ========
670
718
let start_idx = self . emit_cursor as usize ;
671
719
let end_idx =
672
- std:: cmp:: min ( start_idx + OUTPUT_BUFFER_LIMIT , left_batch. num_rows ( ) ) ;
720
+ std:: cmp:: min ( start_idx + self . cfg_batch_size , left_batch. num_rows ( ) ) ;
673
721
674
722
let result_batch = self . process_unmatched_rows ( left_data, start_idx, end_idx) ?;
675
723
@@ -681,33 +729,3 @@ impl NLJStream {
681
729
Ok ( true )
682
730
}
683
731
}
684
-
685
- #[ cfg( test) ]
686
- pub ( crate ) mod tests {
687
- use super :: * ;
688
- use arrow:: datatypes:: { DataType , Field } ;
689
-
690
- #[ test]
691
- fn test_nlj_basic_compilation ( ) {
692
- let _schema = Arc :: new ( Schema :: new ( vec ! [
693
- Field :: new( "l_id" , DataType :: Int32 , false ) ,
694
- Field :: new( "r_id" , DataType :: Int32 , false ) ,
695
- ] ) ) ;
696
-
697
- let _column_indices = vec ! [
698
- ColumnIndex {
699
- index: 0 ,
700
- side: JoinSide :: Left ,
701
- } ,
702
- ColumnIndex {
703
- index: 0 ,
704
- side: JoinSide :: Right ,
705
- } ,
706
- ] ;
707
-
708
- // Verify OUTPUT_BUFFER_LIMIT constant
709
- assert_eq ! ( OUTPUT_BUFFER_LIMIT , 8192 ) ;
710
-
711
- println ! ( "Test passed: simplified NLJ structures compile correctly" ) ;
712
- }
713
- }
0 commit comments