Skip to content

Commit 4c7fe64

Browse files
committed
fix broadcast
1 parent 533a7ad commit 4c7fe64

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
341341

342342
case op: BroadcastHashJoinExec
343343
if CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
344-
op.children.forall(isCometNative) =>
344+
// check has columnar broadcast child
345+
op.children.exists {
346+
case CometSinkPlaceHolder(_, _, _, true) => true
347+
case _ => false
348+
} =>
345349
newPlanWithProto(op) { case (newPlan, operator) =>
346350
CometBroadcastHashJoinExec(
347351
operator,
@@ -464,15 +468,15 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
464468
// For AQE broadcast stage on a Comet broadcast exchange
465469
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
466470
newPlanWithProto(s) { case (newPlan, operator) =>
467-
CometSinkPlaceHolder(operator, newPlan, newPlan)
471+
CometSinkPlaceHolder(operator, newPlan, newPlan, isBroadcast = true)
468472
}
469473

470474
case s @ BroadcastQueryStageExec(
471475
_,
472476
ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
473477
_) =>
474478
newPlanWithProto(s) { case (newPlan, operator) =>
475-
CometSinkPlaceHolder(operator, newPlan, newPlan)
479+
CometSinkPlaceHolder(operator, newPlan, newPlan, isBroadcast = true)
476480
}
477481

478482
// `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast
@@ -487,7 +491,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
487491
QueryPlanSerde.operator2Proto(b) match {
488492
case Some(nativeOp) =>
489493
val cometOp = CometBroadcastExchangeExec(b, b.output, b.mode, b.child)
490-
CometSinkPlaceHolder(nativeOp, b, cometOp)
494+
CometSinkPlaceHolder(nativeOp, b, cometOp, isBroadcast = true)
491495
case None => b
492496
}
493497
case other => other
@@ -706,7 +710,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
706710

707711
// Remove placeholders
708712
newPlan = newPlan.transform {
709-
case CometSinkPlaceHolder(_, _, s) => s
713+
case CometSinkPlaceHolder(_, _, s, _) => s
710714
case CometScanWrapper(_, s) => s
711715
}
712716

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,8 @@ case class CometScanWrapper(override val nativeOp: Operator, override val origin
998998
case class CometSinkPlaceHolder(
999999
override val nativeOp: Operator, // Must be a Scan
10001000
override val originalPlan: SparkPlan,
1001-
child: SparkPlan)
1001+
child: SparkPlan,
1002+
isBroadcast: Boolean = false)
10021003
extends CometUnaryExec {
10031004
override val serializedPlanOpt: SerializedPlan = SerializedPlan(None)
10041005

0 commit comments

Comments
 (0)