Enable Compile for Shard Tensor#1682
Conversation
…ch api shifts from breaking for us. Also increase conv image size for better test stability
|
One other notable thing here: our ring attention implementation is not traceable at the moment. streams and overlapped communication / computation is not traceable. I think we can resolve that though it's not here quite yet. |
Greptile SummaryThis PR enables
Important Files Changed
Reviews (1): Last reviewed commit: "This commit enables torch.compile for mo..." | Re-trigger Greptile |
| # This has to be funcol collectives, below, to be | ||
| # conmpatible with torc.compile. |
There was a problem hiding this comment.
Two typos in this comment: "conmpatible" → "compatible" and "torc.compile" → "torch.compile".
| # This has to be funcol collectives, below, to be | |
| # conmpatible with torc.compile. | |
| # This has to be funcol collectives, below, to be | |
| # compatible with torch.compile. |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if needs_redistribute: | ||
| self = self.redistribute(placements=new_placements) | ||
|
|
||
| if self.grad_fn is not None: | ||
| return torch.Tensor.backward(self, *args, **kwargs) | ||
|
|
||
| return self.to_local().backward(*args, **kwargs) | ||
|
|
||
|
|
||
| ### TODO | ||
| ### Do we still need this? | ||
| ### I think we do not - CJA | ||
|
|
||
|
|
||
| class FSDPOutputTensorAdapter(nn.Module): | ||
| """Wrap a module and convert ShardTensor outputs to torch.Tensor.""" | ||
|
|
||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| out = self.module(*args, **kwargs) | ||
| return out.to_local() if isinstance(out, ShardTensor) else out | ||
|
|
||
|
|
||
| def wrap_for_fsdp(module: nn.Module) -> nn.Module: | ||
| """Return a module wrapper that exposes tensor outputs for FSDP hooks.""" | ||
| return FSDPOutputTensorAdapter(module) | ||
|
|
||
|
|
||
| def distribute_over_domain_for_fsdp( | ||
| module: nn.Module, | ||
| device_mesh: DeviceMesh, | ||
| partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None, | ||
| ) -> nn.Module: | ||
| """Distribute a module over a domain mesh and adapt outputs for FSDP.""" | ||
| distributed_module = distribute_module( | ||
| module, | ||
| device_mesh=device_mesh, | ||
| partition_fn=partition_fn, | ||
| ) | ||
| return wrap_for_fsdp(distributed_module) |
There was a problem hiding this comment.
Uncommitted-design symbols exported publicly
FSDPOutputTensorAdapter, wrap_for_fsdp, and distribute_over_domain_for_fsdp are added with a ### TODO / Do we still need this? / I think we do not comment but are simultaneously exported from physicsnemo/domain_parallel/__init__.py. Shipping an export under a "probably dead" comment makes it hard for downstream users to know whether to depend on it, and removes the ability to quietly delete it later without a breaking-change notice.
PhysicsNeMo Pull Request
This PR enables compilation for ShardTensor. It depends on the refactor of shard tensor #1556 still - that has to merge first - so the diff will be a little smaller in the end, since that PR is included here too for now.
The main requirements for enabling torch.compile are:
setup_contextstyle autograd functions everywheretorch.Sizein favor ofint. This primarily impacts the ShardTensorSpec objects.torch_flattenand it's inverse are available for tracingThere is also a minor bug fix here in the knn op that shows up in shard tensor spec determining the size of sharding shapes automatically.
Two test suites are added: one to test compilation "in general" and one to verify that each op we're wrapping is tracable via the compiler. It's implicit that DTensor fallback ops are compilable, though verification of that is not really done by us ...
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.