feat(compute): on-device bf16 broadcast + scalar ops (capture-safe)#165
Merged
Conversation
bf16 broadcast binary ops (Add/Sub/Mul/Div with mismatched shapes) and scalar ops (AddScalar/MulScalar/DivScalar) fell back to the CPU engine, whose host memcpy breaks CUDA-graph capture. This blocked capture-ON bf16 CrossAsset training at QKL2Norm, which does a column-broadcast Mul(x, inv) and AddScalar(eps). Route these bf16 ops through on-device f32 conversion instead of the CPU: convert operands bf16->f32 (existing BF16ToF32 kernel), run the existing f32 broadcast/ scalar kernel, convert the result f32->bf16 -- all on the engine stream, so the op stays on the GPU and capturable. Computing in f32 matches the bf16 GEMM/reduction convention (f32 accumulation, bf16 storage). No new CUDA kernels and no .so change: reuses the f32 kernels (raw-pointer ABI) + the existing conversion kernels. Scratch is arena-allocated and freed stream-ordered, the same pattern as getDevicePtr's FP16->F32 scratch. - gpuBroadcastOp: both kernel-exec sites go through execBroadcast2D (bf16 -> convert-via-f32, f32 -> direct). - gpuBroadcast4DOp: signals not-handled for bf16 (B=1 CrossAsset uses only 2D broadcasts; bf16 4D broadcast is a follow-up) so it never runs the f32 4D kernel on 2-byte data. - gpuAddScalar/gpuMulScalar/gpuDivScalar: bf16 -> gpuScalarOpBF16. CUDA-gated parity test: bf16 column-broadcast Mul + AddScalar vs f32 reference. Completes the bf16 GPU op surface for capture-ON CrossAsset training. ADR-075 L4.
The CrossAttention test op (added in cc6948a) captured its Forward intermediates (softmax weights attn + Q/K/V) via closure and assumed no arena reset between Forward and Backward. The testing/parity reset-between-fwd-bwd schedule DOES Reset the arena between the passes, so those arena-backed tensors were freed and Backward read garbage -- TestRun_HostArenaStress_RegistryGreen failed with CrossAttention bwd max_abs=+Inf (regressed on main; passes at v1.15.0..v1.17.1). Add an optional opNode.extraSaves hook: when set, Forward registers the returned intermediates with the Saver (graph.SaverAware, ADR 006) alongside the output, so they survive the reset (the same mechanism layerNorm/groupNorm/adaLN/timestepEmbed already use). CrossAttention sets it to pin attn + Q/K/V. gradcheck (no reset between passes) is unaffected. Unblocks ztensor main CI. Unrelated to the bf16 broadcast/scalar change in this branch's other commit; folded in to restore a green main in one cycle.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
bf16 broadcast binary ops (Add/Sub/Mul/Div, mismatched shapes) and scalar ops (AddScalar/MulScalar/DivScalar) now run on-device instead of falling back to the CPU engine.
Why
The CPU fallback's host memcpy breaks CUDA-graph capture. This blocked capture-ON bf16 CrossAsset training at QKL2Norm, which does a column-broadcast
Mul(x, inv)andAddScalar(eps)— both bf16 → CPU → capture failure.How
Route these bf16 ops through on-device f32 conversion:
BF16ToF32operands → existing f32 broadcast/scalar kernel →F32ToBF16result, all on the engine stream. No new CUDA kernels and no.sochange — reuses the f32 kernels (raw-pointer ABI) + existing conversion kernels. Computing in f32 matches the bf16 GEMM/reduction convention. Scratch is arena-allocated, freed stream-ordered (same pattern asgetDevicePtr's FP16→F32 scratch).gpuBroadcastOp: both kernel-exec sites go throughexecBroadcast2D(bf16 → convert-via-f32; f32 → direct, unchanged).gpuBroadcast4DOp: signals not-handled for bf16 (B=1 CrossAsset uses only 2D broadcasts; bf16 4D is a follow-up) so it never runs the f32 4D kernel on 2-byte data.gpuAddScalar/gpuMulScalar/gpuDivScalar: bf16 →gpuScalarOpBF16.Verification
computetests green.TestGPUBF16_BroadcastAndScalarParity(column-broadcast Mul + AddScalar) —ok compute 3.383s.Completes the bf16 GPU op surface for capture-ON CrossAsset training. Final piece of the chain (ztensor v1.16.0 NT/TN, v1.17.0 transpose, v1.17.1 reshape; zerfoo v1.53.1 grad-accum). ADR-075 L4.