feat(nnx): support native Flax NNX PEFT/LoRA training loop#4346
Draft
RexBearIU wants to merge 3 commits into
Draft
feat(nnx): support native Flax NNX PEFT/LoRA training loop#4346RexBearIU wants to merge 3 commits into
RexBearIU wants to merge 3 commits into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
…ce for integration tests
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR implements native parameter-efficient fine-tuning (PEFT) and LoRA training support within the Flax NNX training loop inside MaxText (
train.pyandtrain_utils.py).Previously, training inside
train.pywith Flax NNX was limited to full-parameter fine-tuning. This was due to:setup_train_loopinsidetrain_utils.pyinitializingnnx.Optimizerwithwrt=nnx.Param, which allocated optimizer states for the entire base model.train_stepanddiff_wrapperinsidetrain.pyhardcoding model splits usingnnx.Param.sharding.pyhardcodingnnx.Paramextraction, which failed to cleanly separate base parameters from LoRA parameters.This PR addresses those limitations through the following design:
wrtdynamically based on configuration (nnx.LoRAParamifconfig.lora.enable_lorais True, otherwisennx.Param). This is used consistently across optimizer initialization, training step splits/merges, and sharding parameter extraction.setup_train_loop, Qwix LoRA adapters are dynamically injected whenconfig.lora.enable_lorais True. On fresh training runs (no previous checkpoint step), pre-trained adapters can be loaded fromlora.lora_restore_path.apply_lora_to_modelinsidelora_utils.pyto only invokejax.set_mesh(mesh)if tracing is not currently active (jax_core.trace_state_clean()), avoiding compilation errors during eager/eval paths under NNX.Tests
The changes have been thoroughly tested on CPU with no regressions.
Integration tests (
tests/integration/setup_train_loop_nnx_test.py):test_pure_nnx_setup_with_lora_enabledwhich asserts thatsetup_train_loopcorrectly instantiates a model with LoRA adapters injected and the optimizer configured withwrt=nnx.LoRAParam.test_train_step_updates_only_lora_weightswhich runs a full forward and backward pass for a single training step and asserts that only the adapter parameters (nnx.LoRAParam) are modified, while the base weights (nnx.Param) remain unchanged._tiny_nnx_pyconfig) to useattention="dot_product"so that CPU backend runs are fully supported without relying on TPU/GPU Pallas attention kernels.JAX_PLATFORMS='cpu' PYTHONPATH=src pytest tests/integration/setup_train_loop_nnx_test.py -k SetupTrainLoopNNXLoraTestUnit tests (
tests/unit/lora_utils_nnx_test.py):test_sharding_extracts_only_lora_paramsto verify thatsharding.maybe_update_params_sharding_with_optextracts onlynnx.LoRAParamunder LoRA configuration.JAX_PLATFORMS='cpu' PYTHONPATH=src pytest tests/unit/lora_utils_nnx_test.pyChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.