feat(scan_layers): dynamically set scan_layers from checkpoint metadata#4344
feat(scan_layers): dynamically set scan_layers from checkpoint metadata#4344RexBearIU wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
shralex
left a comment
There was a problem hiding this comment.
Thanks Jacky!
Can you add some comments to wherever scan_layers is described (types.py, base.yml) describing this behavior when resuming from a checkpoint.
Also, can we change docs in appropriate places to explain this new behavior. I remember I added a bunch of warnings related to scan_layers previously to the docs, these might need to be updated since you no longer need to specify scan_layers when resuming from a checkpoint.
|
|
||
|
|
||
| def verify_and_sync_scan_layers(config): | ||
| """Verify and sync scan_layers based on checkpoint metadata.""" |
There was a problem hiding this comment.
There seem to be some repeated code and checks here, would something like this work where we extract the metadata first
def verify_and_sync_scan_layers(config):
"""Verify and sync scan_layers based on checkpoint metadata."""
if not config.load_parameters_path:
return config
custom_metadata = checkpointing.load_checkpoint_metadata(config.load_parameters_path)
saved_scan_layers = custom_metadata.get("scan_layers")
if not isinstance(saved_scan_layers, bool):
return config
Extract the Pydantic config (or use direct config if already a Pydantic model)
pydantic_config = getattr(config, "_pydantic_config", config)
model_fields_set = getattr(pydantic_config, "model_fields_set", None)
If model metadata tracking isn't supported, fall back to matching check (True)
is_explicit = "scan_layers" in model_fields_set if model_fields_set is not None else True
if is_explicit:
if saved_scan_layers != config.scan_layers:
raise ValueError(
f"Configuration mismatch: Your run specifies scan_layers={config.scan_layers}, "
f"but the checkpoint was saved with scan_layers={saved_scan_layers}."
)
else:
new_pydantic_config = pydantic_config.model_copy(update={"scan_layers": saved_scan_layers})
# Wrap back in HyperParameters if the original config was wrapped
if getattr(config, "_pydantic_config", None) is not None:
config = pyconfig.HyperParameters(new_pydantic_config)
else:
config = new_pydantic_config
return config
There was a problem hiding this comment.
Thank you so much for the feedback! I've refactored the function using your elegant fallback structure—it is indeed much cleaner and completely avoids the repetitive code.
I've also added comments explaining the automatic synchronization behavior to both types.py and base.yml, and updated the outdated manual scan_layers warnings/guides throughout our documentation (including checkpoints.md, convert_checkpoint.md, run_maxtext_localhost.md, and all post-training tutorials).
dc53711 to
220893d
Compare
ca08a1a to
2ddeddf
Compare
2ddeddf to
0985ebb
Compare
shralex
left a comment
There was a problem hiding this comment.
Thanks for the changes. I made a few suggestions to update comments. The main missing thing is updates to conversion scripts: to_huggingface should auto-detect the setting as well from the provided checkpoint (and error if its setting conflicts with the checkpoint:
Consider calling the new helper function in to_huggingface.py's main function right after initialising the config:
Initialize maxtext config
config = pyconfig.initialize_pydantic(argv)
Auto-resolve scan_layers from checkpoint metadata if not explicitly provided
config = verify_and_sync_scan_layers(config)
to_maxtext should write the scan_layers metadata into the checkpoint it creates, but in the current implementation, it does not actually do so. I believe we should provide a config when calling save_weights_to_checkpoint,
|
|
||
| - `model_name`: The specific model identifier. It must match a supported entry in the MaxText [globals.py](https://github.com/AI-Hypercomputer/maxtext/blob/16b684840db9b96b19e24e84ac49f06af7204ae3/src/maxtext/utils/globals.py#L46C1-L46C7). | ||
| - `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information. **IMPORTANT:** This setting *must* match the `scan_layers` value used during model training or loading. A mismatch will cause PyTree loading errors (though MaxText will intercept these and raise a descriptive `ValueError` explaining the mismatch). | ||
| - `scan_layers`: Controls whether the output uses a scanned (`scan_layers=true`) or unscanned (`scan_layers=false`) checkpoint format. Refer [to the Checkpoints guide](checkpoints) for more information. **Note:** When resuming or loading a checkpoint in MaxText, this setting will automatically synchronize from checkpoint metadata unless explicitly overridden, meaning you do not have to manually specify it during model execution. |
There was a problem hiding this comment.
The text sounds like a user can override what's in the checkpoint metadata.
Can we say instead:
Note: When resuming or loading a checkpoint in MaxText, this setting will automatically be loaded from the checkpoint metadata, meaning you do not have to manually specify it during model execution. If you do explicitly specify a value for scan_layers, it must match the checkpoint's saved configuration, or a ValueError mismatch error will be raised.
|
|
||
| - **Error:** `Type ShapeDtypeStruct is not a valid JAX type` or generic **PyTree structure/shape mismatches** (e.g., Orbax reporting `"X/Y paths matched"`, such as `143/145 paths`). | ||
|
|
||
| - **Cause: Configuration mismatch** (e.g., `scan_layers`) between the checkpoint conversion script (e.g., `to_maxtext.py` or `to_huggingface.py`) and the trainer/inference runner (e.g., `train.py`). |
There was a problem hiding this comment.
Lets not remove the examples.
Cause: Configuration mismatch (e.g., scan_layers) between the checkpoint conversion script (e.g., to_maxtext.py or to_huggingface.py) and the trainer/inference runner (e.g., train.py). Since MaxText automatically loads scan_layers from the checkpoint's saved metadata, you should only encounter this error if you explicitly set a mismatching value on the command line).
Actually, should we also auto-detect the setting in to_huggingface ?
| - **Cause: Configuration mismatch** (e.g., `scan_layers`) between the checkpoint conversion script and explicit user configurations. (Note: MaxText automatically synchronizes `scan_layers` from the checkpoint's saved metadata when resuming, so you only encounter this error if you explicitly set a mismatching value on the command line). | ||
|
|
||
| - **Solution:** Ensure the `scan_layers` flag is set to the exact same value (`True` or `False`) in both the conversion command and your training/execution command. | ||
| - **Solution:** Omit the `scan_layers` parameter from your training or execution command to allow MaxText to automatically resolve it from the checkpoint metadata, or ensure any explicitly specified `scan_layers` parameter matches the format of the loaded checkpoint. |
There was a problem hiding this comment.
Solution: Omit the scan_layers parameter from your checkpoint conversion command to allow MaxText to automatically resolve it from the checkpoint metadata, or ensure any explicitly specified scan_layers parameter matches the format of the loaded checkpoint.
|
|
||
| > [!IMPORTANT] | ||
| > **PyTree Structure Compatibility:** Because JAX expects the loaded PyTree structure to exactly match the model's instantiated structure, the value of the `scan_layers` flag during execution (training, SFT, RL, DPO, or decoding) **must** match the format of the checkpoint being loaded. A mismatch will cause PyTree loading or shape/path mismatch errors (which MaxText will intercept to raise a descriptive `ValueError` pointing to the scan_layers setting). | ||
| > **Automatic scan_layers Resolution:** MaxText automatically detects and synchronizes `scan_layers` from the checkpoint's saved metadata when resuming (via `load_parameters_path`) if you do not explicitly specify `scan_layers` on the command-line. If you explicitly specify a value for `scan_layers` that conflicts with the checkpoint format, MaxText will raise a descriptive `ValueError` mismatch error to prevent JAX PyTree structure or shape mismatch errors during loading. |
There was a problem hiding this comment.
Suggested:
Automatic scan_layers Resolution: MaxText automatically loads scan_layers from the checkpoint's metadata when resuming (via load_parameters_path) so it isn't necessary to specify scan_layers on resume. If provided, scan_layers must match the checkpoint metadata, otherwise a ValueError error is raised.
| > **Matching the `scan_layers` Parameter:** | ||
| > The `scan_layers` setting during your fine-tuning run **must match** the setting used when creating the checkpoint at `MAXTEXT_CKPT_PATH`. | ||
| > **Automatic `scan_layers` Resolution:** | ||
| > MaxText automatically detects and synchronizes `scan_layers` from the checkpoint's saved metadata when resuming (via `load_parameters_path`) if you do not explicitly specify it on the command-line. |
There was a problem hiding this comment.
I would say "loads" instead of "detects and synchronizes" here and everywhere
| optimizer_memory_host_offload: false | ||
| parameter_memory_host_offload: false | ||
| scan_layers: true # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. | ||
| scan_layers: true # We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. Automatically syncs with checkpoint when resuming unless overridden. |
There was a problem hiding this comment.
I propose:
Whether to use jax.lax.scan over layers (stacked/unstacked checkpoint). We recommend setting this to false when using pipeline parallelism, instead scanning the PP iterations. When resuming from a checkpoint, this flag is auto-determined from metadata.
| True, | ||
| description=( | ||
| "Whether to use jax.lax.scan over layers. Automatically syncs" | ||
| " with the checkpoint when resuming unless overridden." |
There was a problem hiding this comment.
Whether to use jax.lax.scan over layers (stacked/unstacked checkpoint). When resuming from a checkpoint, this flag is auto-determined from metadata.
Description
This PR is a follow-up to PR #4304 and PR #4340. It introduces dynamic synchronization of
scan_layersconfiguration from checkpoint metadata when a checkpoint path is provided andscan_layersis not explicitly specified by the user on the command-line, in environment variables, or as programmatic keyword arguments.Key Changes:
src/maxtext/configs/pyconfig.py):env_keys).__pydantic_fields_set__to only the union of explicit CLI, environment, and keyword argument keys. This avoids the default Pydantic behavior where unpacking a dictionary of default configurations marks every single field as explicitly set.src/maxtext/utils/model_creation_utils.py):verify_and_sync_scan_layers(config).from_pretrained, resolved the active config viaconfig = verify_and_sync_scan_layers(config)without verbose comments, keeping the code highly readable.scan_layersis explicitly specified by the user and mismatches checkpoint metadata, aValueErrorconfiguration mismatch error is raised.Tests
TestVerifyAndSyncScanLayersinsidetests/unit/model_creation_utils_test.py:test_sync_to_false_when_implicit: Asserts thatscan_layersdynamically updates toFalsefrom checkpoint metadata when not explicitly specified by the user.test_sync_to_true_when_implicit: Asserts thatscan_layersstays/updates toTruewhen checkpoint metadata isTrue.test_explicit_match_raises_no_error: Asserts that providing matching explicit values does not trigger validation errors.test_explicit_mismatch_raises_value_error: Asserts that an explicit conflict raises the expectedValueErrorconfiguration mismatch error.model_creation_utils_test.pypass successfully:Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.