Skip to content

Enable Compile for Shard Tensor#1682

Open
coreyjadams wants to merge 6 commits into
NVIDIA:mainfrom
coreyjadams:shard_tensor_compile
Open

Enable Compile for Shard Tensor#1682
coreyjadams wants to merge 6 commits into
NVIDIA:mainfrom
coreyjadams:shard_tensor_compile

Conversation

@coreyjadams
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

This PR enables compilation for ShardTensor. It depends on the refactor of shard tensor #1556 still - that has to merge first - so the diff will be a little smaller in the end, since that PR is included here too for now.

The main requirements for enabling torch.compile are:

  • ensuring we use the newer setup_context style autograd functions everywhere
  • avoiding torch.Size in favor of int. This primarily impacts the ShardTensorSpec objects.
  • ensuring torch_flatten and it's inverse are available for tracing
  • making available a function for coercing the tangent objects for the backward pass to the same shape as the objects themselves. For ShardTensor, care is taken to avoid non-traceable redistributions here. We also needed to cache shard hints for the backward pass in a way that was compiler friendly.

There is also a minor bug fix here in the knn op that shows up in shard tensor spec determining the size of sharding shapes automatically.

Two test suites are added: one to test compilation "in general" and one to verify that each op we're wrapping is tracable via the compiler. It's implicit that DTensor fallback ops are compilable, though verification of that is not really done by us ...

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@coreyjadams
Copy link
Copy Markdown
Collaborator Author

One other notable thing here: our ring attention implementation is not traceable at the moment. streams and overlapped communication / computation is not traceable. I think we can resolve that though it's not here quite yet.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR enables torch.compile support for ShardTensor by migrating every custom torch.autograd.Function subclass from the old forward(ctx, ...) API to the new forward(...) + setup_context(ctx, inputs, output) API required by AOTAutograd, and by replacing torch.Size with plain tuple[int, ...] throughout _sharding_shapes to avoid dynamo's symbolic-shape fakeification path.

  • All 12 autograd.Function subclasses migrated to new-style setup_context; dist.all_reduce / dist.all_to_all calls replaced with funcol functional collectives so AOT backward graphs avoid raw ProcessGroup references.
  • ShardTensor gains __tensor_flatten__/__tensor_unflatten__, tangent coercion hooks, and property overrides for requires_grad, grad, grad_fn, is_leaf that shield against re-entrant __torch_function__ dispatch.
  • Three new FSDP-adapter symbols are added and exported but carry an inline ### TODO / I think we do not [need this] comment — worth resolving before merge.

Important Files Changed

Filename Overview
physicsnemo/domain_parallel/shard_tensor.py Major refactor: migrates to new-style autograd API, adds tensor flatten/unflatten, tangent coercion hooks, and property overrides. Adds FSDPOutputTensorAdapter with a TODO comment questioning its necessity but still exporting it publicly.
physicsnemo/domain_parallel/_shard_redistribute.py ShardRedistribute migrated to new-style setup_context; forward populates _sharding_shapes on target_spec using chunk arithmetic to avoid blocking collectives under compile.
physicsnemo/domain_parallel/_shard_tensor_spec.py Changes _sharding_shapes type from torch.Size to plain tuple[int,...] throughout to avoid dynamo symbolic-shape special-casing; compute_sharding_shapes_from_chunking_global_shape refactored.
physicsnemo/domain_parallel/custom_ops/_reductions.py ShardedSum/ShardedMean migrated to new-style setup_context with DisableTorchFunctionSubclass shielding; new build_reduction_result helper avoids from_local overhead.
physicsnemo/domain_parallel/shard_utils/halo.py HaloPadding/UnHaloPadding migrated to new-style setup_context; all_to_all halo collective refactored to funcol.all_to_all_single_autograd for compile compatibility.
physicsnemo/domain_parallel/shard_utils/normalization_patches.py PartialGroupNorm migrated to new-style setup_context; dist.all_reduce replaced with funcol.all_reduce to avoid raw ProcessGroup references in AOT backward graph.
physicsnemo/domain_parallel/init.py Exports three new FSDP adapter symbols and unbind_wrapper; stub fallbacks added for when ST_AVAILABLE is False.
test/domain_parallel/ops/test_compile_ops.py New test file: compile-traceability tests for each refactored autograd op using aot_eager backend.
test/domain_parallel/test_compile.py New test file: tests for coerce_same_metadata_as_tangent covering uneven shard round-trips and a full compile+backward smoke test.

Reviews (1): Last reviewed commit: "This commit enables torch.compile for mo..." | Re-trigger Greptile

Comment on lines +650 to +651
# This has to be funcol collectives, below, to be
# conmpatible with torc.compile.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Two typos in this comment: "conmpatible" → "compatible" and "torc.compile" → "torch.compile".

Suggested change
# This has to be funcol collectives, below, to be
# conmpatible with torc.compile.
# This has to be funcol collectives, below, to be
# compatible with torch.compile.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 1257 to +1299
if needs_redistribute:
self = self.redistribute(placements=new_placements)

if self.grad_fn is not None:
return torch.Tensor.backward(self, *args, **kwargs)

return self.to_local().backward(*args, **kwargs)


### TODO
### Do we still need this?
### I think we do not - CJA


class FSDPOutputTensorAdapter(nn.Module):
"""Wrap a module and convert ShardTensor outputs to torch.Tensor."""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, *args, **kwargs):
out = self.module(*args, **kwargs)
return out.to_local() if isinstance(out, ShardTensor) else out


def wrap_for_fsdp(module: nn.Module) -> nn.Module:
"""Return a module wrapper that exposes tensor outputs for FSDP hooks."""
return FSDPOutputTensorAdapter(module)


def distribute_over_domain_for_fsdp(
module: nn.Module,
device_mesh: DeviceMesh,
partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None,
) -> nn.Module:
"""Distribute a module over a domain mesh and adapt outputs for FSDP."""
distributed_module = distribute_module(
module,
device_mesh=device_mesh,
partition_fn=partition_fn,
)
return wrap_for_fsdp(distributed_module)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Uncommitted-design symbols exported publicly

FSDPOutputTensorAdapter, wrap_for_fsdp, and distribute_over_domain_for_fsdp are added with a ### TODO / Do we still need this? / I think we do not comment but are simultaneously exported from physicsnemo/domain_parallel/__init__.py. Shipping an export under a "probably dead" comment makes it hard for downstream users to know whether to depend on it, and removes the ability to quietly delete it later without a breaking-change notice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant