优化ep并行时专家并行权重分发逻辑#200
Conversation
…ert parallel FSDP loading
There was a problem hiding this comment.
Code Review
This pull request updates the expert parallel FSDP strategy to support DTensor sharding during weight loading. While the changes introduce necessary logic for sharding tensors across mesh ranks, the implementation of _scatter_ep_expert_tensor is highly inefficient. Using dist.broadcast in a loop over all ranks creates O(N^2) communication complexity and excessive memory overhead. Furthermore, the inclusion of explicit device synchronizations within these loops will significantly degrade performance. Finally, there is a discrepancy between the PR description's claim of hierarchical loading and the implementation, which still relies solely on global rank 0.
| for rank in range(1, world_size): | ||
| recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype) | ||
| dist.broadcast(recv_tensor, src=0) | ||
| torch_util.synchronize() | ||
| if current_rank == rank: | ||
| local_tensor.copy_(recv_tensor) | ||
| del recv_tensor |
There was a problem hiding this comment.
The implementation of _scatter_ep_expert_tensor using dist.broadcast in a loop over all ranks is highly inefficient.
- Communication Complexity: This approach results in
O(N^2)total communication volume (where N is the world size) because every rank participates in every broadcast, receiving data intended for every other rank and discarding it. For large clusters (e.g., 1024+ GPUs), this will be a massive bottleneck. - Memory Overhead: Every rank performs a GPU allocation (
torch.empty) in every iteration of the loop to match the shape of the broadcasted chunk. This adds significant allocation/deallocation overhead.
Consider using dist.scatter if the shards can be prepared on rank 0, or stick to point-to-point dist.send/dist.recv which only moves O(Weight) data in total. If send/recv is unstable, a hierarchical broadcast (to local rank 0s first) would be more appropriate.
| is_rank0 = (dist.get_rank() == 0) | ||
| expert_shard_specs = expert_shard_specs or {} | ||
| rank_to_ep_rank = rank_to_ep_rank or {} | ||
| rank_to_ep_fsdp_rank = rank_to_ep_fsdp_rank or {} | ||
|
|
||
| source_metadata = None | ||
| if is_rank0: |
There was a problem hiding this comment.
The PR description states that this change supports using each node's local rank0 as the weight loading source to reduce global rank 0 fan-out pressure. However, the implementation in _broadcast_sharded_state_dict still relies exclusively on global rank 0 (dist.get_rank() == 0) as the source for all metadata and weight broadcasts. This does not appear to implement the hierarchical loading logic mentioned in the description.
| torch_util.synchronize() | ||
| del chunk, local_chunk, chunk_gpu | ||
| else: | ||
| dist.recv(local_tensor, src=0) | ||
| world_size = dist.get_world_size() | ||
| for rank in range(1, world_size): | ||
| recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype) | ||
| dist.broadcast(recv_tensor, src=0) | ||
| torch_util.synchronize() |
There was a problem hiding this comment.
Explicitly calling torch_util.synchronize() (which likely performs a device-wide synchronization) inside the loop after every broadcast will significantly degrade performance. Standard dist.broadcast calls on CUDA/NPU tensors are already synchronized with respect to the communication stream. Draining the GPU command pipeline in every iteration will make the initialization process extremely slow, especially for models with many layers.
PR type
PR information
优化 EP 并行与
memory_efficient_init=True同时开启时的权重初始化与分发逻辑。此前在多机 EP + FSDP2 场景下,权重初始化主要依赖全局 rank0 加载完整权重并向所有 rank 分发 expert 权重。随着机器数和 rank 数增大,这种全局 fan-out 容易触发 HCCL 通信资源耗尽、broadcast/send 超时或初始化卡住。
本 PR 调整了 EP 场景下的 memory efficient init 流程:
memory_efficient_init=True且开启 EP 时,支持每个节点的 local rank0 作为本节点权重加载源。该改动主要用于提升 DeepSeek V4 / MoE 大模型在多机 EP + FSDP2 场景下的初始化稳定性。
Experiment results