Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 112 additions & 11 deletions spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ package org.apache.spark.sql.comet

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, CumeDist, CurrentRow, DenseRank, Expression, Lag, Lead, Literal, MakeDecimal, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowFrame, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Cast, CumeDist, CurrentRow, DenseRank, Divide, Expression, Lag, Lead, Literal, MakeDecimal, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowFrame, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnscaledValue, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Complete, Count, First, Last, Max, Min, Sum}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, DecimalType, LongType, NumericType}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.{DataType, DateType, Decimal, DecimalType, DoubleType, LongType, NumericType}

import com.google.common.base.Objects

Expand All @@ -50,12 +49,11 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
val output = op.child.output

val winExprs: Array[WindowExpression] = op.windowExpression.map {
case Alias(w: WindowExpression, _) => w
case Alias(MakeDecimal(w: WindowExpression, _, _, _), _) => w
case other =>
withFallbackReason(op, s"Unsupported window expression: $other", other)
val winExprs: Array[WindowExpressionInfo] = op.windowExpression.map { expr =>
extractWindowExpression(expr).getOrElse {
withFallbackReason(op, s"Unsupported window expression: $expr", expr)
return None
}
}.toArray

if (winExprs.length != op.windowExpression.length) {
Expand All @@ -80,18 +78,121 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
// operator itself carries a fallback attribution. Without this, the plan
// prints a bare `Window` and the real reason lives on a sub-expression
// that isn't obvious in the standard explain output.
val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, None) => we } ++
val failing = winExprs.toSeq.zip(windowExprProto).collect { case (we, None) =>
we.windowExpression
} ++
op.partitionSpec.zip(partitionExprs).collect { case (e, None) => e } ++
op.orderSpec.zip(sortOrders).collect { case (e, None) => e }
withFallbackReason(op, failing: _*)
None
}
}

private def windowExprToProto(
private case class WindowExpressionInfo(
windowExpression: WindowExpression,
resultDataType: DataType)

private def extractWindowExpression(expr: Expression): Option[WindowExpressionInfo] = {
expr match {
case Alias(child, _) =>
extractWindowExpression(child)
case w: WindowExpression =>
Some(WindowExpressionInfo(w, w.dataType))
case m @ MakeDecimal(child, _, _, _) =>
for {
info <- extractWindowExpression(child)
rewritten <- restoreDecimalAggregateInput(info.windowExpression)
} yield {
info.copy(windowExpression = rewritten, resultDataType = m.dataType)
}
case c @ Cast(Divide(child, divisor, _), _: DecimalType, _, _) =>
for {
info <- extractWindowExpression(child)
rewritten <- restoreDecimalAverageInput(info.windowExpression, divisor)
} yield {
info.copy(windowExpression = rewritten, resultDataType = c.dataType)
}
case _ =>
None
}
}

// Spark's DecimalAggregates rule wraps decimal SUM / AVG window aggregates
// around UnscaledValue plus rescaling arithmetic. Comet's native decimal
// aggregates expect the original decimal child, so restore that child only
// for the exact wrapper shapes emitted by DecimalAggregates.
private def restoreDecimalAggregateInput(
windowExpr: WindowExpression): Option[WindowExpression] = {
restoreDecimalSumInput(windowExpr).orElse(restoreDecimalAverageInput(windowExpr).map(_._1))
}

private def restoreDecimalAverageInput(
windowExpr: WindowExpression,
divisor: Expression): Option[WindowExpression] = {
restoreDecimalAverageInput(windowExpr)
.filter { case (_, scale) =>
isExpectedDecimalAverageDivisor(divisor, scale)
}
.map(_._1)
}

private def restoreDecimalAverageInput(
windowExpr: WindowExpression): Option[(WindowExpression, Int)] = {
var scale: Option[Int] = None
val rewritten = windowExpr
.transform { case agg @ AggregateExpression(avg: Average, _, _, _, _) =>
avg.child match {
case UnscaledValue(child) =>
child.dataType match {
case dt: DecimalType =>
scale = Some(dt.scale)
agg.copy(aggregateFunction = avg.copy(child = child))
case _ =>
agg
}
case _ =>
agg
}
}
.asInstanceOf[WindowExpression]
scale.map(rewritten -> _)
}

private def restoreDecimalSumInput(windowExpr: WindowExpression): Option[WindowExpression] = {
var restored = false
val rewritten = windowExpr
.transform { case agg @ AggregateExpression(sum: Sum, _, _, _, _) =>
sum.child match {
case UnscaledValue(child) if child.dataType.isInstanceOf[DecimalType] =>
restored = true
agg.copy(aggregateFunction = sum.copy(child = child))
case _ =>
agg
}
}
.asInstanceOf[WindowExpression]
if (restored) {
Some(rewritten)
} else {
None
}
}

private def isExpectedDecimalAverageDivisor(divisor: Expression, scale: Int): Boolean = {
val expected = BigDecimal(10).pow(scale)
divisor match {
case Literal(value: java.lang.Double, DoubleType) =>
BigDecimal(value.toString) == expected
case _ =>
false
}
}

private def windowExprToProto(
windowExprInfo: WindowExpressionInfo,
output: Seq[Attribute],
conf: SQLConf): Option[OperatorOuterClass.WindowExpr] = {
val windowExpr = windowExprInfo.windowExpression

val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap { expr =>
expr match {
Expand Down Expand Up @@ -453,7 +554,7 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
val spec =
OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build()

val resultTypeProto = serializeDataType(windowExpr.dataType)
val resultTypeProto = serializeDataType(windowExprInfo.resultDataType)

if (builtinFunc.isDefined) {
val b = OperatorOuterClass.WindowExpr
Expand Down
173 changes: 173 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,24 @@

package org.apache.comet.exec

import scala.util.Random

import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, Row}
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Divide, Expression, MakeDecimal, WindowExpression}
import org.apache.spark.sql.comet.CometWindowExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{WindowExec => SparkWindowExec}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, lead, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DecimalType

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus

class CometWindowExecSuite extends CometTestBase {

Expand All @@ -50,6 +57,66 @@ class CometWindowExecSuite extends CometTestBase {
}
}

private def assertCometWindowExecExists(plan: SparkPlan): Unit = {
val cometWindowExecs = collect(plan) { case w: CometWindowExec =>
w
}
assert(cometWindowExecs.nonEmpty)
}

private def sparkWindowExpressions(plan: SparkPlan): Seq[Expression] = {
collect(plan) { case w: SparkWindowExec =>
w.windowExpression
}.flatten
}

private def assertSparkPlanHasDecimalAvgRewrite(plan: SparkPlan): Unit = {
val windowExprs = sparkWindowExpressions(plan)
assert(
windowExprs.exists {
case Alias(Cast(Divide(_: WindowExpression, _, _), _: DecimalType, _, _), _) => true
case _ => false
},
s"Expected Spark decimal AVG rewrite in window expressions, but found: $windowExprs")
}

private def assertSparkPlanHasDecimalSumRewrite(plan: SparkPlan): Unit = {
val windowExprs = sparkWindowExpressions(plan)
assert(
windowExprs.exists {
case Alias(MakeDecimal(_: WindowExpression, _, _, _), _) => true
case _ => false
},
s"Expected Spark decimal SUM rewrite in window expressions, but found: $windowExprs")
}

private def randomDecimalString(rng: Random, precision: Int, scale: Int): String = {
val maxIntegerDigits = math.max(1, precision - scale)
val integerDigits = 1 + rng.nextInt(math.min(maxIntegerDigits, 6))
val integerPart = randomDigits(rng, integerDigits, allowLeadingZero = false)
val fractionalPart = randomDigits(rng, scale, allowLeadingZero = true)
val sign = if (rng.nextBoolean()) "-" else ""
if (scale == 0) {
s"$sign$integerPart"
} else {
s"$sign$integerPart.$fractionalPart"
}
}

private def randomDigits(rng: Random, length: Int, allowLeadingZero: Boolean): String = {
val digits = new StringBuilder(length)
(0 until length).foreach { pos =>
val digit =
if (pos == 0 && !allowLeadingZero) {
1 + rng.nextInt(9)
} else {
rng.nextInt(10)
}
digits.append(('0' + digit).toChar)
}
digits.toString()
}

test("lead/lag should return the default value if the offset row does not exist") {
withSQLConf(
CometConf.COMET_ENABLED.key -> "true",
Expand Down Expand Up @@ -354,6 +421,112 @@ class CometWindowExecSuite extends CometTestBase {
}
}

test("window: decimal AVG with PARTITION BY and ORDER BY") {
withTempDir { dir =>
Seq((1, "10.10"), (1, "20.25"), (1, "30.33"), (1, "41.00"), (2, "11.11"), (2, "22.22"))
.toDF("g", "raw_v")
.selectExpr("g", "CAST(raw_v AS DECIMAL(10,2)) AS v")
.repartition(1)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("dec_avg")
val df = sql("""
SELECT g, v, run_avg
FROM (
SELECT g, v,
AVG(v) OVER (
PARTITION BY g
ORDER BY v
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS run_avg
FROM dec_avg
)
ORDER BY g, v
""")
val (sparkPlan, cometPlan) = checkSparkAnswerAndOperator(df)
if (isSpark40Plus) {
assertSparkPlanHasDecimalAvgRewrite(sparkPlan)
}
assertCometWindowExecExists(cometPlan)
}
}

test("window: decimal AVG fuzz with PARTITION BY and ORDER BY") {
Seq((9, 1), (9, 4), (10, 2), (11, 3), (11, 6)).foreach { case (precision, scale) =>
withTempDir { dir =>
val rng = new Random(precision * 31L + scale)
(0 until 120)
.map { i =>
(i % 7, i, randomDecimalString(rng, precision, scale))
}
.toDF("g", "ord", "raw_v")
.selectExpr("g", "ord", s"CAST(raw_v AS DECIMAL($precision,$scale)) AS v")
.repartition(1)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("dec_avg_fuzz")
val df = sql("""
SELECT g, ord, v, run_avg
FROM (
SELECT g, ord, v,
AVG(v) OVER (
PARTITION BY g
ORDER BY ord
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS run_avg
FROM dec_avg_fuzz
)
ORDER BY g, ord
""")
val (sparkPlan, cometPlan) = checkSparkAnswerAndOperator(df)
if (isSpark40Plus) {
assertSparkPlanHasDecimalAvgRewrite(sparkPlan)
}
assertCometWindowExecExists(cometPlan)
}
}
}

test("window: decimal SUM with PARTITION BY and ORDER BY") {
withTempDir { dir =>
Seq(
(1, 1, "10.10"),
(1, 2, "20.25"),
(1, 3, "-5.35"),
(2, 1, "11.11"),
(2, 2, "22.22"),
(2, 3, "33.33"))
.toDF("g", "ord", "raw_v")
.selectExpr("g", "ord", "CAST(raw_v AS DECIMAL(8,2)) AS v")
.repartition(1)
.write
.mode("overwrite")
.parquet(dir.toString)

spark.read.parquet(dir.toString).createOrReplaceTempView("dec_sum")
val df = sql("""
SELECT g, ord, v, run_sum
FROM (
SELECT g, ord, v,
SUM(v) OVER (
PARTITION BY g
ORDER BY ord
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS run_sum
FROM dec_sum
)
ORDER BY g, ord
""")
val (sparkPlan, cometPlan) = checkSparkAnswerAndOperator(df)
assertSparkPlanHasDecimalSumRewrite(sparkPlan)
assertCometWindowExecExists(cometPlan)
}
}

test("window: MIN and MAX with ORDER BY") {
withTempDir { dir =>
(0 until 30)
Expand Down
Loading