Many weights are stored as a single tensor for compute/optimization efficiency but are logically several distinct parameters. Examples:
- fused QKV (and key/value in GQA)
- gated MLP layer-1 (gate + up)
- MoE expert stacks
- SSM
in_proj (z, x, B, C, dt splits)
Today the engine treats each as one opaque parameter. The goal is to make it aware of the internal structure so each logical sub-parameter can be configured independently: lr_scale, initialization, weight decay, PEFT enablement (LoRA on a subset of the split). Bias-split is possible but lower priority.
This is also architecturally adjacent to the converter weight-mapping work — splitting in_proj / QKV is the same concern that shows up in the hybrid-Mamba and Gemma converters.
Prototype approach (#366, closed — preserved here in case the branch is pruned)
The central abstraction is a ConcatenatedParameterMeta (subclass of TensorMeta → ParameterMeta) that holds its constituent sub-parameter metas plus the concat axis (dim_index), built via from_metas(metas, dim_index=0, dim_name=...). It validates that all constituent dims match except on the concat axis, and sets that axis to a ConcatenatedTensorDim. Everything that needs the parts treated independently routes through this one structure:
- Initialization —
init_parameter splits the buffer and delegates to each constituent meta's own init_parameter on its slice, so each sub-param keeps its own init method. (requires_global_initialization forced on.)
- Optimizer — a new
metas_for_grad property exposes the constituents to the param-group builder. When sub-params differ in lr_scale/weight_decay it sets a _split_optimization flag (only valid on dim_index==0, since the optimizer needs contiguous chunks) and the param-group loop in stage_base.py walks metas_for_grad, grouping each contiguous buffer slice by its own (weight_decay, lr_scale).
- PEFT —
concatenate_linear_layers(...) fuses already-built Linear/{Input,Output}ParallelLinear layers into one over the concatenated meta, and resolves per-sub-layer apply_peft flags: uniform → single enable/disable; mixed → builds out_channel_ranges from the concat dim's get_split_ranges(global_=True). PEFT's apply_linear/lora_linear were generalized from a single (out_channel_begin, out_channel_end) to out_channel_ranges: tuple[(begin, end), ...], so LoRA can target multiple disjoint output-channel ranges.
ConcatenatedTensorDim gained split_tensor / merge_tensors / get_split_ranges helpers (local or global), and local_to_global/global_to_local were refactored onto them.
Notably, this supersedes two prior hacks: (1) the old lr_scale-as-tuple-of-floats on ParameterMeta (which split a param into equal contiguous chunks for the optimizer) is removed — lr_scale becomes a plain scalar and combine_lr_scales a simple product; (2) the attention layer's old "build a fused key_value linear then bolt PEFT onto the value half via out_channel_begin" path is replaced by building key/value separately and fusing with concatenate_linear_layers(..., default_apply_peft=(False, True)).
The concrete demonstration is attention QKV/KV; Megatron-init compatibility was handled with a small _get_init_method helper that picks the first constituent's init for a concatenated meta.
Open scope from the prototype
Simplify and harden concatenate_linear_layers (PEFT especially — several TODOs flag rough edges, e.g. LoRA on concatenated layers, naming of the original constituent params), extend coverage to the remaining fused linears (gated MLP, MoE, SSM in_proj), and round out tests. Treat the prototype as a reference for the approach, not for reuse — it predates the linear/PEFT layer rework on main and would need a full rebase.
Drafted by Claude Opus 4.8 from the #366 diff.
Many weights are stored as a single tensor for compute/optimization efficiency but are logically several distinct parameters. Examples:
in_proj(z, x, B, C, dt splits)Today the engine treats each as one opaque parameter. The goal is to make it aware of the internal structure so each logical sub-parameter can be configured independently:
lr_scale, initialization, weight decay, PEFT enablement (LoRA on a subset of the split). Bias-split is possible but lower priority.This is also architecturally adjacent to the converter weight-mapping work — splitting
in_proj/ QKV is the same concern that shows up in the hybrid-Mamba and Gemma converters.Prototype approach (#366, closed — preserved here in case the branch is pruned)
The central abstraction is a
ConcatenatedParameterMeta(subclass ofTensorMeta→ParameterMeta) that holds its constituent sub-parameter metas plus the concat axis (dim_index), built viafrom_metas(metas, dim_index=0, dim_name=...). It validates that all constituent dims match except on the concat axis, and sets that axis to aConcatenatedTensorDim. Everything that needs the parts treated independently routes through this one structure:init_parametersplits the buffer and delegates to each constituent meta's owninit_parameteron its slice, so each sub-param keeps its own init method. (requires_global_initializationforced on.)metas_for_gradproperty exposes the constituents to the param-group builder. When sub-params differ inlr_scale/weight_decayit sets a_split_optimizationflag (only valid ondim_index==0, since the optimizer needs contiguous chunks) and the param-group loop instage_base.pywalksmetas_for_grad, grouping each contiguous buffer slice by its own(weight_decay, lr_scale).concatenate_linear_layers(...)fuses already-builtLinear/{Input,Output}ParallelLinearlayers into one over the concatenated meta, and resolves per-sub-layerapply_peftflags: uniform → single enable/disable; mixed → buildsout_channel_rangesfrom the concat dim'sget_split_ranges(global_=True). PEFT'sapply_linear/lora_linearwere generalized from a single(out_channel_begin, out_channel_end)toout_channel_ranges: tuple[(begin, end), ...], so LoRA can target multiple disjoint output-channel ranges.ConcatenatedTensorDimgainedsplit_tensor/merge_tensors/get_split_rangeshelpers (local or global), andlocal_to_global/global_to_localwere refactored onto them.Notably, this supersedes two prior hacks: (1) the old
lr_scale-as-tuple-of-floats onParameterMeta(which split a param into equal contiguous chunks for the optimizer) is removed —lr_scalebecomes a plain scalar andcombine_lr_scalesa simple product; (2) the attention layer's old "build a fusedkey_valuelinear then bolt PEFT onto the value half viaout_channel_begin" path is replaced by buildingkey/valueseparately and fusing withconcatenate_linear_layers(..., default_apply_peft=(False, True)).The concrete demonstration is attention QKV/KV; Megatron-init compatibility was handled with a small
_get_init_methodhelper that picks the first constituent's init for a concatenated meta.Open scope from the prototype
Simplify and harden
concatenate_linear_layers(PEFT especially — severalTODOs flag rough edges, e.g. LoRA on concatenated layers, naming of the original constituent params), extend coverage to the remaining fused linears (gated MLP, MoE, SSMin_proj), and round out tests. Treat the prototype as a reference for the approach, not for reuse — it predates the linear/PEFT layer rework onmainand would need a full rebase.Drafted by Claude Opus 4.8 from the #366 diff.