Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/infiniop/ops/quickgelu/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#ifndef __QUICKGELU_CUDA_H__
#define __QUICKGELU_CUDA_H__

#include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace op::quickgelu::cuda {

typedef struct QuickGeluOp {
Expand All @@ -29,7 +25,7 @@ public:
half sigmoid = hrcp(denominator);
return __hmul(x, sigmoid);

} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x);
float ax = alpha * xf;
float s = 1.0f / (1.0f + __expf(-ax));
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/quickgelu/metax/quickgelu_metax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __QUICKGELU_METAX_API_H__
#define __QUICKGELU_METAX_API_H__

#include "../../../elementwise/metax/elementwise_metax_api.h"

ELEMENTWISE_DESCRIPTOR(quickgelu, metax)

#endif // __QUICKGELU_METAX_API_H__
60 changes: 60 additions & 0 deletions src/infiniop/ops/quickgelu/metax/quickgelu_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "quickgelu_metax.h"

#include "../../../elementwise/metax/elementwise_metax.h"

#include "../cuda/kernel.cuh"

namespace op::quickgelu::metax {

Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
auto dtype = out_desc->dtype();

const auto &input_desc = input_desc_vec.at(0);
const auto &output_shape = out_desc->shape();
const auto &input_shape = input_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);

CHECK_SAME_SHAPE(output_shape, input_shape);

// create METAX elementwise descriptor
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

switch (_dtype) {
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::QuickGeluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::QuickGeluOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::QuickGeluOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::QuickGeluOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

return INFINI_STATUS_SUCCESS;
}
} // namespace op::quickgelu::metax
2 changes: 1 addition & 1 deletion src/infiniop/ops/quickgelu/nvidia/quickgelu_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ infiniStatus_t Descriptor::calculate(
_info, workspace, output, inputs, stream);

case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::QuickGeluOp, __nv_bfloat16>(
return _device_info->calculate<256, cuda::QuickGeluOp, cuda_bfloat16>(
_info, workspace, output, inputs, stream);

case INFINI_DTYPE_F32:
Expand Down
15 changes: 15 additions & 0 deletions src/infiniop/ops/quickgelu/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/quickgelu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/quickgelu_metax.h"
#endif

__INFINI_C infiniStatus_t infiniopCreateQuickGeluDescriptor(
infiniopHandle_t handle,
Expand Down Expand Up @@ -38,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreateQuickGeluDescriptor(
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -68,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetQuickGeluWorkspaceSize(infiniopQuickGeluDes
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -104,6 +113,9 @@ __INFINI_C infiniStatus_t infiniopQuickGelu(
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down Expand Up @@ -134,6 +146,9 @@ __INFINI_C infiniStatus_t infiniopDestroyQuickGeluDescriptor(infiniopQuickGeluDe
#endif
#ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
29 changes: 29 additions & 0 deletions test/infiniop/libinfiniop/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,35 @@ def sigmoid_(lib):
]


@OpRegister.operator
def quickgelu_(lib):
lib.infiniopCreateQuickGeluDescriptor.restype = c_int32
lib.infiniopCreateQuickGeluDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetQuickGeluWorkspaceSize.restype = c_int32
lib.infiniopGetQuickGeluWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopQuickGelu.restype = c_int32
lib.infiniopQuickGelu.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyQuickGeluDescriptor.restype = c_int32
lib.infiniopDestroyQuickGeluDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]


@OpRegister.operator
def topksoftmax_(lib):
lib.infiniopCreateTopksoftmaxDescriptor.restype = c_int32
Expand Down
176 changes: 176 additions & 0 deletions test/infiniop/quickgelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import torch
import ctypes
from ctypes import c_uint64
from libinfiniop import (
LIBINFINIOP,
TestTensor,
get_test_devices,
check_error,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
TestWorkspace,
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto

# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# shape, x_stride, y_stride
((13, 4), None, None),
((13, 4), (10, 1), (10, 1)),
((13, 4), (0, 1), (0, 1)),
((13, 4, 4), None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1)),
((13, 4, 4), (4, 0, 1), (4, 0, 1)),
((16, 5632), None, None),
((16, 5632), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)),
((4, 4, 56320), None, None),
]


class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()


# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]

# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]

# Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]

# Tolerance map for different data types
_TOLERANCE_MAP = {
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6},
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


def torch_quickgelu(y, x):
# quickgelu(x) = x * sigmoid(1.702 * x)
alpha = 1.702
y.copy_(x * torch.sigmoid(alpha * x))


def test(
handle,
device,
shape,
x_stride=None,
y_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16,
sync=None,
):
x = TestTensor(shape, x_stride, dtype, device)
if inplace == Inplace.INPLACE_X:
if x_stride != y_stride:
return
y = x
else:
y = TestTensor(shape, y_stride, dtype, device, mode="ones")

if y.is_broadcast():
return

print(
f"Testing QuickGelu on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} "
f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
)

torch_quickgelu(y.torch_tensor(), x.torch_tensor())

if sync is not None:
sync()

descriptor = infiniopOperatorDescriptor_t()
check_error(
LIBINFINIOP.infiniopCreateQuickGeluDescriptor(
handle,
ctypes.byref(descriptor),
y.descriptor,
x.descriptor,
)
)

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [x, y]:
tensor.destroy_desc()

workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetQuickGeluWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, y.device)

def lib_quickgelu():
check_error(
LIBINFINIOP.infiniopQuickGelu(
descriptor,
workspace.data(),
workspace.size(),
y.data(),
x.data(),
None,
)
)

lib_quickgelu()

atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)

assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)

# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: torch_quickgelu(y.torch_tensor(), x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_quickgelu(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyQuickGeluDescriptor(descriptor))


if __name__ == "__main__":
args = get_args()

# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations

for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)

print("\033[92m Test passed! \033[0m")