Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
02fd39c
Add tool to evaluate layer-wise numerical-error propagation
jlamypoirier May 26, 2026
4dd6c14
Collapse to a single config; require a checkpoint
jlamypoirier May 27, 2026
5ebea33
Expose `model:` alongside `pretrained:` in the tool config
jlamypoirier May 27, 2026
4c444d8
Inherit PretrainedGPTModelConfig; use Config update mechanism
jlamypoirier May 27, 2026
35206a6
Expand HF metadata allowlist for newer transformers configs
jlamypoirier May 27, 2026
bde1efa
Reshape console table for readability
jlamypoirier May 27, 2026
8099b51
Merge tensor+kind, fix decimal precision in console table
jlamypoirier May 27, 2026
dbd7702
Switch back to fixed-decimal formatting in the table
jlamypoirier May 27, 2026
152ffc3
Wipe per-variant experiment dir before each run
jlamypoirier May 27, 2026
7e98500
Support pre-generated memmap dataset; misc table-format polish
jlamypoirier May 28, 2026
173ae0d
Print per-variant summary at the end of the run
jlamypoirier May 28, 2026
005fd62
Reshape end-of-run summary: variants × aggregations, relative only
jlamypoirier May 28, 2026
c594658
Clarify intermediate aggregation in summary header
jlamypoirier May 28, 2026
3159f73
Split summary across fw/bw rows; one extra precision digit
jlamypoirier May 28, 2026
6ef153e
Two-row column header in summary; chronological column order
jlamypoirier May 28, 2026
7327932
Add fp32_lm_head flag for vLLM precision parity
jlamypoirier May 28, 2026
76335df
Extract layer-name labels for summary first/last columns
jlamypoirier May 28, 2026
8122946
Add `debug_hidden_states_log` to capture named tensors via output_hid…
jlamypoirier May 28, 2026
4633bfd
Capture logit gradients; expose them in the summary
jlamypoirier May 28, 2026
9ca1711
Place logits after head in bw summary; widen format for sub-percent v…
jlamypoirier May 28, 2026
f2655f3
Pick per-column decimals to guarantee ≥2 sig figs
jlamypoirier May 28, 2026
7f8ef96
Tighten summary table spacing
jlamypoirier May 28, 2026
08b1637
Support HF Hub model ids in pretrained.path
jlamypoirier May 28, 2026
77eae22
Add example precision-evaluation configs
jlamypoirier May 28, 2026
efa95b1
Drop bf16_no_fp32_gradients variant from example configs
jlamypoirier May 28, 2026
46bc5b8
Add weight gradients to per-variant report tables
jlamypoirier May 28, 2026
bef2f0d
Separate fw/bw/grad rows in per-variant tables
jlamypoirier May 28, 2026
4fecad4
Split summary into three tables (fw, bw, grad)
jlamypoirier May 28, 2026
4f47dc0
Split grad summary by parameter category
jlamypoirier May 28, 2026
5198c25
Per-tensor sample-density overrides in TensorLogsConfig
jlamypoirier May 28, 2026
312343e
Chosen-logprob loss, per-variant grad-scale auto-calibration, fp16 va…
jlamypoirier May 29, 2026
497c76c
Lean fixed-input runner + DeepSpeed-side precision comparison
jlamypoirier Jun 1, 2026
cecf7ae
Support random-init (model_weights=False) in both precision tools
jlamypoirier Jun 1, 2026
a6d9314
vLLM within-engine precision tool
jlamypoirier Jun 1, 2026
26cd8ab
Auto-bind vLLM fp32 head on tied-embedding models
jlamypoirier Jun 1, 2026
fc6072c
Cross-engine log-prob comparison tool + per-token log π persistence
jlamypoirier Jun 2, 2026
767e7eb
Cross-engine: all-pairwise mismatched group + fp16 precision
jlamypoirier Jun 2, 2026
2828c35
Cross-engine: full 2x2 head matrix per pair, corrected mismatch direc…
jlamypoirier Jun 2, 2026
ca0c7b2
Add forward-only inference mode to the precision tool
jlamypoirier Jun 2, 2026
f2b9d21
Add forward-only mode to the DeepSpeed precision tool
jlamypoirier Jun 2, 2026
32c9545
Add Qwen2.5-7B precision-evaluation config (forward-only)
jlamypoirier Jun 2, 2026
f8a3d20
Multi-sequence (per-sequence/GSPO) cross-engine log-prob mode
jlamypoirier Jun 3, 2026
7a203aa
Multi-sequence forward+backward layer-wise mode + network-FS-safe rmtree
jlamypoirier Jun 4, 2026
35a913d
Merge origin/main into jlp_evaluate_precision
jlamypoirier Jun 6, 2026
116dc06
Batch-size scaling experiment + read-only memmap and TF32 matmul fixes
jlamypoirier Jun 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions examples/evaluate_precision/batch_size/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Batch-size scaling experiment (Qwen2.5-0.5B)

