From 0054b77b130348ff87d2edb40e23a1f38200d63c Mon Sep 17 00:00:00 2001 From: LindseyMei <648816901@qq.com> Date: Thu, 25 Jun 2026 16:12:10 +0000 Subject: [PATCH] feat: support quickgelu operator on metax Add MetaX backend for the quickgelu elementwise operator, reusing the existing cuda::QuickGeluOp kernel through the elementwise MetaX descriptor. Changes: - Add quickgelu/metax/quickgelu_metax.{h,maca} - Wire MetaX into quickgelu/operator.cc - Clean up quickgelu/cuda/kernel.cuh: remove nvidia-specific elementwise include and use cuda_bfloat16 for cross-backend compatibility - Update nvidia/quickgelu_nvidia.cu to use cuda_bfloat16 - Register quickgelu ctypes bindings in test/libinfiniop/op_register.py - Add test/infiniop/quickgelu.py for correctness verification Verified with test/infiniop/quickgelu.py --metax on MetaX C500: passes accuracy check against torch reference (x * sigmoid(1.702 * x)) across shapes/strides and inplace/out-of-place for F16/F32/BF16. Signed-off-by: LindseyMei <648816901@qq.com> --- src/infiniop/ops/quickgelu/cuda/kernel.cuh | 6 +- .../ops/quickgelu/metax/quickgelu_metax.h | 8 + .../ops/quickgelu/metax/quickgelu_metax.maca | 60 ++++++ .../ops/quickgelu/nvidia/quickgelu_nvidia.cu | 2 +- src/infiniop/ops/quickgelu/operator.cc | 15 ++ test/infiniop/libinfiniop/op_register.py | 29 +++ test/infiniop/quickgelu.py | 176 ++++++++++++++++++ 7 files changed, 290 insertions(+), 6 deletions(-) create mode 100644 src/infiniop/ops/quickgelu/metax/quickgelu_metax.h create mode 100644 src/infiniop/ops/quickgelu/metax/quickgelu_metax.maca create mode 100644 test/infiniop/quickgelu.py diff --git a/src/infiniop/ops/quickgelu/cuda/kernel.cuh b/src/infiniop/ops/quickgelu/cuda/kernel.cuh index 2c13c4b9d..47b300868 100644 --- a/src/infiniop/ops/quickgelu/cuda/kernel.cuh +++ b/src/infiniop/ops/quickgelu/cuda/kernel.cuh @@ -1,10 +1,6 @@ #ifndef __QUICKGELU_CUDA_H__ #define __QUICKGELU_CUDA_H__ -#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" -#include -#include - namespace op::quickgelu::cuda { typedef struct QuickGeluOp { @@ -29,7 +25,7 @@ public: half sigmoid = hrcp(denominator); return __hmul(x, sigmoid); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { float xf = __bfloat162float(x); float ax = alpha * xf; float s = 1.0f / (1.0f + __expf(-ax)); diff --git a/src/infiniop/ops/quickgelu/metax/quickgelu_metax.h b/src/infiniop/ops/quickgelu/metax/quickgelu_metax.h new file mode 100644 index 000000000..9596ddf0e --- /dev/null +++ b/src/infiniop/ops/quickgelu/metax/quickgelu_metax.h @@ -0,0 +1,8 @@ +#ifndef __QUICKGELU_METAX_API_H__ +#define __QUICKGELU_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(quickgelu, metax) + +#endif // __QUICKGELU_METAX_API_H__ diff --git a/src/infiniop/ops/quickgelu/metax/quickgelu_metax.maca b/src/infiniop/ops/quickgelu/metax/quickgelu_metax.maca new file mode 100644 index 000000000..729366a5f --- /dev/null +++ b/src/infiniop/ops/quickgelu/metax/quickgelu_metax.maca @@ -0,0 +1,60 @@ +#include "quickgelu_metax.h" + +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::quickgelu::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &input_desc = input_desc_vec.at(0); + const auto &output_shape = out_desc->shape(); + const auto &input_shape = input_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); + + CHECK_SAME_SHAPE(output_shape, input_shape); + + // create METAX elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::QuickGeluOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::QuickGeluOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::QuickGeluOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::QuickGeluOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::quickgelu::metax diff --git a/src/infiniop/ops/quickgelu/nvidia/quickgelu_nvidia.cu b/src/infiniop/ops/quickgelu/nvidia/quickgelu_nvidia.cu index 387e08ecb..09968f3f2 100644 --- a/src/infiniop/ops/quickgelu/nvidia/quickgelu_nvidia.cu +++ b/src/infiniop/ops/quickgelu/nvidia/quickgelu_nvidia.cu @@ -49,7 +49,7 @@ infiniStatus_t Descriptor::calculate( _info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::QuickGeluOp, __nv_bfloat16>( + return _device_info->calculate<256, cuda::QuickGeluOp, cuda_bfloat16>( _info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: diff --git a/src/infiniop/ops/quickgelu/operator.cc b/src/infiniop/ops/quickgelu/operator.cc index f85a3e49a..851c96530 100644 --- a/src/infiniop/ops/quickgelu/operator.cc +++ b/src/infiniop/ops/quickgelu/operator.cc @@ -8,6 +8,9 @@ #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #include "nvidia/quickgelu_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/quickgelu_metax.h" +#endif __INFINI_C infiniStatus_t infiniopCreateQuickGeluDescriptor( infiniopHandle_t handle, @@ -38,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreateQuickGeluDescriptor( #endif #ifdef ENABLE_HYGON_API CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -68,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetQuickGeluWorkspaceSize(infiniopQuickGeluDes #endif #ifdef ENABLE_HYGON_API GET(INFINI_DEVICE_HYGON, nvidia) +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -104,6 +113,9 @@ __INFINI_C infiniStatus_t infiniopQuickGelu( #endif #ifdef ENABLE_HYGON_API CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -134,6 +146,9 @@ __INFINI_C infiniStatus_t infiniopDestroyQuickGeluDescriptor(infiniopQuickGeluDe #endif #ifdef ENABLE_HYGON_API DELETE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ad88fcb43..5847ac1f4 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1114,6 +1114,35 @@ def sigmoid_(lib): ] +@OpRegister.operator +def quickgelu_(lib): + lib.infiniopCreateQuickGeluDescriptor.restype = c_int32 + lib.infiniopCreateQuickGeluDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetQuickGeluWorkspaceSize.restype = c_int32 + lib.infiniopGetQuickGeluWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopQuickGelu.restype = c_int32 + lib.infiniopQuickGelu.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyQuickGeluDescriptor.restype = c_int32 + lib.infiniopDestroyQuickGeluDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def topksoftmax_(lib): lib.infiniopCreateTopksoftmaxDescriptor.restype = c_int32 diff --git a/test/infiniop/quickgelu.py b/test/infiniop/quickgelu.py new file mode 100644 index 000000000..680430618 --- /dev/null +++ b/test/infiniop/quickgelu.py @@ -0,0 +1,176 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, x_stride, y_stride + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4), (0, 1), (0, 1)), + ((13, 4, 4), None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (4, 0, 1)), + ((16, 5632), None, None), + ((16, 5632), (13312, 1), (13312, 1)), + ((4, 4, 5632), None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1)), + ((4, 4, 56320), None, None), +] + + +class Inplace(Enum): + OUT_OF_PLACE = auto() + INPLACE_X = auto() + + +# Inplace options applied for each test case in _TEST_CASES_ +_INPLACE = [ + Inplace.OUT_OF_PLACE, + Inplace.INPLACE_X, +] + +# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES +_TEST_CASES = [ + test_case + (inplace_item,) + for test_case in _TEST_CASES_ + for inplace_item in _INPLACE +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def torch_quickgelu(y, x): + # quickgelu(x) = x * sigmoid(1.702 * x) + alpha = 1.702 + y.copy_(x * torch.sigmoid(alpha * x)) + + +def test( + handle, + device, + shape, + x_stride=None, + y_stride=None, + inplace=Inplace.OUT_OF_PLACE, + dtype=torch.float16, + sync=None, +): + x = TestTensor(shape, x_stride, dtype, device) + if inplace == Inplace.INPLACE_X: + if x_stride != y_stride: + return + y = x + else: + y = TestTensor(shape, y_stride, dtype, device, mode="ones") + + if y.is_broadcast(): + return + + print( + f"Testing QuickGelu on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} " + f"dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}" + ) + + torch_quickgelu(y.torch_tensor(), x.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateQuickGeluDescriptor( + handle, + ctypes.byref(descriptor), + y.descriptor, + x.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [x, y]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetQuickGeluWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, y.device) + + def lib_quickgelu(): + check_error( + LIBINFINIOP.infiniopQuickGelu( + descriptor, + workspace.data(), + workspace.size(), + y.data(), + x.data(), + None, + ) + ) + + lib_quickgelu() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_quickgelu(y.torch_tensor(), x.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_quickgelu(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyQuickGeluDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92m Test passed! \033[0m")