Skip to content

Fsdp2 stormscope #1671

Open
negin513 wants to merge 14 commits into
NVIDIA:mainfrom
negin513:fsdp2-stormcast
Open

Fsdp2 stormscope #1671
negin513 wants to merge 14 commits into
NVIDIA:mainfrom
negin513:fsdp2-stormcast

Conversation

@negin513
Copy link
Copy Markdown
Member

@negin513 negin513 commented May 26, 2026

PhysicsNeMo Pull Request

Migrates StormScope off FSDP1 onto FSDP2 (fully_shard / FSDPModule)....

Description

FSDP1's flat-param machinery doesn't compose with ShardTensor / DTensor. The refactored ShardTensor in #1556 breaks FSDP1's backward pass in StormCast, so domain parallelism was not working with FSDP (only DDP).

Fun fact: StormCast was using FSDP(model, sharding_strategy=NO_SHARD) — which is effectively DDP with extra overhead. Moving to FSDP2 (fully_shard) gives us real sharding capability and DTensor compatibility.

Checklist

Dependencies

Tests Added

examples/weather/stormcast/test_training.py

  • test_contiguity_channels_last — verifies contiguity fix preserves channels-last activations (cuDNN fast path)
    test/utils/test_checkpoint_distributed.py (4 new multi-GPU tests)
Test Validates
test_fsdp2_checkpoint_roundtrip Basic fully_shard model+optimizer+scheduler save/load
test_fsdp2_checkpoint_channels_last_roundtrip channels_last conv + contiguity fix + cross-rank parity
test_fsdp2_shard_tensor_checkpoint_roundtrip 2-D mesh (ddp × domain) with ShardTensor (4-GPU)
test_fsdp2_grad_scaler_checkpoint GradScaler state under FSDP2 + amp autocast

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@negin513 negin513 requested a review from CharlelieLrt as a code owner May 26, 2026 17:09
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 26, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread examples/weather/stormcast/utils/parallel.py Outdated
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 26, 2026

Greptile Summary

This PR migrates StormScope/StormCast from FSDP1 (FullyShardedDataParallel with NO_SHARD) to FSDP2 (fully_shard / FSDPModule), motivated by FSDP1's flat-param machinery being incompatible with the refactored ShardTensor/DTensor in #1556.

  • parallel.py: replaces the FSDP(NO_SHARD) wrapper with fully_shard, adds a pre-shard loop to force standard contiguity on channels-last parameters (FSDP2 rejects non-contiguous params ≤ PyTorch 2.10), and updates the docstring/return type to FSDPModule.
  • checkpoint.py: adds _unwrapped_class_name to recover the original class name from FSDP2's dynamically-generated subclass, extends _has_non_fsdp_dtensors with a degenerate-mesh guard for FSDP2 (size-1 mesh axes break DCP's broadcast), and threads both helpers through _unique_model_names.

Important Files Changed

Filename Overview
physicsnemo/utils/checkpoint.py Adds FSDPModule (FSDP2) support: new _unwrapped_class_name helper, extended _has_non_fsdp_dtensors logic for degenerate meshes, and updated _unique_model_names to use the new helper; _is_distributed_model does not add an explicit FSDPModule check.
examples/weather/stormcast/utils/parallel.py Replaces FSDP1 (NO_SHARD) with FSDP2 fully_shard; adds a pre-shard contiguity normalization loop that also iterates over DTensor parameters when use_shard_tensor=True.
examples/weather/stormcast/utils/trainer.py Minor comment and log-message updates to reflect FSDP2 terminology; no logic changes.
examples/weather/stormcast/test_training.py Removes now-unused FSDP1 imports (StateDictType, ShardedStateDictConfig, ShardedOptimStateDictConfig); no test logic changes.

Comments Outside Diff (1)

  1. physicsnemo/utils/checkpoint.py, line 70-74 (link)

    P2 _is_distributed_model relies solely on the DTensor parameter check to detect FSDP2 models. In practice fully_shard always converts parameters to DTensors, so this works, but a model with no learnable parameters would return False even after FSDP2 wrapping, silently routing it through the non-distributed checkpoint path. Adding an explicit FSDPModule branch makes the intent clear and is defensive against edge cases.

