diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala index df937dbbf3..500f9996b9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala @@ -21,15 +21,14 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, CumeDist, CurrentRow, DenseRank, Expression, Lag, Lead, Literal, MakeDecimal, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowFrame, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Cast, CumeDist, CurrentRow, DenseRank, Divide, Expression, Lag, Lead, Literal, MakeDecimal, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowFrame, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnscaledValue, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Complete, Count, First, Last, Max, Min, Sum} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, DecimalType, LongType, NumericType} -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.{DataType, DateType, Decimal, DecimalType, DoubleType, LongType, NumericType} import com.google.common.base.Objects @@ -50,12 +49,11 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { val output = op.child.output - val winExprs: Array[WindowExpression] = op.windowExpression.map { - case Alias(w: WindowExpression, _) => w - case Alias(MakeDecimal(w: WindowExpression, _, _, _), _) => w - case other => - withFallbackReason(op, s"Unsupported window expression: $other", other) + val winExprs: Array[WindowExpressionInfo] = op.windowExpression.map { expr => + extractWindowExpression(expr).getOrElse { + withFallbackReason(op, s"Unsupported window expression: $expr", expr) return None + } }.toArray if (winExprs.length != op.windowExpression.length) { @@ -80,7 +78,9 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { // operator itself carries a fallback attribution. Without this, the plan // prints a bare `Window` and the real reason lives on a sub-expression // that isn't obvious in the standard explain output. - val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, None) => we } ++ + val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, None) => + we.windowExpression + } ++ op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++ op.orderSpec.zip(sortOrders).collect { case (e, None) => e } withFallbackReason(op, failing: _*) @@ -88,10 +88,111 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { } } - private def windowExprToProto( + private case class WindowExpressionInfo( + windowExpression: WindowExpression, + resultDataType: DataType) + + private def extractWindowExpression(expr: Expression): Option[WindowExpressionInfo] = { + expr match { + case Alias(child, _) => + extractWindowExpression(child) + case w: WindowExpression => + Some(WindowExpressionInfo(w, w.dataType)) + case m @ MakeDecimal(child, _, _, _) => + for { + info <- extractWindowExpression(child) + rewritten <- restoreDecimalAggregateInput(info.windowExpression) + } yield { + info.copy(windowExpression = rewritten, resultDataType = m.dataType) + } + case c @ Cast(Divide(child, divisor, _), _: DecimalType, _, _) => + for { + info <- extractWindowExpression(child) + rewritten <- restoreDecimalAverageInput(info.windowExpression, divisor) + } yield { + info.copy(windowExpression = rewritten, resultDataType = c.dataType) + } + case _ => + None + } + } + + // Spark's DecimalAggregates rule wraps decimal SUM / AVG window aggregates + // around UnscaledValue plus rescaling arithmetic. Comet's native decimal + // aggregates expect the original decimal child, so restore that child only + // for the exact wrapper shapes emitted by DecimalAggregates. + private def restoreDecimalAggregateInput( + windowExpr: WindowExpression): Option[WindowExpression] = { + restoreDecimalSumInput(windowExpr).orElse(restoreDecimalAverageInput(windowExpr).map(_._1)) + } + + private def restoreDecimalAverageInput( windowExpr: WindowExpression, + divisor: Expression): Option[WindowExpression] = { + restoreDecimalAverageInput(windowExpr) + .filter { case (_, scale) => + isExpectedDecimalAverageDivisor(divisor, scale) + } + .map(_._1) + } + + private def restoreDecimalAverageInput( + windowExpr: WindowExpression): Option[(WindowExpression, Int)] = { + var scale: Option[Int] = None + val rewritten = windowExpr + .transform { case agg @ AggregateExpression(avg: Average, _, _, _, _) => + avg.child match { + case UnscaledValue(child) => + child.dataType match { + case dt: DecimalType => + scale = Some(dt.scale) + agg.copy(aggregateFunction = avg.copy(child = child)) + case _ => + agg + } + case _ => + agg + } + } + .asInstanceOf[WindowExpression] + scale.map(rewritten -> _) + } + + private def restoreDecimalSumInput(windowExpr: WindowExpression): Option[WindowExpression] = { + var restored = false + val rewritten = windowExpr + .transform { case agg @ AggregateExpression(sum: Sum, _, _, _, _) => + sum.child match { + case UnscaledValue(child) if child.dataType.isInstanceOf[DecimalType] => + restored = true + agg.copy(aggregateFunction = sum.copy(child = child)) + case _ => + agg + } + } + .asInstanceOf[WindowExpression] + if (restored) { + Some(rewritten) + } else { + None + } + } + + private def isExpectedDecimalAverageDivisor(divisor: Expression, scale: Int): Boolean = { + val expected = BigDecimal(10).pow(scale) + divisor match { + case Literal(value: java.lang.Double, DoubleType) => + BigDecimal(value.toString) == expected + case _ => + false + } + } + + private def windowExprToProto( + windowExprInfo: WindowExpressionInfo, output: Seq[Attribute], conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = { + val windowExpr = windowExprInfo.windowExpression val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap { expr => expr match { @@ -453,7 +554,7 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { val spec = OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build() - val resultTypeProto = serializeDataType(windowExpr.dataType) + val resultTypeProto = serializeDataType(windowExprInfo.resultDataType) if (builtinFunc.isDefined) { val b = OperatorOuterClass.WindowExpr diff --git a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala index 04078a309b..0e6fcf8946 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala @@ -19,17 +19,24 @@ package org.apache.comet.exec +import scala.util.Random + import org.scalactic.source.Position import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, Row} +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Divide, Expression, MakeDecimal, WindowExpression} import org.apache.spark.sql.comet.CometWindowExec +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.window.{WindowExec => SparkWindowExec} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, lead, sum} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DecimalType import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus class CometWindowExecSuite extends CometTestBase { @@ -50,6 +57,66 @@ class CometWindowExecSuite extends CometTestBase { } } + private def assertCometWindowExecExists(plan: SparkPlan): Unit = { + val cometWindowExecs = collect(plan) { case w: CometWindowExec => + w + } + assert(cometWindowExecs.nonEmpty) + } + + private def sparkWindowExpressions(plan: SparkPlan): Seq[Expression] = { + collect(plan) { case w: SparkWindowExec => + w.windowExpression + }.flatten + } + + private def assertSparkPlanHasDecimalAvgRewrite(plan: SparkPlan): Unit = { + val windowExprs = sparkWindowExpressions(plan) + assert( + windowExprs.exists { + case Alias(Cast(Divide(_: WindowExpression, _, _), _: DecimalType, _, _), _) => true + case _ => false + }, + s"Expected Spark decimal AVG rewrite in window expressions, but found: $windowExprs") + } + + private def assertSparkPlanHasDecimalSumRewrite(plan: SparkPlan): Unit = { + val windowExprs = sparkWindowExpressions(plan) + assert( + windowExprs.exists { + case Alias(MakeDecimal(_: WindowExpression, _, _, _), _) => true + case _ => false + }, + s"Expected Spark decimal SUM rewrite in window expressions, but found: $windowExprs") + } + + private def randomDecimalString(rng: Random, precision: Int, scale: Int): String = { + val maxIntegerDigits = math.max(1, precision - scale) + val integerDigits = 1 + rng.nextInt(math.min(maxIntegerDigits, 6)) + val integerPart = randomDigits(rng, integerDigits, allowLeadingZero = false) + val fractionalPart = randomDigits(rng, scale, allowLeadingZero = true) + val sign = if (rng.nextBoolean()) "-" else "" + if (scale == 0) { + s"$sign$integerPart" + } else { + s"$sign$integerPart.$fractionalPart" + } + } + + private def randomDigits(rng: Random, length: Int, allowLeadingZero: Boolean): String = { + val digits = new StringBuilder(length) + (0 until length).foreach { pos => + val digit = + if (pos == 0 && !allowLeadingZero) { + 1 + rng.nextInt(9) + } else { + rng.nextInt(10) + } + digits.append(('0' + digit).toChar) + } + digits.toString() + } + test("lead/lag should return the default value if the offset row does not exist") { withSQLConf( CometConf.COMET_ENABLED.key -> "true", @@ -354,6 +421,112 @@ class CometWindowExecSuite extends CometTestBase { } } + test("window: decimal AVG with PARTITION BY and ORDER BY") { + withTempDir { dir => + Seq((1, "10.10"), (1, "20.25"), (1, "30.33"), (1, "41.00"), (2, "11.11"), (2, "22.22")) + .toDF("g", "raw_v") + .selectExpr("g", "CAST(raw_v AS DECIMAL(10,2)) AS v") + .repartition(1) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("dec_avg") + val df = sql(""" + SELECT g, v, run_avg + FROM ( + SELECT g, v, + AVG(v) OVER ( + PARTITION BY g + ORDER BY v + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS run_avg + FROM dec_avg + ) + ORDER BY g, v + """) + val (sparkPlan, cometPlan) = checkSparkAnswerAndOperator(df) + if (isSpark40Plus) { + assertSparkPlanHasDecimalAvgRewrite(sparkPlan) + } + assertCometWindowExecExists(cometPlan) + } + } + + test("window: decimal AVG fuzz with PARTITION BY and ORDER BY") { + Seq((9, 1), (9, 4), (10, 2), (11, 3), (11, 6)).foreach { case (precision, scale) => + withTempDir { dir => + val rng = new Random(precision * 31L + scale) + (0 until 120) + .map { i => + (i % 7, i, randomDecimalString(rng, precision, scale)) + } + .toDF("g", "ord", "raw_v") + .selectExpr("g", "ord", s"CAST(raw_v AS DECIMAL($precision,$scale)) AS v") + .repartition(1) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("dec_avg_fuzz") + val df = sql(""" + SELECT g, ord, v, run_avg + FROM ( + SELECT g, ord, v, + AVG(v) OVER ( + PARTITION BY g + ORDER BY ord + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS run_avg + FROM dec_avg_fuzz + ) + ORDER BY g, ord + """) + val (sparkPlan, cometPlan) = checkSparkAnswerAndOperator(df) + if (isSpark40Plus) { + assertSparkPlanHasDecimalAvgRewrite(sparkPlan) + } + assertCometWindowExecExists(cometPlan) + } + } + } + + test("window: decimal SUM with PARTITION BY and ORDER BY") { + withTempDir { dir => + Seq( + (1, 1, "10.10"), + (1, 2, "20.25"), + (1, 3, "-5.35"), + (2, 1, "11.11"), + (2, 2, "22.22"), + (2, 3, "33.33")) + .toDF("g", "ord", "raw_v") + .selectExpr("g", "ord", "CAST(raw_v AS DECIMAL(8,2)) AS v") + .repartition(1) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("dec_sum") + val df = sql(""" + SELECT g, ord, v, run_sum + FROM ( + SELECT g, ord, v, + SUM(v) OVER ( + PARTITION BY g + ORDER BY ord + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS run_sum + FROM dec_sum + ) + ORDER BY g, ord + """) + val (sparkPlan, cometPlan) = checkSparkAnswerAndOperator(df) + assertSparkPlanHasDecimalSumRewrite(sparkPlan) + assertCometWindowExecExists(cometPlan) + } + } + test("window: MIN and MAX with ORDER BY") { withTempDir { dir => (0 until 30)