Tests whether small-batch training matches large-batch when Adam hyperparameters are scaled
correctly (Marek et al., arXiv:2507.07101): only β2 must scale to hold the token half-life
fixed; β1 stays at 0.9; the learning rate is ~batch-insensitive.

Loss is compared **vs tokens seen**. All arms share one model init (the warmup checkpoint) and
one fresh, disjoint token stream (the `experiment` split), so the curves form a paired comparison.

## Arms (b=16, B=512, ratio 32; lengths in sequences of 2048 tokens)

| Arm | batch | lr | β1 | β2 | role |
|---|---|---|---|---|---|
| A | 512 | 1e-4 | 0.9 | 0.95 | large-batch reference |
| B | 16 | 1e-4 | 0.9 | 0.998398 (= 0.95^(1/32)) | paper scaling — only β2 moves |
| C | 16 | 3.125e-6 (= 1e-4/32) | 0.996875 | 0.9984375 | conservative — linear lr + linear (1−β) |
| D | 16 | 1e-4 | 0.9 | 0.95 | naive unscaled (strawman) |

lr held constant (no decay) at 1e-4 — the Qwen2.5-7B peak is 3e-4 cosine-to-10%; we use a lower
constant value since constant-lr has no decay phase to lower the effective rate.

Predictions: B overlays A; D degrades (wrong steady-state second-moment averaging); C sits above
B if scaling lr down 32× undertrains. `arm_base.yaml` defaults to D.

## Sequence

```bash
# 0. Tokenize FineWeb-Edu -> 3 disjoint splits
fast-llm prepare gpt_memmap --config prepare.yaml

# 1. Warmup (4 GPUs, ~1h). Kill at ~1h; note the latest checkpoint iteration <ITER>.
torchrun --nproc_per_node=4 -m fast_llm train gpt --config warmup.yaml

# 2. Set pretrained.path in arm_base.yaml to experiments/batch_size/warmup/checkpoint/<ITER>,
# then launch the four arms in parallel, one GPU each:
CUDA_VISIBLE_DEVICES=0 fast-llm train gpt --config arm_base.yaml \
run.experiment_dir=experiments/batch_size/arm_A \
schedule.depth_first_micro_batches=512 training.train_iters=40000 \
training.checkpoint.interval=1000 training.evaluators.validation.interval=190 &

CUDA_VISIBLE_DEVICES=1 fast-llm train gpt --config arm_base.yaml \
run.experiment_dir=experiments/batch_size/arm_B \
optimizer.beta_2=0.998398 &

CUDA_VISIBLE_DEVICES=2 fast-llm train gpt --config arm_base.yaml \
run.experiment_dir=experiments/batch_size/arm_C \
optimizer.beta_2=0.9984375 optimizer.beta_1=0.996875 optimizer.learning_rate.base=3.125e-6 &

CUDA_VISIBLE_DEVICES=3 fast-llm train gpt --config arm_base.yaml \
run.experiment_dir=experiments/batch_size/arm_D &
```

`num_samples` is identical across arms (512×40000 = 16×1280000 = 20,480,000 ≈ 42B tokens, within the
~94B `experiment` split), so the shuffled token stream is identical — only the batching differs. The
42B is a cap; monitor the curves and stop when the comparison is conclusive. Compare loss at matched
tokens (step × micro_batch_size × depth_first), as a paired difference vs A plus held-out validation loss.

Logging goes to W&B `servicenow-team/batch_size_experiments` (group `arms`; warmup in group `warmup`);
each arm is a separate run named after its experiment_dir.
62 changes: 62 additions & 0 deletions examples/evaluate_precision/batch_size/arm_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Stage 2 — branch arms from the warmup checkpoint. Each arm is a SEPARATE single-GPU run,
# launched with per-arm overrides (see README.md run matrix). This base file holds the shared
# settings and defaults to arm D (naive small batch: 16 sequences, unscaled hyperparameters).
#
# Arms inherit the architecture from the warmup checkpoint (pretrained.load_config: model),
# start fresh (completed_steps = 0) so all arms consume the SAME `experiment` token stream from
# offset 0, and differ only in: schedule.depth_first_micro_batches, optimizer betas/lr,
# training.train_iters (chosen so num_samples is identical across arms), eval interval, and dir.
#
# CUDA_VISIBLE_DEVICES=0 fast-llm train gpt --config arm_base.yaml <ARM OVERRIDES>
pretrained:
path: experiments/batch_size/warmup/checkpoint/LATEST # set to the ~1h checkpoint iteration
format: distributed
model_weights: true
optimizer_state: false # cold optimizer; each arm warms its own v under its own beta_2
load_config: model
training:
train_iters: 1280000 # small-batch cap; num_samples = 2 x 1280000 = 2,560,000 (~42B tokens at micro_batch_size 16384, within the ~94B experiment split). Monitor and stop.
num_workers: 8
wandb:
entity_name: servicenow-team
project_name: batch_size_experiments
group_name: arms
logs:
interval: 10
checkpoint:
interval: 30000 # ~1B tokens for small arms
keep: 2
evaluators:
validation:
type: loss
interval: 6000 # ~200M-token cadence for small arms
iterations: 50
schedule:
depth_first_micro_batches: 2 # small batch = 2 x 16384 = 32768 tokens (large arms override to 64 -> 1,048,576)
data:
datasets:
training:
type: file
path: data/batch_size/fineweb_edu_qwen/fast_llm_config_experiment.yaml
validation:
type: file
path: data/batch_size/fineweb_edu_qwen/fast_llm_config_validation.yaml
micro_batch_size: 16384 # tokens packed per micro-batch (throughput knob; documents are masked per-doc within the pack). Same value for all arms -> identical packing.
optimizer:
learning_rate:
base: 1.0e-04
decay_style: constant
warmup_iterations: 0
weight_decay: 0.0
beta_1: 0.9
beta_2: 0.95
epsilon: 1.0e-08
gradient_norm_clipping: 5.0 # raised from 1.0: grad norms ~1, so clipping ~never fires -> avoids batch-size-dependent differential clipping; still a safety net vs spikes
model:
multi_stage:
zero_stage: 1 # single-GPU: sharding is a no-op; 0 is not a valid value (range 1-3)
distributed:
compute_dtype: bf16
seed: 7 # shared across arms -> identical data order
run:
experiment_dir: experiments/batch_size/arm_D
22 changes: 22 additions & 0 deletions examples/evaluate_precision/batch_size/prepare.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Stage 0 — tokenize FineWeb-Edu with the Qwen2.5 tokenizer into a Fast-LLM memmap dataset,
# partitioned into three disjoint splits:
# warmup — the throwaway from-scratch prefix (stage 1)
# experiment — fresh, unseen-by-warmup data, shared by every arm (stage 2)
# validation — held out, shared by all runs for comparable eval-loss curves
#
# fast-llm prepare gpt_memmap --config examples/evaluate_precision/batch_size/prepare.yaml
output_path: data/batch_size/fineweb_edu_qwen
dataset:
path: HuggingFaceFW/fineweb-edu
config_name: sample-100BT
split: train
trust_remote_code: true
tokenizer:
path: Qwen/Qwen2.5-0.5B
# Qwen has no native BOS; skip it and separate documents by EOS (<|endoftext|>) only (PR #534).
add_bos: false
add_eos: true
splits:
warmup: 0.05
experiment: 0.94
validation: 0.01
81 changes: 81 additions & 0 deletions examples/evaluate_precision/batch_size/warmup.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Stage 1 — from-scratch warmup of a Qwen2.5-0.5B model to a mid-early-training branch point.
# This prefix is throwaway: only the final checkpoint is reused, as `pretrained` init for the arms.
# Run data-parallel on 4 GPUs; stop at ~1h wall-clock and branch from the last checkpoint.
# torchrun --nproc_per_node=4 -m fast_llm train gpt --config examples/evaluate_precision/batch_size/warmup.yaml
#
# Global batch = micro_batch_size (2048) x depth_first_micro_batches (128) x data_parallel (4)
# = 512 sequences of length 2048 (matches Qwen2.5's pretraining batch).
training:
train_iters: 1000 # cap only — kill at ~1h and use the latest checkpoint
num_workers: 8
wandb:
entity_name: servicenow-team
project_name: batch_size_experiments
group_name: warmup
logs:
interval: 10
checkpoint:
interval: 100 # frequent saves so we can branch near the 1h mark
keep: 3
schedule:
depth_first_micro_batches: 128
data:
datasets:
training:
type: file
path: data/batch_size/fineweb_edu_qwen/fast_llm_config_warmup.yaml
micro_batch_size: 2048
optimizer:
learning_rate:
base: 1.0e-04
decay_style: constant
warmup_iterations: 100
weight_decay: 0.0
beta_1: 0.9
beta_2: 0.95
epsilon: 1.0e-08
gradient_norm_clipping: 1.0
model:
base_model:
embeddings:
vocab_size: 151936
dropout: 0.0
decoder:
block:
mixer:
type: attention
rotary:
type: default
theta: 1000000
heads: 14
head_groups: 2
head_size: 64
add_linear_biases: false
query_layer: {bias: {enabled: true}}
key_layer: {bias: {enabled: true}}
value_layer: {bias: {enabled: true}}
dense_layer: {bias: {enabled: false}}
dropout: 0.0
mlp:
intermediate_size: 4864
add_linear_biases: false
gated: true
activation: silu
normalization:
type: rms_norm
epsilon: 1.0e-06
dropout: 0.0
num_blocks: 24
head:
normalization:
type: rms_norm
epsilon: 1.0e-06
hidden_size: 896
tied_embedding_weight: true
multi_stage:
zero_stage: 2
distributed:
compute_dtype: bf16
seed: 1234
run:
experiment_dir: experiments/batch_size/warmup
25 changes: 25 additions & 0 deletions examples/evaluate_precision/qwen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Precision-evaluation config on Qwen2.5-0.5B — the model used for the Fast-LLM vs DeepSpeed
# precision-pattern comparison (DeepSpeed side: tools/evaluate_precision_deepspeed.py).
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen.yaml
pretrained:
path: Qwen/Qwen2.5-0.5B
format: qwen2
output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_features
sequence_length: 2048
variants:
# Maps to the DeepSpeed harness's `bf16_head_bf16` (compute bf16, lm head in compute dtype).
bf16:
model.distributed.compute_dtype: bfloat16
# Maps to the DeepSpeed harness's `bf16` (compute bf16, fp32 lm head — the stack default).
bf16_fp32_lm_head:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
# Maps to the DeepSpeed harness's `fp16_head_fp16`.
fp16:
model.distributed.compute_dtype: float16
# Maps to the DeepSpeed harness's `fp16`.
fp16_fp32_lm_head:
model.distributed.compute_dtype: float16
model.base_model.head.fp32_lm_head: true
28 changes: 28 additions & 0 deletions examples/evaluate_precision/qwen_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Precision-evaluation config on Qwen2.5-7B. Unlike the 0.5B model, the 7B has untied embeddings,
# so the LM head is a real parameter and an fp32 head genuinely changes the logits (on tied models
# it can be a no-op). `forward_only` runs a single inference-mode forward so the fp32 reference fits
# in memory — forward+backward+Adam in fp32 would not.
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen_7b.yaml
pretrained:
path: Qwen/Qwen2.5-7B
format: qwen2
output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_7b
sequence_length: 2048
forward_only: true
variants:
# compute bf16, lm head in compute dtype.
bf16:
model.distributed.compute_dtype: bfloat16
# compute bf16, fp32 lm head (the stack default).
bf16_fp32_lm_head:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
# compute fp16, lm head in compute dtype.
fp16:
model.distributed.compute_dtype: float16
# compute fp16, fp32 lm head.
fp16_fp32_lm_head:
model.distributed.compute_dtype: float16
model.base_model.head.fp32_lm_head: true
35 changes: 35 additions & 0 deletions examples/evaluate_precision/qwen_multi_seq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Multi-sequence (per-sequence / GSPO) precision config on Qwen2.5-0.5B. Each MATH-500 example becomes one
# independent sequence (chat-templated problem + reference solution); the tool runs one inference-mode
# forward per sequence and saves per-sequence log π. The cross-engine tool then reports the across-sequence
# distribution of the length-normalized log-ratio (the GSPO-relevant quantity), alongside the per-token
# (GRPO) view. `input_dataset` implies forward-only; the other engines consume the saved inputs.pt.
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen_multi_seq.yaml
# python -m tools.evaluate_precision_deepspeed --model Qwen/Qwen2.5-0.5B \
# --inputs-file <output_dir>/inputs.pt --output-dir <output_dir>/ds
# python -m tools.evaluate_precision_vllm --model Qwen/Qwen2.5-0.5B --attention-backend TRITON_ATTN \
# --inputs-file <output_dir>/inputs.pt --output-dir <output_dir>/vllm
# python -m tools.evaluate_precision_cross_engine --fast-llm-dir <output_dir> \
# --deepspeed-dir <output_dir>/ds --vllm-dir <output_dir>/vllm --inputs-file <output_dir>/inputs.pt
pretrained:
path: Qwen/Qwen2.5-0.5B
format: qwen2
output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_multi_seq
input_dataset: HuggingFaceH4/MATH-500
num_sequences: 256
sequence_length: 2048
# Forward-only: one inference forward per sequence, saving per-sequence log π for the cross-engine
# comparison (no per-layer tables). The layer-wise variant (qwen_multi_seq_layerwise.yaml) sets this false.
forward_only: true
variants:
bf16:
model.distributed.compute_dtype: bfloat16
bf16_fp32_lm_head:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
fp16:
model.distributed.compute_dtype: float16
fp16_fp32_lm_head:
model.distributed.compute_dtype: float16
model.base_model.head.fp32_lm_head: true
43 changes: 43 additions & 0 deletions examples/evaluate_precision/qwen_multi_seq_layerwise.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Layer-wise precision over multiple real sequences, on Qwen2.5-0.5B. Reproduces the single-input
# layer-wise tables (per-layer forward activation / backward gradient / parameter-gradient RMS vs an
# fp32 reference), but aggregated across many real MATH-500 sequences. Each example becomes one
# independent sequence (chat-templated problem + reference solution); the tool runs a full
# forward+backward per sequence with the per-layer debug logs on, then averages each per-tensor metric
# across sequences (the report also keeps the across-sequence std/max of the relative RMS). Averaging
# over sequences damps per-token noise (~1/√(N·T)) while systematic per-layer bias persists — the
# GSPO-relevant view of the layer-wise precision.
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen_multi_seq_layerwise.yaml
pretrained:
path: Qwen/Qwen2.5-0.5B
format: qwen2
output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_multi_seq_layerwise
input_dataset: HuggingFaceH4/MATH-500
num_sequences: 32
sequence_length: 2048
# Forward+backward per sequence: capture the full per-layer activation / gradient tables. (The
# logprob-only cross-engine config, qwen_multi_seq.yaml, sets this true for a forward-only run.)
forward_only: false
variants:
# Baseline bf16: compute_dtype=bf16 + Fast-LLM defaults (fp32 grad accumulation, bf16 residual, bf16 head).
bf16:
model.distributed.compute_dtype: bfloat16
# Full-precision residual stream on.
bf16_fp32_residual:
model.distributed.compute_dtype: bfloat16
model.base_model.embeddings.full_precision_residual: true
# fp32 LM-head matmul on (the stack default).
bf16_fp32_lm_head:
model.distributed.compute_dtype: bfloat16
model.base_model.head.fp32_lm_head: true
# Both stability features on (most precise bf16-compute configuration).
bf16_max_precision:
model.distributed.compute_dtype: bfloat16
model.base_model.embeddings.full_precision_residual: true
model.base_model.head.fp32_lm_head: true
fp16:
model.distributed.compute_dtype: float16
fp16_fp32_lm_head:
model.distributed.compute_dtype: float16
model.base_model.head.fp32_lm_head: true
20 changes: 20 additions & 0 deletions examples/evaluate_precision/qwen_multi_seq_layerwise_256.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Layer-wise precision over 256 real sequences, trimmed to the variants that matter for the
# systematic-gradient-bias question: reference fp32 + bf16 + fp16 only. Forward+backward per sequence,
# per-tensor metrics averaged across sequences. More sequences lower the noise floor on the bias estimate
# (floor ~ 1/sqrt(N_sequences)); 256 puts the measured bias well above the floor.
#
# Run with:
# python -m tools.evaluate_precision -c examples/evaluate_precision/qwen_multi_seq_layerwise_256.yaml
pretrained:
path: Qwen/Qwen2.5-0.5B
format: qwen2
output_dir: /tmp/fast_llm_tests/evaluate_precision/qwen_multi_seq_layerwise_256
input_dataset: HuggingFaceH4/MATH-500
num_sequences: 256
sequence_length: 2048
forward_only: false
variants:
bf16:
model.distributed.compute_dtype: bfloat16
fp16:
model.distributed.compute_dtype: float16
Loading
Loading