diff --git a/src/device.h b/src/device.h index f877f03..67e1b72 100644 --- a/src/device.h +++ b/src/device.h @@ -155,6 +155,11 @@ struct DevicePriority { static constexpr int value = 5; }; +enum class MemorySpace : uint8_t { kHost = 0, kDevice = 1, kUnknown }; + +template +MemorySpace GetMemorySpace(const void *ptr); + } // namespace infini::ccl #endif // INFINI_CCL_DEVICE_H_ diff --git a/src/metax/checks.h b/src/metax/checks.h new file mode 100644 index 0000000..c88b261 --- /dev/null +++ b/src/metax/checks.h @@ -0,0 +1,34 @@ +#ifndef INFINI_CCL_METAX_CHECKS_H_ +#define INFINI_CCL_METAX_CHECKS_H_ + +// clang-format off +#include +// clang-format on + +#include + +#include "return_status_impl.h" + +#define INFINI_CHECK_MACA(result) \ + ::infini::ccl::detail::CheckMacaImpl((result), __FILE__, __LINE__) + +namespace infini::ccl { + +namespace detail { + +inline ReturnStatus CheckMacaImpl(mcError_t maca_result, const char *file, + int line) { + if (maca_result != mcSuccess) { + mcGetLastError(); + std::cerr << "MACA error code: " << maca_result << " at line " << line + << " in " << file << std::endl; + std::abort(); + } + return ReturnStatus::kSuccess; +} + +} // namespace detail + +} // namespace infini::ccl + +#endif // INFINI_CCL_METAX_CHECKS_H_ diff --git a/src/metax/device_.h b/src/metax/device_.h index 5902631..a14df84 100644 --- a/src/metax/device_.h +++ b/src/metax/device_.h @@ -1,6 +1,11 @@ #ifndef INFINI_CCL_METAX_DEVICE__H_ #define INFINI_CCL_METAX_DEVICE__H_ +// clang-format off +#include +// clang-format on + +#include "checks.h" #include "device.h" namespace infini::ccl { @@ -8,6 +13,23 @@ namespace infini::ccl { template <> struct DeviceEnabled : std::true_type {}; +template <> +MemorySpace GetMemorySpace(const void* ptr) { + if (!ptr) { + return MemorySpace::kHost; + } + + mcPointerAttribute_t attr; + INFINI_CHECK_MACA(mcPointerGetAttributes(&attr, ptr)); + + if (attr.type == mcMemoryTypeDevice || attr.type == mcMemoryTypeManaged || + attr.type == mcMemoryTypeArray) { + return MemorySpace::kDevice; + } + + return MemorySpace::kHost; +} + } // namespace infini::ccl #endif // INFINI_CCL_METAX_DEVICE__H_ diff --git a/src/nvidia/checks.h b/src/nvidia/checks.h new file mode 100644 index 0000000..c42a4e6 --- /dev/null +++ b/src/nvidia/checks.h @@ -0,0 +1,32 @@ +#ifndef INFINI_CCL_NVIDIA_CHECKS_H_ +#define INFINI_CCL_NVIDIA_CHECKS_H_ + +#include + +#include + +#include "return_status_impl.h" + +#define INFINI_CHECK_CUDA(result) \ + ::infini::ccl::detail::CheckCudaImpl((result), __FILE__, __LINE__) + +namespace infini::ccl { + +namespace detail { + +inline ReturnStatus CheckCudaImpl(cudaError_t cuda_result, const char *file, + int line) { + if (cuda_result != cudaSuccess) { + cudaGetLastError(); + std::cerr << "CUDA error code: " << cuda_result << " at line " << line + << " in " << file << std::endl; + std::abort(); + } + return ReturnStatus::kSuccess; +} + +} // namespace detail + +} // namespace infini::ccl + +#endif // INFINI_CCL_NVIDIA_CHECKS_H_ diff --git a/src/nvidia/device_.h b/src/nvidia/device_.h index 8192cfb..f73b71c 100644 --- a/src/nvidia/device_.h +++ b/src/nvidia/device_.h @@ -1,6 +1,9 @@ #ifndef INFINI_CCL_NVIDIA_DEVICE__H_ #define INFINI_CCL_NVIDIA_DEVICE__H_ +#include + +#include "checks.h" #include "device.h" namespace infini::ccl { @@ -8,6 +11,22 @@ namespace infini::ccl { template <> struct DeviceEnabled : std::true_type {}; +template <> +MemorySpace GetMemorySpace(const void* ptr) { + if (!ptr) { + return MemorySpace::kHost; + } + + cudaPointerAttributes attr; + INFINI_CHECK_CUDA(cudaPointerGetAttributes(&attr, ptr)); + + if (attr.type == cudaMemoryTypeDevice || attr.type == cudaMemoryTypeManaged) { + return MemorySpace::kDevice; + } + + return MemorySpace::kHost; +} + } // namespace infini::ccl #endif diff --git a/src/ompi/impl/broadcast.h b/src/ompi/impl/broadcast.h index f691954..7031005 100644 --- a/src/ompi/impl/broadcast.h +++ b/src/ompi/impl/broadcast.h @@ -25,48 +25,81 @@ class BroadcastImpl { using Rt = Runtime; auto *inst = static_cast(comm->inter_comm()); - if (!inst || inst->handle == MPI_COMM_NULL) { - LOG("Invalid OpenMPI communicator instance for Broadcast."); - return ReturnStatus::kInternalError; + size_t total_bytes = 0; + ReturnStatus status = ValidateArgs(inst, count, data_type, total_bytes); + if (status != ReturnStatus::kSuccess) { + return status; } - size_t type_size = kDataTypeToSize.at(data_type); - if (count > std::numeric_limits::max() / type_size) { - LOG("Broadcast byte size overflow."); - return ReturnStatus::kInvalidArgument; - } + const int rank = comm->rank(); + const bool is_root = (rank == root); + const bool is_out_of_place = (send_buff != recv_buff); - size_t total_bytes = count * type_size; - void *host_buf = std::malloc(total_bytes); - if (!host_buf) { - LOG("Failed to allocate host buffer for Broadcast staging."); - return ReturnStatus::kSystemError; - } + // Resolve memory topology using the buffer guaranteed to be valid on all + // nodes. + const MemorySpace space = GetMemorySpace(recv_buff); + const bool is_device = (space == MemorySpace::kDevice); + char *active_buf = nullptr; + + if (is_device) { + active_buf = static_cast(std::malloc(total_bytes)); + if (!active_buf) { + LOG("Failed to allocate host buffer for `Broadcast` staging."); + return ReturnStatus::kSystemError; + } - if (comm->rank() == root) { - CHECK_STATUS(Rt, Rt::Memcpy(host_buf, send_buff, total_bytes, - Rt::MemcpyDeviceToHost)); - CHECK_STATUS(Rt, Rt::StreamSynchronize(static_cast(stream))); + if (is_root) { + CHECK_STATUS(Rt, Rt::Memcpy(active_buf, send_buff, total_bytes, + Rt::MemcpyDeviceToHost)); + CHECK_STATUS(Rt, + Rt::StreamSynchronize(static_cast(stream))); + } + } else { + if (is_root && is_out_of_place && recv_buff != nullptr) { + std::memcpy(recv_buff, send_buff, total_bytes); + } + active_buf = static_cast(recv_buff); } - auto *bytes = static_cast(host_buf); size_t offset = 0; constexpr size_t kMaxMpiCount = static_cast(std::numeric_limits::max()); + while (offset < total_bytes) { - size_t chunk = total_bytes - offset; - if (chunk > kMaxMpiCount) { - chunk = kMaxMpiCount; - } - INFINI_CHECK_MPI(MPI_Bcast(bytes + offset, static_cast(chunk), + size_t chunk = std::min(total_bytes - offset, kMaxMpiCount); + INFINI_CHECK_MPI(MPI_Bcast(active_buf + offset, static_cast(chunk), MPI_BYTE, root, inst->handle)); offset += chunk; } - CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, host_buf, total_bytes, - Rt::MemcpyHostToDevice)); + if (is_device) { + if (!is_root || is_out_of_place) { + CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, active_buf, total_bytes, + Rt::MemcpyHostToDevice)); + } + std::free(active_buf); + } + + return ReturnStatus::kSuccess; + } + + private: + static ReturnStatus ValidateArgs(const OmpiInstance *inst, size_t count, + DataType data_type, + size_t &out_total_bytes) { + if (!inst || inst->handle == MPI_COMM_NULL) { + LOG("Invalid OpenMPI communicator instance for Broadcast."); + return ReturnStatus::kInternalError; + } + + size_t type_size = kDataTypeToSize.at(data_type); + if (count > std::numeric_limits::max() / type_size) { + LOG("Broadcast byte size overflow."); + return ReturnStatus::kInvalidArgument; + } + + out_total_bytes = count * type_size; - std::free(host_buf); return ReturnStatus::kSuccess; } };