Apply lengthy-operation timeout to state-dict checkpoint save/export#543
Merged
Conversation
StateDictCheckpointHandler.save() ran its weight-gather collectives under the default DistributedConfig timeout (60s) rather than the lengthy-operation training timeout, unlike load() which wraps its collectives in SafeLoad(timeout=config.timeout). Saving is rank-0-serialized, so on a large model or slow storage a gather on the waiting ranks can exceed 60s and the NCCL watchdog aborts with a barrier desync. Wrap the save gather loop in set_timeout(world_group, config.timeout) to match the load path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Authored by Claude Opus 4.8 (via Claude Code), reviewed by @jlamypoirier.
Problem
StateDictCheckpointHandler.save()(used forfast_llm- and HuggingFace-format checkpoint export viaTrainerand the offlineconvert) runs its per-parameter weight-gather collectives under the defaultDistributedConfig.timeout(60 s) instead of the lengthy-operationtraining.timeout(default 3600 s).The load path does not have this problem:
load()wraps its collectives inSafeLoad(timeout=config.timeout), and_save_checkpointalready passestraining.timeoutto the surrounding barriers and toget_save_config(...)— but nothing applied that timeout to the gather collectives insidesave().Saving a state-dict checkpoint is rank-0-serialized (rank 0 writes the
model_*.safetensorsfiles while the other ranks wait, with gathers interleaved between files). On a large model and/or slow/networked storage, a gather on the waiting ranks can exceed 60 s, at which point the NCCL collective watchdog fires and training aborts with:The distributed (
distributed-format) checkpoint is unaffected because every rank writes its own shard in parallel, so no single collective stalls past 60 s.Observed multi-node (16×H100, 2 nodes) exporting a ~7B model to networked storage: the step-200 distributed checkpoint completed fine, while the
fast_llmexport at step 400 aborted ~60 s into the save.Fix
Wrap the gather loop in
save()withset_timeout(world_group, config.timeout), mirroring the load path. No config change is needed —training.timeoutalready defaults to 3600 s and flows intoconfig.timeout.Notes / scope
iter_tensors()/iter_checkpoint()(the streaming weight-broadcast consumer) shares the same gather pattern but has different control flow (a generator driven by an external consumer) and its own timeout handling; left out of scope here. Worth a follow-up look.