Skip to content

Engine support for concatenated weights with per-sub-parameter configuration #533

@jlamypoirier

Description

@jlamypoirier

Many weights are stored as a single tensor for compute/optimization efficiency but are logically several distinct parameters. Examples:

  • fused QKV (and key/value in GQA)
  • gated MLP layer-1 (gate + up)
  • MoE expert stacks
  • SSM in_proj (z, x, B, C, dt splits)

Today the engine treats each as one opaque parameter. The goal is to make it aware of the internal structure so each logical sub-parameter can be configured independently: lr_scale, initialization, weight decay, PEFT enablement (LoRA on a subset of the split). Bias-split is possible but lower priority.

This is also architecturally adjacent to the converter weight-mapping work — splitting in_proj / QKV is the same concern that shows up in the hybrid-Mamba and Gemma converters.

Prototype approach (#366, closed — preserved here in case the branch is pruned)

The central abstraction is a ConcatenatedParameterMeta (subclass of TensorMetaParameterMeta) that holds its constituent sub-parameter metas plus the concat axis (dim_index), built via from_metas(metas, dim_index=0, dim_name=...). It validates that all constituent dims match except on the concat axis, and sets that axis to a ConcatenatedTensorDim. Everything that needs the parts treated independently routes through this one structure:

  • Initializationinit_parameter splits the buffer and delegates to each constituent meta's own init_parameter on its slice, so each sub-param keeps its own init method. (requires_global_initialization forced on.)
  • Optimizer — a new metas_for_grad property exposes the constituents to the param-group builder. When sub-params differ in lr_scale/weight_decay it sets a _split_optimization flag (only valid on dim_index==0, since the optimizer needs contiguous chunks) and the param-group loop in stage_base.py walks metas_for_grad, grouping each contiguous buffer slice by its own (weight_decay, lr_scale).
  • PEFTconcatenate_linear_layers(...) fuses already-built Linear/{Input,Output}ParallelLinear layers into one over the concatenated meta, and resolves per-sub-layer apply_peft flags: uniform → single enable/disable; mixed → builds out_channel_ranges from the concat dim's get_split_ranges(global_=True). PEFT's apply_linear/lora_linear were generalized from a single (out_channel_begin, out_channel_end) to out_channel_ranges: tuple[(begin, end), ...], so LoRA can target multiple disjoint output-channel ranges.

ConcatenatedTensorDim gained split_tensor / merge_tensors / get_split_ranges helpers (local or global), and local_to_global/global_to_local were refactored onto them.

Notably, this supersedes two prior hacks: (1) the old lr_scale-as-tuple-of-floats on ParameterMeta (which split a param into equal contiguous chunks for the optimizer) is removed — lr_scale becomes a plain scalar and combine_lr_scales a simple product; (2) the attention layer's old "build a fused key_value linear then bolt PEFT onto the value half via out_channel_begin" path is replaced by building key/value separately and fusing with concatenate_linear_layers(..., default_apply_peft=(False, True)).

The concrete demonstration is attention QKV/KV; Megatron-init compatibility was handled with a small _get_init_method helper that picks the first constituent's init for a concatenated meta.

Open scope from the prototype

Simplify and harden concatenate_linear_layers (PEFT especially — several TODOs flag rough edges, e.g. LoRA on concatenated layers, naming of the original constituent params), extend coverage to the remaining fused linears (gated MLP, MoE, SSM in_proj), and round out tests. Treat the prototype as a reference for the approach, not for reuse — it predates the linear/PEFT layer rework on main and would need a full rebase.

Drafted by Claude Opus 4.8 from the #366 diff.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions