Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/infiniop/ops/softmax/metax/softmax_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SOFTMAX_METAX_H__
#define __SOFTMAX_METAX_H__

#include "../softmax.h"

DESCRIPTOR(metax)

#endif // __SOFTMAX_METAX_H__
162 changes: 162 additions & 0 deletions src/infiniop/ops/softmax/metax/softmax_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#include "../../../devices/metax/metax_common.h"
#include "softmax_metax.h"

#ifdef ENABLE_METAX_MC_API
#include <cub/block/block_reduce.cuh>
#else
#include <hccub/block/block_reduce.cuh>
#endif
#include "../../../devices/metax/metax_kernel_common.h"

#include "../cuda/kernel.cuh"

template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_METAX_KERNEL blockSoftmax(
Tdata *y, const Tdata *x,
size_t dimsize,
ptrdiff_t stride) {
blockSoftmaxKernel<Tdata, BLOCK_SIZE>(x, y, dimsize, stride);
}

template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y, int numPerThreadx>
INFINIOP_METAX_KERNEL warpSoftmax(
Tdata *y, const Tdata *x,
size_t othersize,
size_t dimsize,
ptrdiff_t stride) {
warpSoftmaxKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>(x, y, othersize, dimsize, stride);
}

namespace op::softmax::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc,
int axis) {
auto info = SoftmaxInfo::create(y_desc, x_desc, axis);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
size_t othersize, size_t dimsize, ptrdiff_t stride,
hcStream_t stream) {
int num_blocks = (int)othersize;
if (dtype == INFINI_DTYPE_F16) {
if (dimsize > 1024) {
blockSoftmax<half, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>((half *)y, (const half *)x,
dimsize, stride);
} else if (dimsize > 31) {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
constexpr int numPerThreadx = 32;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpSoftmax<half, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
<<<grid_dim, block_dim, 0, stream>>>((half *)y, (const half *)x,
othersize, dimsize, stride);
} else {
constexpr unsigned int BLOCK_SIZE_x = 16;
constexpr unsigned int BLOCK_SIZE_y = 32;
constexpr int numPerThreadx = 2;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpSoftmax<half, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
<<<grid_dim, block_dim, 0, stream>>>((half *)y, (const half *)x,
othersize, dimsize, stride);
}

} else if (dtype == INFINI_DTYPE_BF16) {
if (dimsize > 1024) {
blockSoftmax<cuda_bfloat16, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x,
dimsize, stride);
} else if (dimsize > 31) {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
constexpr int numPerThreadx = 32;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpSoftmax<cuda_bfloat16, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
<<<grid_dim, block_dim, 0, stream>>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x,
othersize, dimsize, stride);
} else {
constexpr unsigned int BLOCK_SIZE_x = 16;
constexpr unsigned int BLOCK_SIZE_y = 32;
constexpr int numPerThreadx = 2;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpSoftmax<cuda_bfloat16, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
<<<grid_dim, block_dim, 0, stream>>>((cuda_bfloat16 *)y, (const cuda_bfloat16 *)x,
othersize, dimsize, stride);
}

} else if (dtype == INFINI_DTYPE_F32) {
if (dimsize > 1024) {
blockSoftmax<float, BLOCK_SIZE>
<<<num_blocks, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
dimsize, stride);
} else if (dimsize > 31) {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
constexpr int numPerThreadx = 32;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpSoftmax<float, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
<<<grid_dim, block_dim, 0, stream>>>((float *)y, (const float *)x,
othersize, dimsize, stride);
} else {
constexpr unsigned int BLOCK_SIZE_x = 16;
constexpr unsigned int BLOCK_SIZE_y = 32;
constexpr int numPerThreadx = 2;
int num_block_x = (num_blocks + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
warpSoftmax<float, BLOCK_SIZE_x, BLOCK_SIZE_y, numPerThreadx>
<<<grid_dim, block_dim, 0, stream>>>((float *)y, (const float *)x,
othersize, dimsize, stride);
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
y, x, _info.dtype, _info.othersize, _info.dimsize, _info.stride, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
y, x, _info.dtype, _info.othersize, _info.dimsize, _info.stride, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}

} // namespace op::softmax::metax
16 changes: 16 additions & 0 deletions src/infiniop/ops/softmax/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include "nvidia/softmax_nvidia.cuh"
#endif

#ifdef ENABLE_METAX_API
#include "metax/softmax_metax.h"
#endif

__INFINI_C infiniStatus_t infiniopCreateSoftmaxDescriptor(
infiniopHandle_t handle,
infiniopSoftmaxDescriptor_t *desc_ptr,
Expand Down Expand Up @@ -43,6 +47,9 @@ __INFINI_C infiniStatus_t infiniopCreateSoftmaxDescriptor(
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -74,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescrip
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -110,6 +120,9 @@ __INFINI_C infiniStatus_t infiniopSoftmax(
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -141,6 +154,9 @@ __INFINI_C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescri
#endif
#ifdef ENABLE_CAMBRICON_API
DESTROY(INFINI_DEVICE_CAMBRICON, bang);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down