Fix Ideogram4MRoPE collapsing under torch.autocast (compute rotary in float32)#13922
Fix Ideogram4MRoPE collapsing under torch.autocast (compute rotary in float32)#13922HaozheZhang6 wants to merge 11 commits into
Ideogram4MRoPE collapsing under torch.autocast (compute rotary in float32)#13922Conversation
…y in float32)
Ideogram4 builds image-token positions as IMAGE_POSITION_OFFSET (65536) + (t, h, w).
`Ideogram4MRoPE.forward` casts its operands to float32, but the rotary matmul (and
cos/sin) is on autocast's downcast list, so under torch.autocast("cuda", bfloat16) —
common in training and pipeline code — it runs in bfloat16 anyway. bfloat16's step at
65536 is 512, so every image position in a <=512 grid rounds to the same value: all
image tokens get identical rotary embeddings, spatial information is lost, and the
decoded image degenerates to a flat color.
Wrap the frequency computation in torch.autocast(enabled=False) so the rotary
embeddings are always computed in float32, matching how transformers guards its RoPE
modules. Added a regression test that fails on main and passes with the fix.
Fixes huggingface#13920
|
before committing that (and thereby closing my report), please consider that other modules might be affected, just not as bad. bfloat16 becomes inaccurate for integers starting 257.0 (which is rounded to 256.0). that's within the range of text token ids |
|
You're right — confirmed bf16 rounds 257→256, 259→260, so text positions past 256 lose precision in any RoPE that matmuls raw position ids under autocast. Ideogram4 is just the pathological case: the 65536 offset collapses a whole ≤512-wide grid onto a single value, where the others degrade gradually instead of all-at-once. I'd checked the other diffusers transformers — Ideogram4 is the only RoPE with a large position offset, so the only catastrophic one — but the gradual loss you describe is real for the rest. I can extend the same |
| # IMAGE_POSITION_OFFSET (65536), so an ambient autocast would otherwise run the matmul and | ||
| # cos/sin in bfloat16, rounding every image position to the same value and collapsing the | ||
| # rotary embeddings (all spatial information is lost). | ||
| with torch.autocast(device_type=position_ids.device.type, enabled=False): |
There was a problem hiding this comment.
We don't use autocast within our modeling implementation like this.
There was a problem hiding this comment.
Good catch — dropped the autocast guard and compute the freqs in float64 instead, which autocast doesn't downcast (matching the float64 rope path Flux uses). The autocast and float32 paths come out bit-identical (max|Δ| = 0), and the regression test still passes.
There was a problem hiding this comment.
We don't use
autocastwithin our modeling implementation like this.
@sayakpaul Why?
Maybe this can be reconsidered.
- It's the right solution. Casting to float64 as @HaozheZhang6's AI suggested below is a bad workaround
- This is what huggingface transformers does: https://github.com/huggingface/transformers/blob/08a7ef05bcf9723cb2e58855afb8dc2c799323ff/src/transformers/models/qwen3_vl/modular_qwen3_vl.py#L304
There was a problem hiding this comment.
After some internal discussion, we decided that using torch.autocast(..., enabled=False) makes sense here, so the original implementation which uses it is fine.
Per review: replace the torch.autocast(enabled=False) guard with a float64 computation, which autocast does not downcast — matching the float64 rope path used elsewhere (Flux). The autocast and float32 paths stay bit-identical (max|delta|=0).
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
dg845
left a comment
There was a problem hiding this comment.
Thanks for the PR! Would you be willing to extend the fix to other RoPE modules? Using torch.autocast(enabled=False) would be fine, as described in #13922 (comment).
Per review, use torch.autocast(enabled=False) around the rotary matmul (as the original implementation did) rather than computing in float64, and adopt the clearer comment describing the bfloat16 collapse at the 65536 offset.
|
Done — switched Ideogram4 back to On extending it: a fair number of the other transformer RoPE modules build their freqs the same way (matmul of raw position ids), so they'd downcast under an ambient autocast too — most won't collapse as hard as Ideogram4's 65536 offset, but they lose precision once positions pass ~257 in bf16. Happy to wrap each in |
| # matmul to bfloat16, the image positions will collapse to only a few distinct values because bfloat16 cannot | ||
| # represent consecutive integers at this value (after pos 65536 each 512-integer block will collapse to the | ||
| # same value), which causes the image to become essentially flat. Therefore, we need to disable autocast here. | ||
| with torch.autocast(device_type=position_ids.device.type, enabled=False): |
There was a problem hiding this comment.
I think it would be better to tighten the torch.autocast region to just the freqs matmul, since it's the operation that is actually precision-sensitive and needs the guard. So maybe something like
# <explanatory comment from above>
pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32)
inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1)
with torch.autocast(device_type=position_ids.device.type, enabled=False):
freqs = inv_freq @ pos.unsqueeze(2)
freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size)
# Rest of the implementation (setting up interleaved mrope, cos/sin call)
...|
Hi @HaozheZhang6, I think fixing all of the RoPE modules that build their |
Extend the Ideogram4 fix: ernie_image's `rope` and helios's `get_frequency_batched` build rotary freqs with an unguarded float32 einsum over raw position ids. Under an ambient autocast the einsum runs in bfloat16 on CUDA, which cannot represent consecutive integers past 256, so positions degrade — the same bug, matching the guards mochi/omnigen already have. Wrap each in torch.autocast(enabled=False).
Cosmos3VLTextRotaryEmbedding builds its interleaved-mrope freqs with an unguarded position-id matmul (same shape as Ideogram4), so an ambient autocast downcasts it to bfloat16 and collapses positions past 256. Wrap in torch.autocast(enabled=False).
|
Went through every RoPE module and extended the guard to the unguarded ones. Fixed (wrapped in Already handled, left alone: The
|
What does this PR do?
Fixes #13920
Ideogram4MRoPEproduces collapsed rotary embeddings undertorch.autocast, so denoising inside an autocast context (common in training, and when users wrap pipeline calls) renders a flat single-color image.Root cause
Image-token positions are
IMAGE_POSITION_OFFSET (65536) + (t, h, w).Ideogram4MRoPE.forwardcasts its operands to float32, but the frequency matmul is on autocast's downcast list, so undertorch.autocast("cuda", torch.bfloat16)it executes in bfloat16 anyway. bfloat16's representable step at 65536 is 512, so every image position in a ≤512-wide grid rounds to the same value — all image tokens get identical rotary embeddings, spatial information is lost, and sampling degenerates to a flat field.Reproduced with the weight-free snippet from the issue (
max |cos_autocast − cos_fp32| ≈ 1.93, distinct positions become equal).Fix
Wrap the frequency computation in
torch.autocast(device_type=..., enabled=False)so the rotary embeddings are always computed in float32 regardless of an ambient autocast — the same guardtransformersapplies to its RoPE modules. After the fix the autocast and float32 paths are bit-identical (max |Δ| = 0.0).Scope is
Ideogram4MRoPE, the catastrophic case (others noted in the issue are far milder without the 65536 offset). Happy to extend the same guard to the sibling RoPE modules in a follow-up if you'd like.Tests
Added
test_ideogram4_mrope_is_autocast_invariant— it fails onmain(collapsed positions) and passes with the fix. Full file green:Before submitting
Ideogram4MRoPEbreaks undertorch.autocast: all image positions collapse, producing flat single-color images #13920Who can review?
@DN6 @sayakpaul