2020package org .apache .comet .rules
2121
2222import scala .collection .mutable .ListBuffer
23-
2423import org .apache .spark .sql .SparkSession
2524import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
2625import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
@@ -30,16 +29,15 @@ import org.apache.spark.sql.catalyst.rules.Rule
3029import org .apache .spark .sql .comet ._
3130import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec , CometShuffleManager }
3231import org .apache .spark .sql .execution ._
33- import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AQEShuffleReadExec , BroadcastQueryStageExec , ShuffleQueryStageExec }
32+ import org .apache .spark .sql .execution .adaptive .{AQEShuffleReadExec , AdaptiveSparkPlanExec , BroadcastQueryStageExec , ShuffleQueryStageExec }
3433import org .apache .spark .sql .execution .aggregate .{BaseAggregateExec , HashAggregateExec , ObjectHashAggregateExec }
3534import org .apache .spark .sql .execution .command .ExecutedCommandExec
3635import org .apache .spark .sql .execution .datasources .v2 .V2CommandExec
37- import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ReusedExchangeExec , ShuffleExchangeExec }
36+ import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , Exchange , ReusedExchangeExec , ShuffleExchangeExec }
3837import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , ShuffledHashJoinExec , SortMergeJoinExec }
3938import org .apache .spark .sql .execution .window .WindowExec
4039import org .apache .spark .sql .internal .SQLConf
4140import org .apache .spark .sql .types ._
42-
4341import org .apache .comet .{CometConf , ExtendedExplainInfo }
4442import org .apache .comet .CometConf .COMET_EXEC_SHUFFLE_ENABLED
4543import org .apache .comet .CometSparkSessionExtensions ._
@@ -608,18 +606,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
608606 }
609607 }
610608
611- plan.transformUp { case op =>
612- val newOp = convertNode(op)
613- // if newOp is not columnar and newOp.children has columnar, we need to add columnar to row
614- if (! newOp.supportsColumnar && ! newOp.isInstanceOf [ColumnarToRowTransition ]) {
615- val newChildren = newOp.children.map {
609+ val newPlan = plan.transformUp { case op =>
610+ convertNode(op)
611+ }
612+
613+ // insert CometColumnarToRowExec if necessary
614+ newPlan.transformUp {
615+ case c2r : ColumnarToRowTransition => c2r
616+ case op if ! op.supportsColumnar =>
617+ val newChildren = op.children.map {
618+ // CometExec already handles columnar to row conversion internally
619+ // Don't explicitly add CometColumnarToRowExec helps broadcast reuse,
620+ // for plan like: BroadcastExchangeExec(CometExec)
621+ case cometExec : CometExec => cometExec
616622 case c if c.supportsColumnar => CometColumnarToRowExec (c)
617623 case other => other
618624 }
619- newOp.withNewChildren(newChildren)
620- } else {
621- newOp
622- }
625+ op.withNewChildren(newChildren)
626+ case o => o
623627 }
624628 }
625629
0 commit comments