diff --git a/src/triton/ops/rms_norm/build.py b/src/triton/ops/rms_norm/build.py new file mode 100644 index 000000000..d3fbc1f71 --- /dev/null +++ b/src/triton/ops/rms_norm/build.py @@ -0,0 +1,44 @@ +from scripts.triton import aot + +_DTYPES = ("fp16", "bf16", "fp32") +_BLOCK_SIZES = (2048,) +_ALIGNMENTS = (16, None) +_NUM_WARPS = 8 +_NUM_STAGES = 3 +_DATA_PTRS = ("x_ptr", "w_ptr", "y_ptr") +_I32_SCALARS = ("m", "n") +_I64_SCALARS = ( + "stride_xm", + "stride_xn", + "stride_wn", + "stride_ym", + "stride_yn", +) + + +def _signature(dtype, block_size, alignment): + return aot.Signature( + pointer_dtypes={name: dtype for name in _DATA_PTRS}, + pointer_alignments={name: alignment for name in _DATA_PTRS}, + scalar_dtypes={ + "eps": "fp32", + **{name: "i32" for name in _I32_SCALARS}, + **{name: "i64" for name in _I64_SCALARS}, + }, + constexprs={"BLOCK_SIZE": block_size}, + ) + + +def configs(): + for dtype in _DTYPES: + yield tuple( + aot.CompileConfig( + signature=_signature(dtype, block_size, alignment), + grid="m, 1, 1", + out_name=f"infini_ops_triton_rms_norm_{dtype}", + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) + for block_size in _BLOCK_SIZES + for alignment in _ALIGNMENTS + ) diff --git a/src/triton/ops/rms_norm/rms_norm.h b/src/triton/ops/rms_norm/rms_norm.h new file mode 100644 index 000000000..3c68abde4 --- /dev/null +++ b/src/triton/ops/rms_norm/rms_norm.h @@ -0,0 +1,59 @@ +#ifndef INFINI_OPS_TRITON_RMS_NORM_H_ +#define INFINI_OPS_TRITON_RMS_NORM_H_ + +#include + +#include +#include + +#include "base/rms_norm.h" +#include "data_type.h" +#include "rms_norm/infini_ops_triton_rms_norm.h" + +namespace infini::ops { + +template <> +class Operator : public RmsNorm { + public: + using RmsNorm::operator(); + + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm{input, weight, eps, out} {} + + Operator(const Tensor input, const Tensor weight, Tensor out) + : RmsNorm{input, weight, out} {} + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + assert(input.dtype() == out.dtype() && + "Triton `RmsNorm` requires input and output to have the same dtype"); + + load_infini_ops_triton_rms_norm(out.dtype()); + + const auto input_strides = input.strides(); + const auto weight_strides = weight.strides(); + const auto out_strides = out.strides(); + + const auto n_rows = static_cast(batch_size_ * nhead_); + const auto n_cols = static_cast(dim_); + + const auto stride_xm = static_cast(input_strides[ndim_ - 2]); + const auto stride_xn = static_cast(input_strides[ndim_ - 1]); + const auto stride_wn = static_cast(weight_strides.back()); + const auto stride_ym = static_cast(out_strides[ndim_ - 2]); + const auto stride_yn = static_cast(out_strides[ndim_ - 1]); + + auto result = launch_infini_ops_triton_rms_norm( + out.dtype(), static_cast(stream_), + reinterpret_cast(const_cast(input.data())), + reinterpret_cast(const_cast(weight.data())), + reinterpret_cast(out.data()), eps, n_rows, n_cols, + stride_xm, stride_xn, stride_wn, stride_ym, stride_yn); + + assert(result == CUDA_SUCCESS && "Triton `RmsNorm` launch failed"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/triton/ops/rms_norm/rms_norm.py b/src/triton/ops/rms_norm/rms_norm.py new file mode 100644 index 000000000..f7e23759d --- /dev/null +++ b/src/triton/ops/rms_norm/rms_norm.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl + + +@triton.jit +def kernel( + x_ptr, + w_ptr, + y_ptr, + eps, + m, + n, + stride_xm, + stride_xn, + stride_wn, + stride_ym, + stride_yn, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= m: + return + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < n + + x_row_ptr = x_ptr + pid * stride_xm + y_row_ptr = y_ptr + pid * stride_ym + + x = tl.load(x_row_ptr + offs * stride_xn, mask=mask, other=0.0).to(tl.float32) + w = tl.load(w_ptr + offs * stride_wn, mask=mask, other=0.0).to(tl.float32) + + mean_sq = tl.sum(x * x) / n + rrms = 1.0 / tl.sqrt(mean_sq + eps) + + y = x * rrms * w + + tl.store(y_row_ptr + offs * stride_yn, y, mask=mask)