Add NVFP4 per-token quantization recipe#3045
Draft
cael-ling wants to merge 10 commits into
Draft
Conversation
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.
* common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
* common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
to 2d quant of W).
* pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
grouped bulk binding and per-token GEMM entry; thin pybind layer.
* pytorch/custom_recipes/{gemm_nvfp4_per_token,
quantization_nvfp4_per_token_group}.py: Python wrappers.
* tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
cast tests + bf16-close GEMM tests.
* tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
Graphs columns, ratio against per-tensor RHT+SR baseline.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
6f17fe4 to
928ab1c
Compare
for more information, see https://pre-commit.ci
…uped) Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax) and K2 (encode) kernels for both single-tensor and grouped paths. with_rht=False is byte-equal to the pre-RHT code path; when true, applies a 16-pt RHT on the columnwise direction in both K1 and K2 (rowwise stays raw) with outer amax + inner SF self-consistent. Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32 sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into block_amax / block_scale (bit-exact). Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and byte-equality regressions. Benches gain a --rht flag (2-way default, 3-way under --rht). Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K: * single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT) * grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT) Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D). Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).
The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.
with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).
Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.
Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.
Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.
Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.
New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
- nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
(CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
- nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
matching the value cuBLASLt auto-folds via its amax slot.
Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.
Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
- fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
M,N,K = 256..1024.
- fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
(sanity check that the EVT tree and the baked constant are both correct).
Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.
bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
dispatch). Quant + GEMM inside the timing loop, N = K. Function
docstring carries an ASCII kernel-pipeline diagram for both paths
(per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
bwd never re-quantizes). pten side uses RHT-cols + SR grad
quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
carries an ASCII kernel-pipeline diagram (per-step launch budget:
per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
bf16 per-row*per-col post-scale, "Route 1") next to the existing
ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
baseline. Headline ratio lp/cf decides whether to dispatch
per-token through cuBLASLt + post_scale or fused EVT; current
data shows ct_fused wins or ties at every shape we care about.
test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
(catches breakage cheaper than the broader Layer 3 test).
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
for more information, see https://pre-commit.ci
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds an NVFP4 per-token quantization fast path for bf16 inputs on Blackwell (SM100+) for Model Pre-training. Per-token uses per-row / per-column outer amax instead of the per-tensor scalar amax, which factors cleanly out of the GEMM K-summation and lets the inner GEMM stay on production cuBLASLT NVFP4 plus a thin trailing post-scale.
Status: draft. The cast kernel and GEMM composite, byte-equal-verified against a Python reference, and benched against the per-tensor (RHT + SR) recipe are still in progress. Partial experimental results are shown as follows.
Tests and benches
This PR ships four new pytest / benchmark files under
tests/pytorch/nvfp4/. All four require bf16 input andM % 128 == 0/K % 128 == 0; GEMM tests are gated by SM100 (Blackwell).test_nvfp4_per_token.py— single-tensor correctnesstest_nvfp4_per_token_group.py— group-tensor correctnessbench_nvfp4_per_token.py— single-tensor amax + quant benchWall-time benchmark of the single-tensor amax + quant composite (
tex.nvfp4_per_token_quantize, this PR) against the per-tensor RHT+SR production baseline (NVFP4Quantizer(rht=True, sr=True)viatex.quantize). Both sides use rowwise + columnwise. Single output table:Each row reports eager wall-time plus CUDA Graphs replay (kernel-only floor).
ratio < 1.0⇒ per-token is faster than the per-tensor baseline.bench_nvfp4_per_token_group.py— grouped amax + quant benchWall-time benchmark of the grouped amax + quant composite (
nvfp4_per_token_group_quantizePython wrapper backed by the new C++ bulk entry, this PR) against the per-tensor RHT+SR grouped production baseline (tex.split_quantize(...)withNVFP4Quantizer(rht=True, sr=True)per split). Layout identical to the single-tensor bench:Default sweep is 6 × 3 = 18 cases at fixed
N = 8equal splits (MoE-typical):sum_M ∈ {1024, 2048, 4096, 8192, 16384, 32768}(so per-splitM_i ∈ {128 … 4096}) ×K ∈ {2048, 4096, 8192}. CUDA Graphs replay reported on every row.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: