Skip to content

fix(compute): keep bf16 GPUStorage reshape on-device (capture-safe)#162

Merged
dndungu merged 1 commit into
mainfrom
fix/bf16-reshape-gpu-view
Jun 17, 2026
Merged

fix(compute): keep bf16 GPUStorage reshape on-device (capture-safe)#162
dndungu merged 1 commit into
mainfrom
fix/bf16-reshape-gpu-view

Conversation

@dndungu

@dndungu dndungu commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

What

GPUEngine.Reshape now takes the zero-copy GPU view path for bf16 GPUStorage[bf16], not just f32.

Why

Reshape's GPUStorage[T] view branch was gated on isFloat32[T](), so a bf16 tensor fell through to e.cpu.Reshape, producing a host tensor. That host tensor then forced the next op — the Transpose feeding QKL2Norm — onto the CPU engine, whose host memcpy breaks CUDA-graph capture (operation would make the legacy stream depend on a capturing blocking stream). So even with the native bf16 transpose kernels (v1.17.0), the bf16 CrossAsset GPU bench still could not capture.

Fix

Reshape is a pure metadata/view operation (no data movement), valid for any element type backed by GPUStorage[T]. Allow bf16 on the GPU view path: isFloat32[T]() || isBFloat16[T](). Go-only change — no kernel, no .so rebuild.

Verification

  • Build + vet + non-CUDA compute tests green.
  • CUDA-gated TestGPUBF16_ReshapeStaysOnDevice: a GPU-resident bf16 tensor reshaped stays *GPUStorage[bf16] (not CPU), data preserved.

Final piece letting the bf16 CrossAsset GPU backward run with CUDA-graph capture ON (representative s/epoch). Chain: ztensor v1.16.0 (NT/TN) + v1.17.0 (transpose kernels) + zerfoo v1.53.1 (grad-accum). ADR-075 lever L4.

GPUEngine.Reshape only took the zero-copy GPUStorage[T] view path for float32
(isFloat32[T] gate); a bf16 GPUStorage[bf16] tensor fell through to e.cpu.Reshape,
producing a host tensor. That host tensor then forced the next op -- the Transpose
feeding QKL2Norm -- onto the CPU engine, whose host memcpy breaks CUDA-graph
capture ("operation would make the legacy stream depend on a capturing blocking
stream"). So even with native bf16 transpose kernels (v1.17.0), the bf16 CrossAsset
GPU bench still could not capture.

Reshape is a pure metadata/view operation (no data movement), valid for any
element type backed by GPUStorage[T]. Allow bf16 on the GPU view path
(isFloat32[T] || isBFloat16[T]). CUDA-gated test asserts a GPU-resident bf16
tensor reshaped stays *GPUStorage[bf16] with data preserved.

Final piece letting the bf16 CrossAsset GPU backward run with CUDA-graph capture
ON. ADR-075 lever L4.
@dndungu dndungu merged commit 508f01d into main Jun 17, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant