Skip to content

Extend LoRA for Gemma4#3969

Open
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/gemma4-lora
Open

Extend LoRA for Gemma4#3969
RexBearIU wants to merge 1 commit into
mainfrom
jackyf/gemma4-lora

Conversation

@RexBearIU
Copy link
Copy Markdown
Collaborator

Description

This PR extends the recent LoRA support to accurately target and process Gemma 4 architectures (including MoE).

Gemma 4 introduces complex nested structures (like scanned_blocks and layers_remainder) and unique chat template behaviors (such as the <|channel>thought block) that are incompatible with standard LoRA targeting and data
processing. Furthermore, MoE models require dynamic metadata synchronization during forward passes which is broken by aggressive NNX graph caching.

This PR addresses these challenges by:

  • Adding accurate regex mapping for Gemma 4 standard and MoE LoRA targets in lora_module_path.yml.
  • Implementing a thought channel bypass in input_pipeline_utils.py to prevent validation failures when the generation prompt includes the <|channel>thought block.
  • Dynamically disabling NNX graph caching in train_sft.py specifically for MoE models (where experts > 1) to allow necessary metadata synchronization.

Tests

  • Added unit tests for the Gemma 4 tokenizer bypass in tests/post_training/unit/sft_data_processing_test.py (test_tokenizer_gemma4_thought_channel_bypass).
  • Verified caching behavior changes by running Gemma-4 MoE LoRA tuning on TPU.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 2bc8632 to ab61640 Compare May 22, 2026 07:38
def test_tokenizer_wo_generation_prompt(self):
verify_chat_template_generation_prompt_logic(self.llama2_tokenizer)

def test_tokenizer_gemma4_thought_channel_bypass(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test expects to not fail with TemplateError or ValueError. Can you add an assertion for this so that it is readable what this test actually verifies?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated verify_chat_template_generation_prompt_logic to return True on success, and wrapped all three test cases in explicit self.assertTrue() assertions.

This cleanly verifies that validation succeeds and keeps all tests uniform. All tests pass successfully!"

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch 2 times, most recently from 61626bd to ef50ff7 Compare May 28, 2026 08:41
actual_prefix_in_full_turn = full_turn_ids[len(prompt_wo_gen_ids) : len(prompt_wo_gen_ids) + len(assistant_prefix)]

if actual_prefix_in_full_turn != assistant_prefix:
# Allow the generation prompt to include a thought channel block (e.g., for Gemma 4).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic looks like a hacky approach to support Gemma4. I am working on a generalized logic to support any model that requires specific prefix shifting. I will send out the PR soon.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it was definitely a workaround. Thanks for taking the lead on a generalized solution! I'll track #4010 and we can use that approach instead.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rebased this PR on top of #4010

@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from ef50ff7 to 5fd616b Compare May 29, 2026 07:48
@RexBearIU RexBearIU force-pushed the jackyf/gemma4-lora branch from 5fd616b to 6a64bd0 Compare June 1, 2026 06:59
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jun 2, 2026

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request successfully extends LoRA support for Gemma 4 architectures and addresses a critical issue with NNX graph caching in MoE models. However, there is a significant discrepancy between the PR description and the actual changes, as several mentioned files and unit tests are missing from the diff.

🔍 General Feedback

  • Missing Implementation: The PR description mentions a "thought channel bypass" in input_pipeline_utils.py and new unit tests in tests/post_training/unit/sft_data_processing_test.py, but these files are not included in the PR. Please ensure all intended changes are staged and pushed.
  • Consistency across Trainers: The dynamic disabling of NNX graph caching is a great addition for MoE stability; consider applying this same logic to DPO, RL, and Distillation trainers to ensure consistent behavior across the post-training suite.
  • LoRA Targeting: The regex for Gemma 4 LoRA targeting is comprehensive but should be monitored to ensure it doesn't become overly broad as the architecture evolves.

with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
# Disable NNX graph caching for MoE models (where experts > 1) to allow
# necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan).
enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Disabling NNX graph caching for MoE models is a critical fix for metadata synchronization. However, this change is currently only applied to the SFT trainer. Other post-training trainers, such as DPO, RL, and Distillation, may also be affected when training MoE models and should ideally receive a similar update for consistency and to avoid potential issues.

gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The regex for `gemma4` includes `MoeBlock_0` within the MLP group. Since the `.*` wildcard already allows for matching nested paths, `MoeBlock_0` may be redundant unless there are specific parameters directly under it that require targeting. Additionally, please ensure that `layers.*/.*` is sufficiently precise to avoid matching unintended components in deeper architectures.

deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma4: "decoder/(scanned_blocks|layers_remainder)/layers.*/.*(self_attention/(query|key|value|out)|mlp/.*(MoeBlock_0|wi_0|wi_1|wo|shared_experts/(wi_0|wi_1|wo)))"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern also adapts the MoE router: probably by accident. The MoeBlock_0 term matches everything inside the routed-MoE block, including the router/gate that decides which expert each token goes to. You normally don't want LoRA on the router. The actual expert weights (wi_0, wi_1, wo) are already matched by the other terms in the pattern, so you can just delete MoeBlock_0|. Experts and shared experts still get LoRA; the router no longer does. (Checked against the parameter tree.)

with mesh, nn_partitioning.axis_rules(mt_config.logical_axis_rules):
# Disable NNX graph caching for MoE models (where experts > 1) to allow
# necessary dynamic metadata synchronization during forward passes (e.g., in jax.lax.scan).
enable_nnx_cache = getattr(mt_config, "num_experts", 1) <= 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MaxText flags just directly use dot notation:
mt_config.num_experts

Also why do you need to disable it for MoE, what happens if you don't?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests for Gemma4 Lora?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants