Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ struct DevicePriority<Device::Type::kCambricon> {
static constexpr int value = 5;
};

enum class MemorySpace : uint8_t { kHost = 0, kDevice = 1, kUnknown };

template <Device::Type kDev>
MemorySpace GetMemorySpace(const void *ptr);

} // namespace infini::ccl

#endif // INFINI_CCL_DEVICE_H_
34 changes: 34 additions & 0 deletions src/metax/checks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef INFINI_CCL_METAX_CHECKS_H_
#define INFINI_CCL_METAX_CHECKS_H_

// clang-format off
#include <mcr/mc_runtime.h>
// clang-format on

#include <iostream>

#include "return_status_impl.h"

#define INFINI_CHECK_MACA(result) \
::infini::ccl::detail::CheckMacaImpl((result), __FILE__, __LINE__)

namespace infini::ccl {

namespace detail {

inline ReturnStatus CheckMacaImpl(mcError_t maca_result, const char *file,
int line) {
if (maca_result != mcSuccess) {
mcGetLastError();
std::cerr << "MACA error code: " << maca_result << " at line " << line
<< " in " << file << std::endl;
std::abort();
}
return ReturnStatus::kSuccess;
}

} // namespace detail

} // namespace infini::ccl

#endif // INFINI_CCL_METAX_CHECKS_H_
22 changes: 22 additions & 0 deletions src/metax/device_.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
#ifndef INFINI_CCL_METAX_DEVICE__H_
#define INFINI_CCL_METAX_DEVICE__H_

// clang-format off
#include <mcr/mc_runtime.h>
// clang-format on

#include "checks.h"
#include "device.h"

namespace infini::ccl {

template <>
struct DeviceEnabled<Device::Type::kMetax> : std::true_type {};

template <>
MemorySpace GetMemorySpace<Device::Type::kMetax>(const void* ptr) {
if (!ptr) {
return MemorySpace::kHost;
}

mcPointerAttribute_t attr;
INFINI_CHECK_MACA(mcPointerGetAttributes(&attr, ptr));

if (attr.type == mcMemoryTypeDevice || attr.type == mcMemoryTypeManaged ||
attr.type == mcMemoryTypeArray) {
return MemorySpace::kDevice;
}

return MemorySpace::kHost;
}

} // namespace infini::ccl

#endif // INFINI_CCL_METAX_DEVICE__H_
32 changes: 32 additions & 0 deletions src/nvidia/checks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef INFINI_CCL_NVIDIA_CHECKS_H_
#define INFINI_CCL_NVIDIA_CHECKS_H_

#include <cuda_runtime.h>

#include <iostream>

#include "return_status_impl.h"

#define INFINI_CHECK_CUDA(result) \
::infini::ccl::detail::CheckCudaImpl((result), __FILE__, __LINE__)

namespace infini::ccl {

namespace detail {

inline ReturnStatus CheckCudaImpl(cudaError_t cuda_result, const char *file,
int line) {
if (cuda_result != cudaSuccess) {
cudaGetLastError();
std::cerr << "CUDA error code: " << cuda_result << " at line " << line
<< " in " << file << std::endl;
std::abort();
}
return ReturnStatus::kSuccess;
}

} // namespace detail

} // namespace infini::ccl

#endif // INFINI_CCL_NVIDIA_CHECKS_H_
19 changes: 19 additions & 0 deletions src/nvidia/device_.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
#ifndef INFINI_CCL_NVIDIA_DEVICE__H_
#define INFINI_CCL_NVIDIA_DEVICE__H_

#include <cuda_runtime.h>

#include "checks.h"
#include "device.h"

namespace infini::ccl {

template <>
struct DeviceEnabled<Device::Type::kNvidia> : std::true_type {};

template <>
MemorySpace GetMemorySpace<Device::Type::kNvidia>(const void* ptr) {
if (!ptr) {
return MemorySpace::kHost;
}

cudaPointerAttributes attr;
INFINI_CHECK_CUDA(cudaPointerGetAttributes(&attr, ptr));

if (attr.type == cudaMemoryTypeDevice || attr.type == cudaMemoryTypeManaged) {
return MemorySpace::kDevice;
}

return MemorySpace::kHost;
}

} // namespace infini::ccl

#endif
87 changes: 60 additions & 27 deletions src/ompi/impl/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,81 @@ class BroadcastImpl<BackendType::kOmpi, device_type> {
using Rt = Runtime<kDev>;

auto *inst = static_cast<OmpiInstance *>(comm->inter_comm());
if (!inst || inst->handle == MPI_COMM_NULL) {
LOG("Invalid OpenMPI communicator instance for Broadcast.");
return ReturnStatus::kInternalError;
size_t total_bytes = 0;
ReturnStatus status = ValidateArgs(inst, count, data_type, total_bytes);
if (status != ReturnStatus::kSuccess) {
return status;
}

size_t type_size = kDataTypeToSize.at(data_type);
if (count > std::numeric_limits<size_t>::max() / type_size) {
LOG("Broadcast byte size overflow.");
return ReturnStatus::kInvalidArgument;
}
const int rank = comm->rank();
const bool is_root = (rank == root);
const bool is_out_of_place = (send_buff != recv_buff);

size_t total_bytes = count * type_size;
void *host_buf = std::malloc(total_bytes);
if (!host_buf) {
LOG("Failed to allocate host buffer for Broadcast staging.");
return ReturnStatus::kSystemError;
}
// Resolve memory topology using the buffer guaranteed to be valid on all
// nodes.
const MemorySpace space = GetMemorySpace<kDev>(recv_buff);
const bool is_device = (space == MemorySpace::kDevice);
char *active_buf = nullptr;

if (is_device) {
active_buf = static_cast<char *>(std::malloc(total_bytes));
if (!active_buf) {
LOG("Failed to allocate host buffer for `Broadcast` staging.");
return ReturnStatus::kSystemError;
}

if (comm->rank() == root) {
CHECK_STATUS(Rt, Rt::Memcpy(host_buf, send_buff, total_bytes,
Rt::MemcpyDeviceToHost));
CHECK_STATUS(Rt, Rt::StreamSynchronize(static_cast<Rt::Stream>(stream)));
if (is_root) {
CHECK_STATUS(Rt, Rt::Memcpy(active_buf, send_buff, total_bytes,
Rt::MemcpyDeviceToHost));
CHECK_STATUS(Rt,
Rt::StreamSynchronize(static_cast<Rt::Stream>(stream)));
}
} else {
if (is_root && is_out_of_place && recv_buff != nullptr) {
std::memcpy(recv_buff, send_buff, total_bytes);
}
active_buf = static_cast<char *>(recv_buff);
}

auto *bytes = static_cast<char *>(host_buf);
size_t offset = 0;
constexpr size_t kMaxMpiCount =
static_cast<size_t>(std::numeric_limits<int>::max());

while (offset < total_bytes) {
size_t chunk = total_bytes - offset;
if (chunk > kMaxMpiCount) {
chunk = kMaxMpiCount;
}
INFINI_CHECK_MPI(MPI_Bcast(bytes + offset, static_cast<int>(chunk),
size_t chunk = std::min(total_bytes - offset, kMaxMpiCount);
INFINI_CHECK_MPI(MPI_Bcast(active_buf + offset, static_cast<int>(chunk),
MPI_BYTE, root, inst->handle));
offset += chunk;
}

CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, host_buf, total_bytes,
Rt::MemcpyHostToDevice));
if (is_device) {
if (!is_root || is_out_of_place) {
CHECK_STATUS(Rt, Rt::Memcpy(recv_buff, active_buf, total_bytes,
Rt::MemcpyHostToDevice));
}
std::free(active_buf);
}

return ReturnStatus::kSuccess;
}

private:
static ReturnStatus ValidateArgs(const OmpiInstance *inst, size_t count,
DataType data_type,
size_t &out_total_bytes) {
if (!inst || inst->handle == MPI_COMM_NULL) {
LOG("Invalid OpenMPI communicator instance for Broadcast.");
return ReturnStatus::kInternalError;
}

size_t type_size = kDataTypeToSize.at(data_type);
if (count > std::numeric_limits<size_t>::max() / type_size) {
LOG("Broadcast byte size overflow.");
return ReturnStatus::kInvalidArgument;
}

out_total_bytes = count * type_size;

std::free(host_buf);
return ReturnStatus::kSuccess;
}
};
Expand Down
Loading