Skip to content

[NNX] Checkpoint restore: error on weights missing from the checkpoint#4342

Draft
ecnal-cienet wants to merge 2 commits into
mainfrom
feat/nnx-ckpt-missing-param-policy
Draft

[NNX] Checkpoint restore: error on weights missing from the checkpoint#4342
ecnal-cienet wants to merge 2 commits into
mainfrom
feat/nnx-ckpt-missing-param-policy

Conversation

@ecnal-cienet

Copy link
Copy Markdown
Collaborator

Description

Restoring a Linen-layout checkpoint into an NNX state filled every absent leaf with a default value. NNX-only rngs/dropout are legitimately absent from a Linen layout, but so is a weight that is genuinely missing from disk — and that weight was then silently zero-filled, a silent accuracy loss that also slips past the ShapeDtypeStruct guard (the default is a concrete zeros array, not a struct).

This distinguishes the two cases in _populate_pure_dict_from_partial by tracking the path and whether it sits under an rngs/dropout subtree:

  • rngs/dropout absent → filled with a deterministic default, as before.
  • any other absent weight → governed by the new checkpoint_missing_param_policy:
    • "error" (default): raise naming the parameter path and expected shape.
    • "warn": log a WARNING listing the zero-filled path and continue.

"error" is the default because a zero-filled weight is silent accuracy loss; opt into "warn" only when zero-filling is intended.

Tests

  • tests/unit/checkpointing_nnx_load_test.py::TestMissingParamPolicy — present weights pass through; rng/dropout absence defaults even under error; a missing weight raises naming the path; warn zero-fills and logs; default policy is error.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jul 2, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 84.21053% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/common/checkpointing.py 83.33% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-ckpt-missing-param-policy branch from e0ceee0 to 7ea06cc Compare July 2, 2026 22:09
… mismatch

When Orbax skips a parameter on a structural/shape mismatch it leaves an
unmaterialized ShapeDtypeStruct in the restored state, which otherwise surfaces
as a cryptic compile error deep in the first train_step. Report the offending
parameter path and shape from _assert_no_shaped_dtype_struct so the mismatch
points straight at the config (emb_dim/mlp_dim/layers/scan_layers).
Restoring a Linen-layout checkpoint into an NNX state filled every absent leaf
with a default, so a weight genuinely missing from disk was silently zero-filled
-- silent accuracy loss that slipped past the ShapeDtypeStruct guard. Distinguish
the NNX-only rngs/dropout subtrees (legitimately absent from a Linen layout, still
defaulted) from real missing weights, which are now governed by the new
checkpoint_missing_param_policy: 'error' (default) raises naming the parameter
path and shape, 'warn' logs and zero-fills.
@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-ckpt-missing-param-policy branch from 7ea06cc to a381bf8 Compare July 2, 2026 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant