Fsdp2 stormscope #1671
Conversation
Greptile SummaryThis PR migrates StormScope/StormCast from FSDP1 (
Important Files Changed
|
| with torch.no_grad(): | ||
| for p in model.parameters(): | ||
| if p.is_contiguous(): | ||
| continue | ||
| p.data = p.data.contiguous() |
There was a problem hiding this comment.
When
use_shard_tensor=True, distribute_module has already been called and model.parameters() yields DTensor-backed nn.Parameters. Assigning p.data = p.data.contiguous() on a DTensor parameter is not documented PyTorch API; Tensor.set_() (which backs the .data setter) with a DTensor argument may silently strip the DTensor's mesh/placements metadata, breaking the subsequent fully_shard call. In practice distribute_tensor normalises contiguity internally so the guard p.is_contiguous() is usually True for DTensor params and the assignment is skipped — but making the skip explicit prevents a silent breakage if that behaviour changes.
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if p.is_contiguous(): | |
| continue | |
| p.data = p.data.contiguous() | |
| with torch.no_grad(): | |
| for p in model.parameters(): | |
| if isinstance(p.data, DTensor): | |
| continue # distribute_module already normalises DTensor local shards | |
| if p.is_contiguous(): | |
| continue | |
| p.data = p.data.contiguous() |
Add test_contiguity_channels_last verifying that the FSDP2 contiguity workaround makes all parameters standard-contiguous while preserving channels-last output activations (cuDNN fast path). Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
DomainParallelNoiseScheduler now requires shard_dim; pass shard_dim=2 in test_wrapper_timesteps_replicated. Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
FSDP2 raises NotImplementedError on non-contiguous parameters. Models using channels_last memory format have 4-D conv params with non-standard strides. Force standard contiguity on parameter storage before fully_shard(); kernels still use channels_last when inputs arrive in that layout, retaining the cuDNN fast path. Also sort imports (removed unused Callable). Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
Verify that the FSDP2 contiguity workaround makes all parameters standard-contiguous while channels_last input still produces channels_last output activations (cuDNN fast path retained). Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
- Add _materialize_dtensors_to_full() to gather remaining DTensors after get_model_state_dict(full_state_dict=True) which is not fully honoured by FSDP2 on some PyTorch versions. - Extend _has_non_fsdp_dtensors for FSDPModule with empty placements to avoid KeyError: 'state.0.step' via the manual load path. - Handle cross-mesh DTensor inputs in _redistribute_sd_for_dtensor and _redistribute_optim_sd_for_dtensor (peel to_local, then re-distribute onto the new mesh). - Add _materialize_optimizer_state_for_dcp() to pre-populate optimizer state placeholders before set_optimizer_state_dict. Fixes RuntimeError: found no DeviceMesh, KeyError: 'state.0.step', and size-mismatch errors on cross-mesh checkpoint reloads. Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
Add four multi-GPU tests covering FSDP2 (fully_shard) scenarios: - Basic 1-D mesh model+optimizer+scheduler round-trip - channels_last convolutions with the contiguity workaround - 2-D mesh (ddp x domain) with ShardTensor - GradScaler state preservation under FSDP2 + amp autocast Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
Replace fragile __bases__[1] assumption with an MRO walk that skips FSDPModule/nn.Module/object. Resilient to future PyTorch changes that may insert intermediate mixins in the dynamically-generated class. Addresses greptile review comment on PR NVIDIA#1671. Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
PhysicsNeMo Pull Request
Migrates StormScope off FSDP1 onto FSDP2 (
fully_shard/FSDPModule)....Description
FSDP1's flat-param machinery doesn't compose with
ShardTensor/ DTensor. The refactoredShardTensorin #1556 breaks FSDP1's backward pass in StormCast, so domain parallelism was not working with FSDP (only DDP).Fun fact: StormCast was using
FSDP(model, sharding_strategy=NO_SHARD)— which is effectively DDP with extra overhead. Moving to FSDP2 (fully_shard) gives us real sharding capability and DTensor compatibility.Checklist
Dependencies
Tests Added
examples/weather/stormcast/test_training.pytest_contiguity_channels_last— verifies contiguity fix preserves channels-last activations (cuDNN fast path)test/utils/test_checkpoint_distributed.py(4 new multi-GPU tests)test_fsdp2_checkpoint_roundtripfully_shardmodel+optimizer+scheduler save/loadtest_fsdp2_checkpoint_channels_last_roundtriptest_fsdp2_shard_tensor_checkpoint_roundtriptest_fsdp2_grad_scaler_checkpointReview Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.