Skip to content

feat(nnx): support native Flax NNX PEFT/LoRA training loop#4346

Draft
RexBearIU wants to merge 3 commits into
mainfrom
nnx-lora-support
Draft

feat(nnx): support native Flax NNX PEFT/LoRA training loop#4346
RexBearIU wants to merge 3 commits into
mainfrom
nnx-lora-support

Conversation

@RexBearIU

Copy link
Copy Markdown
Collaborator

Description

This PR implements native parameter-efficient fine-tuning (PEFT) and LoRA training support within the Flax NNX training loop inside MaxText (train.py and train_utils.py).

Previously, training inside train.py with Flax NNX was limited to full-parameter fine-tuning. This was due to:

  1. setup_train_loop inside train_utils.py initializing nnx.Optimizer with wrt=nnx.Param, which allocated optimizer states for the entire base model.
  2. train_step and diff_wrapper inside train.py hardcoding model splits using nnx.Param.
  3. Sharding extraction inside sharding.py hardcoding nnx.Param extraction, which failed to cleanly separate base parameters from LoRA parameters.

This PR addresses those limitations through the following design:

  • Dynamic Parameter Split/Filter Target: We define wrt dynamically based on configuration (nnx.LoRAParam if config.lora.enable_lora is True, otherwise nnx.Param). This is used consistently across optimizer initialization, training step splits/merges, and sharding parameter extraction.
  • LoRA Adapter Injection & Restoration: In setup_train_loop, Qwix LoRA adapters are dynamically injected when config.lora.enable_lora is True. On fresh training runs (no previous checkpoint step), pre-trained adapters can be loaded from lora.lora_restore_path.
  • Mesh Context Tracing Safety: Updated apply_lora_to_model inside lora_utils.py to only invoke jax.set_mesh(mesh) if tracing is not currently active (jax_core.trace_state_clean()), avoiding compilation errors during eager/eval paths under NNX.

Tests

The changes have been thoroughly tested on CPU with no regressions.

  1. Integration tests (tests/integration/setup_train_loop_nnx_test.py):

    • Added test_pure_nnx_setup_with_lora_enabled which asserts that setup_train_loop correctly instantiates a model with LoRA adapters injected and the optimizer configured with wrt=nnx.LoRAParam.
    • Added test_train_step_updates_only_lora_weights which runs a full forward and backward pass for a single training step and asserts that only the adapter parameters (nnx.LoRAParam) are modified, while the base weights (nnx.Param) remain unchanged.
    • Updated the tiny NNX testing config (_tiny_nnx_pyconfig) to use attention="dot_product" so that CPU backend runs are fully supported without relying on TPU/GPU Pallas attention kernels.
    • Run command:
      JAX_PLATFORMS='cpu' PYTHONPATH=src pytest tests/integration/setup_train_loop_nnx_test.py -k SetupTrainLoopNNXLoraTest
  2. Unit tests (tests/unit/lora_utils_nnx_test.py):

    • Added test_sharding_extracts_only_lora_params to verify that sharding.maybe_update_params_sharding_with_opt extracts only nnx.LoRAParam under LoRA configuration.
    • Run command:
      JAX_PLATFORMS='cpu' PYTHONPATH=src pytest tests/unit/lora_utils_nnx_test.py

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jul 3, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 71.42857% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/train_utils.py 60.00% 3 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant