Extend LoRA for Gemma4#3969
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
2bc8632 to
ab61640
Compare
| def test_tokenizer_wo_generation_prompt(self): | ||
| verify_chat_template_generation_prompt_logic(self.llama2_tokenizer) | ||
|
|
||
| def test_tokenizer_gemma4_thought_channel_bypass(self): |
There was a problem hiding this comment.
This test expects to not fail with TemplateError or ValueError. Can you add an assertion for this so that it is readable what this test actually verifies?
There was a problem hiding this comment.
I updated verify_chat_template_generation_prompt_logic to return True on success, and wrapped all three test cases in explicit self.assertTrue() assertions.
This cleanly verifies that validation succeeds and keeps all tests uniform. All tests pass successfully!"
61626bd to
ef50ff7
Compare
| actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)] | ||
|
|
||
| if actual_prefix_in_full_turn != assistant_prefix: | ||
| # Allow the generation prompt to include a thought channel block (e.g., for Gemma 4). |
There was a problem hiding this comment.
This logic looks like a hacky approach to support Gemma4. I am working on a generalized logic to support any model that requires specific prefix shifting. I will send out the PR soon.
There was a problem hiding this comment.
Agreed, it was definitely a workaround. Thanks for taking the lead on a generalized solution! I'll track #4010 and we can use that approach instead.
ef50ff7 to
5fd616b
Compare
5fd616b to
6a64bd0
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request successfully extends LoRA support for Gemma 4 architectures and addresses a critical issue with NNX graph caching in MoE models. However, there is a significant discrepancy between the PR description and the actual changes, as several mentioned files and unit tests are missing from the diff.
🔍 General Feedback
- Missing Implementation: The PR description mentions a "thought channel bypass" in
input_pipeline_utils.pyand new unit tests intests/post_training/unit/sft_data_processing_test.py, but these files are not included in the PR. Please ensure all intended changes are staged and pushed. - Consistency across Trainers: The dynamic disabling of NNX graph caching is a great addition for MoE stability; consider applying this same logic to DPO, RL, and Distillation trainers to ensure consistent behavior across the post-training suite.
- LoRA Targeting: The regex for Gemma 4 LoRA targeting is comprehensive but should be monitored to ensure it doesn't become overly broad as the architecture evolves.
| with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): | ||
| # Disable NNX graph caching for MoE models (where experts > 1) to allow | ||
| # necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan). | ||
| enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1 |
There was a problem hiding this comment.
| gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" | ||
| gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" | ||
| gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))" | ||
| olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" |
There was a problem hiding this comment.
| deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)" | ||
| gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" | ||
| gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" | ||
| gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))" |
There was a problem hiding this comment.
The pattern also adapts the MoE router: probably by accident. The MoeBlock_0 term matches everything inside the routed-MoE block, including the router/gate that decides which expert each token goes to. You normally don't want LoRA on the router. The actual expert weights (wi_0, wi_1, wo) are already matched by the other terms in the pattern, so you can just delete MoeBlock_0|. Experts and shared experts still get LoRA; the router no longer does. (Checked against the parameter tree.)
| with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules): | ||
| # Disable NNX graph caching for MoE models (where experts > 1) to allow | ||
| # necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan). | ||
| enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1 |
There was a problem hiding this comment.
For MaxText flags just directly use dot notation:
mt_config.num_experts
Also why do you need to disable it for MoE, what happens if you don't?
There was a problem hiding this comment.
Are there any tests for Gemma4 Lora?
Description
This PR extends the recent LoRA support to accurately target and process Gemma 4 architectures (including MoE).
Gemma 4 introduces complex nested structures (like scanned_blocks and layers_remainder) and unique chat template behaviors (such as the <|channel>thought block) that are incompatible with standard LoRA targeting and data
processing. Furthermore, MoE models require dynamic metadata synchronization during forward passes which is broken by aggressive NNX graph caching.
This PR addresses these challenges by:
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.