Skip to content

fix: support Spark 4 decimal window avg#4749

Open
manuzhang wants to merge 2 commits into
apache:mainfrom
manuzhang:codex/issue-4731-decimal-window-avg
Open

fix: support Spark 4 decimal window avg#4749
manuzhang wants to merge 2 commits into
apache:mainfrom
manuzhang:codex/issue-4731-decimal-window-avg

Conversation

@manuzhang

@manuzhang manuzhang commented Jun 29, 2026

Copy link
Copy Markdown
Member

Which issue does this PR close?

Closes #4731.

Rationale for this change

Spark 4 applies DecimalAggregates to decimal window aggregates. For decimal AVG, Spark can represent the result as Cast(Divide(WindowExpression, Literal(10^scale, DoubleType)), DecimalType, ...); for decimal SUM, Spark can wrap the window expression in MakeDecimal when the widened precision still fits in a long.

Comet's window converter previously extracted bare WindowExpression values and MakeDecimal(WindowExpression, ...), but it did not handle Spark 4's decimal AVG cast/divide wrapper. That caused decimal window AVG to fall back to Spark instead of using native Comet execution.

What changes are included in this PR?

This PR updates CometWindowExec to recognize Spark's decimal window aggregate wrappers from DecimalAggregates while preserving Spark's final result type:

  • unwraps decimal AVG only for the exact Spark shape with Average(UnscaledValue(decimalChild)) and a Literal(10^scale, DoubleType) divisor
  • unwraps decimal SUM from Spark's MakeDecimal(WindowExpression(...)) wrapper when the inner aggregate is Sum(UnscaledValue(decimalChild))
  • falls back for unexpected wrapper shapes instead of assuming they are safe to rewrite
  • adds tests for decimal AVG, randomized decimal AVG precision/scale combinations, and decimal SUM

How are these changes tested?

The window tests compare Spark and Comet answers and assert that the executed Comet plan contains CometWindowExec, so they verify the decimal paths do not silently fall back.

  • git diff --check
  • make core
  • ./mvnw test -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false
  • ./mvnw test -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG fuzz with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false
  • ./mvnw test -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal SUM with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false
  • ./mvnw test -Pspark-4.0 -Pjdk17 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false
  • ./mvnw test -Pjdk17,spark-3.5 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false
  • JAVA_HOME=/opt/homebrew/Cellar/openjdk@17/17.0.9/libexec/openjdk.jdk/Contents/Home ./mvnw test -Pspark-3.4 -Dtest=none -Dsuites="org.apache.comet.exec.CometWindowExecSuite window: decimal AVG with PARTITION BY and ORDER BY" -Dscalastyle.skip=true -DfailIfNoTests=false

Co-authored-by: @codex

@manuzhang manuzhang force-pushed the codex/issue-4731-decimal-window-avg branch from bea8bf7 to c401884 Compare June 29, 2026 15:36
@andygrove andygrove requested a review from comphead June 29, 2026 16:03

@comphead comphead left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @manuzhang makes sense to me.
Appreciate if you can enclose the exact message of fallback to the PR description

CI pending.

@andygrove andygrove left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tracking this down. I verified the wrapper shapes against the Spark DecimalAggregates rule in Optimizer.scala (v4.2.0-preview4), and the diagnosis and overall approach look right. I left a few inline comments. The main one to consider is the precise-decimal vs double-average difference, since that is the part most likely to surface a real Spark mismatch and it is not visible in the current test. The other two are about tightening the match and adding SUM coverage.

Comment on lines +108 to +113
case c @ Cast(Divide(child, _, _), _: DecimalType, _, _) =>
extractWindowExpression(child).map { info =>
info.copy(
windowExpression = restoreDecimalAggregateInput(info.windowExpression),
resultDataType = c.dataType)
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This match extracts the window expression from the numerator and discards the divisor. That substitution is only correct for the exact shape DecimalAggregates emits: divisor Literal(10^scale, DoubleType) and inner Average(UnscaledValue(decimal)). For any other Cast(Divide(windowExpr, X), decimal) the divisor would be silently dropped and the result would be wrong.

I believe it is safe today, because DecimalAggregates is the only rule that injects a Cast/Divide directly inside a WindowExec.windowExpression (it runs in the optimizer, after ExtractWindowExpressions has pulled window expressions into the operator). A user-written CAST(SUM(x) OVER w / 2 AS DECIMAL) keeps its cast and divide in the Project above the window, so it never reaches here.

Could we tighten the guard to verify the divisor is the expected 10^scale literal and the inner is Average(UnscaledValue(_)), and fall back otherwise? That way an unexpected shape degrades to a Spark fallback instead of a silently dropped divide. A short comment naming DecimalAggregates as the source of this shape would also help the next reader.

Comment on lines +125 to +128
case agg @ AggregateExpression(avg: Average, _, _, _, _) =>
avg.child match {
case UnscaledValue(child) if child.dataType.isInstanceOf[DecimalType] =>
agg.copy(aggregateFunction = avg.copy(child = child))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restoring the decimal child here makes Comet run its precise avg_decimal UDAF (decimal sum at decimal(p+10, s), decimal division with HALF_UP). But the plan that DecimalAggregates produced computes the average in Double: Average(UnscaledValue(e)) is a double average, divided by 10^scale as a double, then cast to decimal(p+4, s+4). So for the case this PR targets we are substituting precise decimal arithmetic for Spark's floating-point arithmetic, and the two can round differently at the s+4 scale.

Spark only takes this double path when prec + 4 <= 15, so values stay in double's exact integer range and any divergence would be rare and in the last digit, but it is not guaranteed to be zero. For contrast, the regular non-window decimal AVG stays faithful: the HashAggregate computes the double average and a Project does the divide and cast, matching Spark exactly. That asymmetry is what worries me here.

Could we add a fuzz or property test that runs many random decimal values across a few precision and scale combinations through window AVG and compares against Spark? The clean values in the new test all happen to round the same way in both schemes, so they would not catch this.

Comment on lines +132 to +138
case agg @ AggregateExpression(sum: Sum, _, _, _, _) =>
sum.child match {
case UnscaledValue(child) if child.dataType.isInstanceOf[DecimalType] =>
agg.copy(aggregateFunction = sum.copy(child = child))
case _ =>
agg
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Sum branch and the MakeDecimal result-type handling are a behavior change from the old case Alias(MakeDecimal(w: WindowExpression, _, _, _), _) => w, which extracted the bare window expression and used windowExpr.dataType. The new test only covers AVG, so this branch is unexercised. Could we add a decimal SUM over a window with small enough precision to trigger the MakeDecimal rewrite (prec + 10 <= 18)? That exercises the other half of the new extraction logic.

Spark 4 represents decimal window AVG results with a Cast(Divide(WindowExpression, ...)) wrapper. Teach Comet window serialization to unwrap that shape while preserving the wrapper result type, and add a regression test for decimal AVG over partitioned ordered windows.

Co-authored-by: Codex <codex@openai.com>
@manuzhang manuzhang force-pushed the codex/issue-4731-decimal-window-avg branch from c401884 to 63e5e51 Compare June 30, 2026 00:41
@manuzhang manuzhang requested review from andygrove and comphead June 30, 2026 04:48
@andygrove

Copy link
Copy Markdown
Member

Thanks for the updates. The divisor guard plus the two new tests address all three of my earlier comments.

On the precise-decimal vs double-average question I raised, I dug into it further and I am now satisfied there is no real divergence to worry about. I built the worst case on purpose: 32 rows in one partition (a power-of-two count, so sum / count is exact in double), thirty-one 0.0000 and one 0.0003, where the average 3 / 32 / 10^4 = 0.000009375 lands exactly on the HALF_UP boundary at the result scale. Both Spark and Comet return 0.00000938.

The reason is Spark's cast from the double result back to decimal. It goes through Decimal.apply(BigDecimal.valueOf(d)) in Cast.scala, which uses Double.toString, the shortest round-trip string. That recovers 0.000009375 rather than the exact binary value 0.00000937499..., so it rounds up and matches Comet's exact avg_decimal. This is exactly why Spark gates the rewrite on prec + 4 <= 15 (MAX_DOUBLE_DIGITS): within 15 significant digits the double round-trips to a unique short decimal equal to the exactly-rounded value. To be thorough I also ran a large targeted search over rounding-boundary inputs across a range of scales and counts and found zero mismatches.

So the approach of restoring the decimal child is sound within the range the rewrite applies to, and the fuzz test is good defensive coverage on top of that. No further concerns from me.

@andygrove andygrove left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @manuzhang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AVG(decimal) over a window always falls back to Spark on Spark 4.x (AvgDecimal window branch is dead)

3 participants