[NNX] Checkpoint restore: error on weights missing from the checkpoint#4342
Draft
ecnal-cienet wants to merge 2 commits into
Draft
[NNX] Checkpoint restore: error on weights missing from the checkpoint#4342ecnal-cienet wants to merge 2 commits into
ecnal-cienet wants to merge 2 commits into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
e0ceee0 to
7ea06cc
Compare
… mismatch When Orbax skips a parameter on a structural/shape mismatch it leaves an unmaterialized ShapeDtypeStruct in the restored state, which otherwise surfaces as a cryptic compile error deep in the first train_step. Report the offending parameter path and shape from _assert_no_shaped_dtype_struct so the mismatch points straight at the config (emb_dim/mlp_dim/layers/scan_layers).
Restoring a Linen-layout checkpoint into an NNX state filled every absent leaf with a default, so a weight genuinely missing from disk was silently zero-filled -- silent accuracy loss that slipped past the ShapeDtypeStruct guard. Distinguish the NNX-only rngs/dropout subtrees (legitimately absent from a Linen layout, still defaulted) from real missing weights, which are now governed by the new checkpoint_missing_param_policy: 'error' (default) raises naming the parameter path and shape, 'warn' logs and zero-fills.
7ea06cc to
a381bf8
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Restoring a Linen-layout checkpoint into an NNX state filled every absent leaf with a default value. NNX-only
rngs/dropoutare legitimately absent from a Linen layout, but so is a weight that is genuinely missing from disk — and that weight was then silently zero-filled, a silent accuracy loss that also slips past theShapeDtypeStructguard (the default is a concrete zeros array, not a struct).This distinguishes the two cases in
_populate_pure_dict_from_partialby tracking the path and whether it sits under anrngs/dropoutsubtree:checkpoint_missing_param_policy:"error"(default): raise naming the parameter path and expected shape."warn": log a WARNING listing the zero-filled path and continue."error"is the default because a zero-filled weight is silent accuracy loss; opt into"warn"only when zero-filling is intended.Tests
tests/unit/checkpointing_nnx_load_test.py::TestMissingParamPolicy— present weights pass through; rng/dropout absence defaults even undererror; a missing weight raises naming the path;warnzero-fills and logs; default policy iserror.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.