Reviews (1): Last reviewed commit: "improving checkpoint" | Re-trigger Greptile

Comment thread physicsnemo/utils/checkpoint.py
Comment on lines +270 to +274
with torch.no_grad():
for p in model.parameters():
if p.is_contiguous():
continue
p.data = p.data.contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 When use_shard_tensor=True, distribute_module has already been called and model.parameters() yields DTensor-backed nn.Parameters. Assigning p.data = p.data.contiguous() on a DTensor parameter is not documented PyTorch API; Tensor.set_() (which backs the .data setter) with a DTensor argument may silently strip the DTensor's mesh/placements metadata, breaking the subsequent fully_shard call. In practice distribute_tensor normalises contiguity internally so the guard p.is_contiguous() is usually True for DTensor params and the assignment is skipped — but making the skip explicit prevents a silent breakage if that behaviour changes.

Suggested change
with torch.no_grad():
for p in model.parameters():
if p.is_contiguous():
continue
p.data = p.data.contiguous()
with torch.no_grad():
for p in model.parameters():
if isinstance(p.data, DTensor):
continue # distribute_module already normalises DTensor local shards
if p.is_contiguous():
continue
p.data = p.data.contiguous()

@negin513 negin513 changed the title Fsdp2 stormscope Fsdp2 stormscope [WIP] May 26, 2026
negin513 and others added 8 commits May 26, 2026 11:26
Add test_contiguity_channels_last verifying that the FSDP2 contiguity
workaround makes all parameters standard-contiguous while preserving
channels-last output activations (cuDNN fast path).

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
DomainParallelNoiseScheduler now requires shard_dim; pass shard_dim=2
in test_wrapper_timesteps_replicated.

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
FSDP2 raises NotImplementedError on non-contiguous parameters.
Models using channels_last memory format have 4-D conv params with
non-standard strides. Force standard contiguity on parameter storage
before fully_shard(); kernels still use channels_last when inputs
arrive in that layout, retaining the cuDNN fast path.
Also sort imports (removed unused Callable).

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
Verify that the FSDP2 contiguity workaround makes all parameters
standard-contiguous while channels_last input still produces
channels_last output activations (cuDNN fast path retained).

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
- Add _materialize_dtensors_to_full() to gather remaining DTensors
  after get_model_state_dict(full_state_dict=True) which is not
  fully honoured by FSDP2 on some PyTorch versions.
- Extend _has_non_fsdp_dtensors for FSDPModule with empty placements
  to avoid KeyError: 'state.0.step' via the manual load path.
- Handle cross-mesh DTensor inputs in _redistribute_sd_for_dtensor
  and _redistribute_optim_sd_for_dtensor (peel to_local, then
  re-distribute onto the new mesh).
- Add _materialize_optimizer_state_for_dcp() to pre-populate
  optimizer state placeholders before set_optimizer_state_dict.
Fixes RuntimeError: found no DeviceMesh, KeyError: 'state.0.step',
and size-mismatch errors on cross-mesh checkpoint reloads.

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
Add four multi-GPU tests covering FSDP2 (fully_shard) scenarios:
- Basic 1-D mesh model+optimizer+scheduler round-trip
- channels_last convolutions with the contiguity workaround
- 2-D mesh (ddp x domain) with ShardTensor
- GradScaler state preservation under FSDP2 + amp autocast

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
negin513 added 2 commits May 28, 2026 14:50
Replace fragile __bases__[1] assumption with an MRO walk that skips
FSDPModule/nn.Module/object. Resilient to future PyTorch changes that
may insert intermediate mixins in the dynamically-generated class.

Addresses greptile review comment on PR NVIDIA#1671.

Signed-off-by: Negin Sobhani <nsobhani@nvidia.com>
@negin513 negin513 changed the title Fsdp2 stormscope [WIP] Fsdp2 stormscope May 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant