fix: support Spark 4 decimal window avg#4749
Conversation
bea8bf7 to
c401884
Compare
comphead
left a comment
There was a problem hiding this comment.
Thanks @manuzhang makes sense to me.
Appreciate if you can enclose the exact message of fallback to the PR description
CI pending.
andygrove
left a comment
There was a problem hiding this comment.
Thanks for tracking this down. I verified the wrapper shapes against the Spark DecimalAggregates rule in Optimizer.scala (v4.2.0-preview4), and the diagnosis and overall approach look right. I left a few inline comments. The main one to consider is the precise-decimal vs double-average difference, since that is the part most likely to surface a real Spark mismatch and it is not visible in the current test. The other two are about tightening the match and adding SUM coverage.
| case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) => | ||
| extractWindowExpression(child).map { info => | ||
| info.copy( | ||
| windowExpression = restoreDecimalAggregateInput(info.windowExpression), | ||
| resultDataType = c.dataType) | ||
| } |
There was a problem hiding this comment.
This match extracts the window expression from the numerator and discards the divisor. That substitution is only correct for the exact shape DecimalAggregates emits: divisor Literal(10^scale, DoubleType) and inner Average(UnscaledValue(decimal)). For any other Cast(Divide(windowExpr, X), decimal) the divisor would be silently dropped and the result would be wrong.
I believe it is safe today, because DecimalAggregates is the only rule that injects a Cast/Divide directly inside a WindowExec.windowExpression (it runs in the optimizer, after ExtractWindowExpressions has pulled window expressions into the operator). A user-written CAST(SUM(x) OVER w / 2 AS DECIMAL) keeps its cast and divide in the Project above the window, so it never reaches here.
Could we tighten the guard to verify the divisor is the expected 10^scale literal and the inner is Average(UnscaledValue(_)), and fall back otherwise? That way an unexpected shape degrades to a Spark fallback instead of a silently dropped divide. A short comment naming DecimalAggregates as the source of this shape would also help the next reader.
| case agg @ AggregateExpression(avg: Average, _, _, _, _) => | ||
| avg.child match { | ||
| case UnscaledValue(child) if child.dataType.isInstanceOf[DecimalType] => | ||
| agg.copy(aggregateFunction = avg.copy(child = child)) |
There was a problem hiding this comment.
Restoring the decimal child here makes Comet run its precise avg_decimal UDAF (decimal sum at decimal(p+10, s), decimal division with HALF_UP). But the plan that DecimalAggregates produced computes the average in Double: Average(UnscaledValue(e)) is a double average, divided by 10^scale as a double, then cast to decimal(p+4, s+4). So for the case this PR targets we are substituting precise decimal arithmetic for Spark's floating-point arithmetic, and the two can round differently at the s+4 scale.
Spark only takes this double path when prec + 4 <= 15, so values stay in double's exact integer range and any divergence would be rare and in the last digit, but it is not guaranteed to be zero. For contrast, the regular non-window decimal AVG stays faithful: the HashAggregate computes the double average and a Project does the divide and cast, matching Spark exactly. That asymmetry is what worries me here.
Could we add a fuzz or property test that runs many random decimal values across a few precision and scale combinations through window AVG and compares against Spark? The clean values in the new test all happen to round the same way in both schemes, so they would not catch this.
| case agg @ AggregateExpression(sum: Sum, _, _, _, _) => | ||
| sum.child match { | ||
| case UnscaledValue(child) if child.dataType.isInstanceOf[DecimalType] => | ||
| agg.copy(aggregateFunction = sum.copy(child = child)) | ||
| case _ => | ||
| agg | ||
| } |
There was a problem hiding this comment.
This Sum branch and the MakeDecimal result-type handling are a behavior change from the old case Alias(MakeDecimal(w: WindowExpression, _, _, _), _) => w, which extracted the bare window expression and used windowExpr.dataType. The new test only covers AVG, so this branch is unexercised. Could we add a decimal SUM over a window with small enough precision to trigger the MakeDecimal rewrite (prec + 10 <= 18)? That exercises the other half of the new extraction logic.
Spark 4 represents decimal window AVG results with a Cast(Divide(WindowExpression, ...)) wrapper. Teach Comet window serialization to unwrap that shape while preserving the wrapper result type, and add a regression test for decimal AVG over partitioned ordered windows. Co-authored-by: Codex <codex@openai.com>
c401884 to
63e5e51
Compare
|
Thanks for the updates. The divisor guard plus the two new tests address all three of my earlier comments. On the precise-decimal vs double-average question I raised, I dug into it further and I am now satisfied there is no real divergence to worry about. I built the worst case on purpose: 32 rows in one partition (a power-of-two count, so The reason is Spark's cast from the So the approach of restoring the decimal child is sound within the range the rewrite applies to, and the fuzz test is good defensive coverage on top of that. No further concerns from me. |
Which issue does this PR close?
Closes #4731.
Rationale for this change
Spark 4 applies
DecimalAggregatesto decimal window aggregates. For decimalAVG, Spark can represent the result asCast(Divide(WindowExpression, Literal(10^scale, DoubleType)), DecimalType, ...); for decimalSUM, Spark can wrap the window expression inMakeDecimalwhen the widened precision still fits in a long.Comet's window converter previously extracted bare
WindowExpressionvalues andMakeDecimal(WindowExpression, ...), but it did not handle Spark 4's decimalAVGcast/divide wrapper. That caused decimal windowAVGto fall back to Spark instead of using native Comet execution.What changes are included in this PR?
This PR updates
CometWindowExecto recognize Spark's decimal window aggregate wrappers fromDecimalAggregateswhile preserving Spark's final result type:AVGonly for the exact Spark shape withAverage(UnscaledValue(decimalChild))and aLiteral(10^scale, DoubleType)divisorSUMfrom Spark'sMakeDecimal(WindowExpression(...))wrapper when the inner aggregate isSum(UnscaledValue(decimalChild))AVG, randomized decimalAVGprecision/scale combinations, and decimalSUMHow are these changes tested?
The window tests compare Spark and Comet answers and assert that the executed Comet plan contains
CometWindowExec, so they verify the decimal paths do not silently fall back.git diff --checkmake core./mvnw test -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false./mvnw test -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG fuzz with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false./mvnw test -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal SUM with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false./mvnw test -Pspark-4.0 -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false./mvnw test -Pjdk17,spark-3.5 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=falseJAVA_HOME=/opt/homebrew/Cellar/openjdk@17/17.0.9/libexec/openjdk.jdk/Contents/Home ./mvnw test -Pspark-3.4 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=falseCo-authored-by: @codex