Skip to content

[NNX] Checkpoint restore: name the parameter path on ShapeDtypeStruct mismatch#4341

Open
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-ckpt-sds-error-path
Open

[NNX] Checkpoint restore: name the parameter path on ShapeDtypeStruct mismatch#4341
ecnal-cienet wants to merge 1 commit into
mainfrom
feat/nnx-ckpt-sds-error-path

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Description

When Orbax skips a parameter during an NNX checkpoint restore (a structural/shape mismatch under partial_restore), it leaves an unmaterialized jax.ShapeDtypeStruct in the restored state. Today the guard that catches this reports only the raw struct, so the failure surfaces as a cryptic, low-level compile/allocation error deep in the first train_step.

This makes _assert_no_shaped_dtype_struct carry the tree path and report the offending parameter and its shape/dtype, pointing straight at the config that doesn't match the checkpoint (e.g. emb_dim, mlp_dim, num layers, scan_layers).

Tests

  • tests/unit/checkpointing_nnx_load_test.py::TestShapeDtypeStructPath — a surviving ShapeDtypeStruct raises naming its path and shape; concrete arrays pass.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jul 2, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 50.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/common/checkpointing.py 50.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@ecnal-cienet ecnal-cienet force-pushed the feat/nnx-ckpt-sds-error-path branch from 2db41cb to ee725b5 Compare July 2, 2026 22:08
… mismatch

When Orbax skips a parameter on a structural/shape mismatch it leaves an
unmaterialized ShapeDtypeStruct in the restored state, which otherwise surfaces
as a cryptic compile error deep in the first train_step. Report the offending
parameter path and shape from _assert_no_shaped_dtype_struct so the mismatch
points straight at the config (emb_dim/mlp_dim/layers/scan_layers).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant