Skip to content

feat(compute): on-device bf16 broadcast + scalar ops (capture-safe)#165

Merged
dndungu merged 2 commits into
mainfrom
feat/bf16-broadcast-scalar-via-f32
Jun 17, 2026
Merged

feat(compute): on-device bf16 broadcast + scalar ops (capture-safe)#165
dndungu merged 2 commits into
mainfrom
feat/bf16-broadcast-scalar-via-f32

Conversation

@dndungu

@dndungu dndungu commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

What

bf16 broadcast binary ops (Add/Sub/Mul/Div, mismatched shapes) and scalar ops (AddScalar/MulScalar/DivScalar) now run on-device instead of falling back to the CPU engine.

Why

The CPU fallback's host memcpy breaks CUDA-graph capture. This blocked capture-ON bf16 CrossAsset training at QKL2Norm, which does a column-broadcast Mul(x, inv) and AddScalar(eps) — both bf16 → CPU → capture failure.

How

Route these bf16 ops through on-device f32 conversion: BF16ToF32 operands → existing f32 broadcast/scalar kernel → F32ToBF16 result, all on the engine stream. No new CUDA kernels and no .so change — reuses the f32 kernels (raw-pointer ABI) + existing conversion kernels. Computing in f32 matches the bf16 GEMM/reduction convention. Scratch is arena-allocated, freed stream-ordered (same pattern as getDevicePtr's FP16→F32 scratch).

  • gpuBroadcastOp: both kernel-exec sites go through execBroadcast2D (bf16 → convert-via-f32; f32 → direct, unchanged).
  • gpuBroadcast4DOp: signals not-handled for bf16 (B=1 CrossAsset uses only 2D broadcasts; bf16 4D is a follow-up) so it never runs the f32 4D kernel on 2-byte data.
  • gpuAddScalar/gpuMulScalar/gpuDivScalar: bf16 → gpuScalarOpBF16.

Verification

  • Build + vet + non-CUDA compute tests green.
  • GB10 (sm_121): full bf16 suite GREEN incl. new TestGPUBF16_BroadcastAndScalarParity (column-broadcast Mul + AddScalar) — ok compute 3.383s.

Completes the bf16 GPU op surface for capture-ON CrossAsset training. Final piece of the chain (ztensor v1.16.0 NT/TN, v1.17.0 transpose, v1.17.1 reshape; zerfoo v1.53.1 grad-accum). ADR-075 L4.

dndungu added 2 commits June 17, 2026 00:24
bf16 broadcast binary ops (Add/Sub/Mul/Div with mismatched shapes) and scalar
ops (AddScalar/MulScalar/DivScalar) fell back to the CPU engine, whose host
memcpy breaks CUDA-graph capture. This blocked capture-ON bf16 CrossAsset
training at QKL2Norm, which does a column-broadcast Mul(x, inv) and AddScalar(eps).

Route these bf16 ops through on-device f32 conversion instead of the CPU: convert
operands bf16->f32 (existing BF16ToF32 kernel), run the existing f32 broadcast/
scalar kernel, convert the result f32->bf16 -- all on the engine stream, so the
op stays on the GPU and capturable. Computing in f32 matches the bf16
GEMM/reduction convention (f32 accumulation, bf16 storage). No new CUDA kernels
and no .so change: reuses the f32 kernels (raw-pointer ABI) + the existing
conversion kernels. Scratch is arena-allocated and freed stream-ordered, the
same pattern as getDevicePtr's FP16->F32 scratch.

- gpuBroadcastOp: both kernel-exec sites go through execBroadcast2D (bf16 ->
  convert-via-f32, f32 -> direct).
- gpuBroadcast4DOp: signals not-handled for bf16 (B=1 CrossAsset uses only 2D
  broadcasts; bf16 4D broadcast is a follow-up) so it never runs the f32 4D
  kernel on 2-byte data.
- gpuAddScalar/gpuMulScalar/gpuDivScalar: bf16 -> gpuScalarOpBF16.

CUDA-gated parity test: bf16 column-broadcast Mul + AddScalar vs f32 reference.
Completes the bf16 GPU op surface for capture-ON CrossAsset training. ADR-075 L4.
The CrossAttention test op (added in cc6948a) captured its Forward intermediates
(softmax weights attn + Q/K/V) via closure and assumed no arena reset between
Forward and Backward. The testing/parity reset-between-fwd-bwd schedule DOES
Reset the arena between the passes, so those arena-backed tensors were freed and
Backward read garbage -- TestRun_HostArenaStress_RegistryGreen failed with
CrossAttention bwd max_abs=+Inf (regressed on main; passes at v1.15.0..v1.17.1).

Add an optional opNode.extraSaves hook: when set, Forward registers the returned
intermediates with the Saver (graph.SaverAware, ADR 006) alongside the output, so
they survive the reset (the same mechanism layerNorm/groupNorm/adaLN/timestepEmbed
already use). CrossAttention sets it to pin attn + Q/K/V. gradcheck (no reset
between passes) is unaffected.

Unblocks ztensor main CI. Unrelated to the bf16 broadcast/scalar change in this
branch's other commit; folded in to restore a green main in one cycle.
@dndungu dndungu merged commit ed96180 into main Jun 17, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant