Skip to content

Cast batch token ids to int64 in the data loader#540

Merged
jlamypoirier merged 1 commit into
mainfrom
jlp_token-ids-int64
Jun 12, 2026
Merged

Cast batch token ids to int64 in the data loader#540
jlamypoirier merged 1 commit into
mainfrom
jlp_token-ids-int64

Conversation

@jlamypoirier

Copy link
Copy Markdown
Collaborator

Authored by Claude Opus 4.8 (1M context), at Joel's direction.

Problem

get_unsigned_integer_type stores token ids in the narrowest integer type that fits the vocab — int16 for any vocab < 2**15. torch.embedding and the cross-entropy loss only accept int32/int64 indices, so training a model whose vocab is below 32768 crashes on the first forward step:

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.ShortTensor instead (while checking arguments for embedding)

examples/mistral.yaml (vocab 32000, type: random data) triggers it directly. It is not random-data-specific — any dataset with a sub-32768 vocab hits it.

Root cause

Token ids flow dataset → TokenDocument.tokens → TokenBatch → token_ids kwarg / labels with no dtype normalization. TokenBatch._get_model_input's meta branch already declares tokens as int64, but the real (non-meta) collation path kept the dataset's narrow dtype — so the real data wasn't honoring the int64 contract the meta path already assumes.

Fix

Normalize at the single collation chokepoint, TokenBatch.from_documents, so both token_ids and the derived labels (a clone of the batch tokens) are int64:

-            tokens=torch.cat(tokens),
+            tokens=torch.cat(tokens).to(torch.int64),

Testing

Ran examples/mistral.yaml (Mistral-7B, type: random) for 10 steps on 8 GPUs: trains cleanly (lm_head_loss ≈ 10.87 ≈ ln(32000), grad norm ~4.0, no NaN/skipped iterations). Without the fix it crashes on step 1.

Separately, the same run surfaced what looks like a pre-existing, unrelated issue — the throughput line reports negative model FLOP/s for this config. Not addressed here; happy to file it separately.

🤖 Generated with Claude Code

Token ids are stored in the narrowest integer type that fits the vocab (int16 for vocab < 2**15 via get_unsigned_integer_type), but torch.embedding and the cross-entropy loss require int32/int64 indices. Any dataset with a sub-32768 vocab therefore crashed on the first forward step with "Expected ... Long, Int; but got torch.cuda.ShortTensor". examples/mistral.yaml triggers it directly (vocab 32000).

TokenBatch._get_model_input's meta branch already declares tokens as int64, but the real collation path kept the dataset's narrow dtype. Normalize at the single chokepoint, TokenBatch.from_documents, so both token_ids and the derived labels honor the int64 contract.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier merged commit a7ddbd8 into main Jun 12, 2026
3 checks passed
@jlamypoirier jlamypoirier deleted the jlp_token-ids-int64 branch June 12, 2026 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant