Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/triton/ops/rms_norm/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from scripts.triton import aot

_DTYPES = ("fp16", "bf16", "fp32")
_BLOCK_SIZES = (2048,)
_ALIGNMENTS = (16, None)
_NUM_WARPS = 8
_NUM_STAGES = 3
_DATA_PTRS = ("x_ptr", "w_ptr", "y_ptr")
_I32_SCALARS = ("m", "n")
_I64_SCALARS = (
"stride_xm",
"stride_xn",
"stride_wn",
"stride_ym",
"stride_yn",
)


def _signature(dtype, block_size, alignment):
return aot.Signature(
pointer_dtypes={name: dtype for name in _DATA_PTRS},
pointer_alignments={name: alignment for name in _DATA_PTRS},
scalar_dtypes={
"eps": "fp32",
**{name: "i32" for name in _I32_SCALARS},
**{name: "i64" for name in _I64_SCALARS},
},
constexprs={"BLOCK_SIZE": block_size},
)


def configs():
for dtype in _DTYPES:
yield tuple(
aot.CompileConfig(
signature=_signature(dtype, block_size, alignment),
grid="m, 1, 1",
out_name=f"infini_ops_triton_rms_norm_{dtype}",
num_warps=_NUM_WARPS,
num_stages=_NUM_STAGES,
)
for block_size in _BLOCK_SIZES
for alignment in _ALIGNMENTS
)
59 changes: 59 additions & 0 deletions src/triton/ops/rms_norm/rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef INFINI_OPS_TRITON_RMS_NORM_H_
#define INFINI_OPS_TRITON_RMS_NORM_H_

#include <cuda.h>

#include <cassert>
#include <cstdint>

#include "base/rms_norm.h"
#include "data_type.h"
#include "rms_norm/infini_ops_triton_rms_norm.h"

namespace infini::ops {

template <>
class Operator<RmsNorm, Device::Type::kNvidia, 8> : public RmsNorm {
public:
using RmsNorm::operator();

Operator(const Tensor input, const Tensor weight, float eps, Tensor out)
: RmsNorm{input, weight, eps, out} {}

Operator(const Tensor input, const Tensor weight, Tensor out)
: RmsNorm{input, weight, out} {}

void operator()(const Tensor input, const Tensor weight, float eps,
Tensor out) const override {
assert(input.dtype() == out.dtype() &&
"Triton `RmsNorm` requires input and output to have the same dtype");

load_infini_ops_triton_rms_norm(out.dtype());

const auto input_strides = input.strides();
const auto weight_strides = weight.strides();
const auto out_strides = out.strides();

const auto n_rows = static_cast<int32_t>(batch_size_ * nhead_);
const auto n_cols = static_cast<int32_t>(dim_);

const auto stride_xm = static_cast<int64_t>(input_strides[ndim_ - 2]);
const auto stride_xn = static_cast<int64_t>(input_strides[ndim_ - 1]);
const auto stride_wn = static_cast<int64_t>(weight_strides.back());
const auto stride_ym = static_cast<int64_t>(out_strides[ndim_ - 2]);
const auto stride_yn = static_cast<int64_t>(out_strides[ndim_ - 1]);

auto result = launch_infini_ops_triton_rms_norm(
out.dtype(), static_cast<CUstream>(stream_),
reinterpret_cast<CUdeviceptr>(const_cast<void*>(input.data())),
reinterpret_cast<CUdeviceptr>(const_cast<void*>(weight.data())),
reinterpret_cast<CUdeviceptr>(out.data()), eps, n_rows, n_cols,
stride_xm, stride_xn, stride_wn, stride_ym, stride_yn);

assert(result == CUDA_SUCCESS && "Triton `RmsNorm` launch failed");
}
};

} // namespace infini::ops

#endif
38 changes: 38 additions & 0 deletions src/triton/ops/rms_norm/rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import triton
import triton.language as tl


@triton.jit
def kernel(
x_ptr,
w_ptr,
y_ptr,
eps,
m,
n,
stride_xm,
stride_xn,
stride_wn,
stride_ym,
stride_yn,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= m:
return

offs = tl.arange(0, BLOCK_SIZE)
mask = offs < n

x_row_ptr = x_ptr + pid * stride_xm
y_row_ptr = y_ptr + pid * stride_ym

x = tl.load(x_row_ptr + offs * stride_xn, mask=mask, other=0.0).to(tl.float32)
w = tl.load(w_ptr + offs * stride_wn, mask=mask, other=0.0).to(tl.float32)

mean_sq = tl.sum(x * x) / n
rrms = 1.0 / tl.sqrt(mean_sq + eps)

y = x * rrms * w

tl.store(y_row_ptr + offs * stride_yn, y, mask=mask)
Loading