Skip to content

feat(scan_layers): dynamically set scan_layers from checkpoint metadata#4344

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/dynamic-scan-layers-metadata
Open

feat(scan_layers): dynamically set scan_layers from checkpoint metadata#4344
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/dynamic-scan-layers-metadata

Conversation

@RexBearIU

@RexBearIU RexBearIU commented Jul 3, 2026

Copy link
Copy Markdown
Collaborator

Description

This PR is a follow-up to PR #4304 and PR #4340. It introduces dynamic synchronization of scan_layers configuration from checkpoint metadata when a checkpoint path is provided and scan_layers is not explicitly specified by the user on the command-line, in environment variables, or as programmatic keyword arguments.

Key Changes:

  • Robust Tracking of Explicit User Configurations (src/maxtext/configs/pyconfig.py):
    • Keeps track of overridden configuration parameters from the environment (env_keys).
    • Sets Pydantic's internal __pydantic_fields_set__ to only the union of explicit CLI, environment, and keyword argument keys. This avoids the default Pydantic behavior where unpacking a dictionary of default configurations marks every single field as explicitly set.
  • Helper for Sync and Verification (src/maxtext/utils/model_creation_utils.py):
    • Extracted dynamic configuration resolution and verification into verify_and_sync_scan_layers(config).
    • Inside from_pretrained, resolved the active config via config = verify_and_sync_scan_layers(config) without verbose comments, keeping the code highly readable.
    • If scan_layers is explicitly specified by the user and mismatches checkpoint metadata, a ValueError configuration mismatch error is raised.

Tests

  • Added a comprehensive suite of unit tests in TestVerifyAndSyncScanLayers inside tests/unit/model_creation_utils_test.py:
    • test_sync_to_false_when_implicit: Asserts that scan_layers dynamically updates to False from checkpoint metadata when not explicitly specified by the user.
    • test_sync_to_true_when_implicit: Asserts that scan_layers stays/updates to True when checkpoint metadata is True.
    • test_explicit_match_raises_no_error: Asserts that providing matching explicit values does not trigger validation errors.
    • test_explicit_mismatch_raises_value_error: Asserts that an explicit conflict raises the expected ValueError configuration mismatch error.
  • Verified all tests in model_creation_utils_test.py pass successfully:
    PYTHONPATH=src python3 -m pytest tests/unit/model_creation_utils_test.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jul 3, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 89.47368% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/model_creation_utils.py 89.47% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Base automatically changed from jackyf/cleanup-scan-layers-mismatch to main July 3, 2026 03:38

@shralex shralex left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jacky!

Can you add some comments to wherever scan_layers is described (types.py, base.yml) describing this behavior when resuming from a checkpoint.

Also, can we change docs in appropriate places to explain this new behavior. I remember I added a bunch of warnings related to scan_layers previously to the docs, these might need to be updated since you no longer need to specify scan_layers when resuming from a checkpoint.



def verify_and_sync_scan_layers(config):
"""Verify and sync scan_layers based on checkpoint metadata."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be some repeated code and checks here, would something like this work where we extract the metadata first

def verify_and_sync_scan_layers(config):
"""Verify and sync scan_layers based on checkpoint metadata."""
if not config.load_parameters_path:
return config

custom_metadata = checkpointing.load_checkpoint_metadata(config.load_parameters_path)
saved_scan_layers = custom_metadata.get("scan_layers")
if not isinstance(saved_scan_layers, bool):
return config

Extract the Pydantic config (or use direct config if already a Pydantic model)

pydantic_config = getattr(config, "_pydantic_config", config)
model_fields_set = getattr(pydantic_config, "model_fields_set", None)

If model metadata tracking isn't supported, fall back to matching check (True)

is_explicit = "scan_layers" in model_fields_set if model_fields_set is not None else True

if is_explicit:
if saved_scan_layers != config.scan_layers:
raise ValueError(
f"Configuration mismatch: Your run specifies scan_layers={config.scan_layers}, "
f"but the checkpoint was saved with scan_layers={saved_scan_layers}."
)
else:
new_pydantic_config = pydantic_config.model_copy(update={"scan_layers": saved_scan_layers})
# Wrap back in HyperParameters if the original config was wrapped
if getattr(config, "_pydantic_config", None) is not None:
config = pyconfig.HyperParameters(new_pydantic_config)
else:
config = new_pydantic_config

return config

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for the feedback! I've refactored the function using your elegant fallback structure—it is indeed much cleaner and completely avoids the repetitive code.

I've also added comments explaining the automatic synchronization behavior to both types.py and base.yml, and updated the outdated manual scan_layers warnings/guides throughout our documentation (including checkpoints.md, convert_checkpoint.md, run_maxtext_localhost.md, and all post-training tutorials).

@RexBearIU RexBearIU force-pushed the jackyf/dynamic-scan-layers-metadata branch from dc53711 to 220893d Compare July 3, 2026 04:48
@RexBearIU RexBearIU requested a review from jacoguzo as a code owner July 3, 2026 04:48
@RexBearIU RexBearIU force-pushed the jackyf/dynamic-scan-layers-metadata branch 2 times, most recently from ca08a1a to 2ddeddf Compare July 3, 2026 04:59
@RexBearIU RexBearIU force-pushed the jackyf/dynamic-scan-layers-metadata branch from 2ddeddf to 0985ebb Compare July 3, 2026 05:00

@shralex shralex left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. I made a few suggestions to update comments. The main missing thing is updates to conversion scripts: to_huggingface should auto-detect the setting as well from the provided checkpoint (and error if its setting conflicts with the checkpoint:

Consider calling the new helper function in to_huggingface.py's main function right after initialising the config:

Initialize maxtext config

config = pyconfig.initialize_pydantic(argv)

Auto-resolve scan_layers from checkpoint metadata if not explicitly provided

config = verify_and_sync_scan_layers(config)

to_maxtext should write the scan_layers metadata into the checkpoint it creates, but in the current implementation, it does not actually do so. I believe we should provide a config when calling save_weights_to_checkpoint,


- `model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7).
- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information. **IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch).
- `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information. **Note:** When resuming or loading a checkpoint in MaxText, this setting will automatically synchronize from checkpoint metadata unless explicitly overridden, meaning you do not have to manually specify it during model execution.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The text sounds like a user can override what's in the checkpoint metadata.

Can we say instead:

Note: When resuming or loading a checkpoint in MaxText, this setting will automatically be loaded from the checkpoint metadata, meaning you do not have to manually specify it during model execution. If you do explicitly specify a value for scan_layers, it must match the checkpoint's saved configuration, or a ValueError mismatch error will be raised.


- **Error:** `Type ShapeDtypeStruct is not a valid JAX type` or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`).

- **Cause: Configuration mismatch** (e.g., `scan_layers`) between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not remove the examples.

Cause: Configuration mismatch (e.g., scan_layers) between the checkpoint conversion script (e.g., to_maxtext.py or to_huggingface.py) and the trainer/inference runner (e.g., train.py). Since MaxText automatically loads scan_layers from the checkpoint's saved metadata, you should only encounter this error if you explicitly set a mismatching value on the command line).

Actually, should we also auto-detect the setting in to_huggingface ?

- **Cause: Configuration mismatch** (e.g., `scan_layers`) between the checkpoint conversion script and explicit user configurations. (Note: MaxText automatically synchronizes `scan_layers` from the checkpoint's saved metadata when resuming, so you only encounter this error if you explicitly set a mismatching value on the command line).

- **Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command.
- **Solution:** Omit the `scan_layers` parameter from your training or execution command to allow MaxText to automatically resolve it from the checkpoint metadata, or ensure any explicitly specified `scan_layers` parameter matches the format of the loaded checkpoint.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solution: Omit the scan_layers parameter from your checkpoint conversion command to allow MaxText to automatically resolve it from the checkpoint metadata, or ensure any explicitly specified scan_layers parameter matches the format of the loaded checkpoint.


> [!IMPORTANT]
> **PyTree Structure Compatibility:** Because JAX expects the loaded PyTree structure to exactly match the model's instantiated structure, the value of the `scan_layers` flag during execution (training, SFT, RL, DPO, or decoding) **must** match the format of the checkpoint being loaded. A mismatch will cause PyTree loading or shape/path mismatch errors (which MaxText will intercept to raise a descriptive `ValueError` pointing to the scan_layers setting).
> **Automatic scan_layers Resolution:** MaxText automatically detects and synchronizes `scan_layers` from the checkpoint's saved metadata when resuming (via `load_parameters_path`) if you do not explicitly specify `scan_layers` on the command-line. If you explicitly specify a value for `scan_layers` that conflicts with the checkpoint format, MaxText will raise a descriptive `ValueError` mismatch error to prevent JAX PyTree structure or shape mismatch errors during loading.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested:

Automatic scan_layers Resolution: MaxText automatically loads scan_layers from the checkpoint's metadata when resuming (via load_parameters_path) so it isn't necessary to specify scan_layers on resume. If provided, scan_layers must match the checkpoint metadata, otherwise a ValueError error is raised.

> **Matching the `scan_layers` Parameter:**
> The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`.
> **Automatic `scan_layers` Resolution:**
> MaxText automatically detects and synchronizes `scan_layers` from the checkpoint's saved metadata when resuming (via `load_parameters_path`) if you do not explicitly specify it on the command-line.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say "loads" instead of "detects and synchronizes" here and everywhere

optimizer_memory_host_offload: false
parameter_memory_host_offload: false
scan_layers: true # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations.
scan_layers: true # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. Automatically syncs with checkpoint when resuming unless overridden.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose:

Whether to use jax.lax.scan over layers (stacked/unstacked checkpoint). We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. When resuming from a checkpoint, this flag is auto-determined from metadata.

True,
description=(
"Whether to use jax.lax.scan over layers. Automatically syncs"
" with the checkpoint when resuming unless overridden."

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether to use jax.lax.scan over layers (stacked/unstacked checkpoint). When resuming from a checkpoint, this flag is auto-determined from metadata.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants