From e45aeacd525cb60b8e60ed4b6738e4d6010da013 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 26 Jun 2026 19:22:13 +0800 Subject: [PATCH] feat: support nsa paged attn on nv --- .../ops/nsa_compress_paged_cache.hpp | 15 + .../infinicore/ops/nsa_paged_attention.hpp | 19 + include/infiniop.h | 2 + .../infiniop/ops/nsa_compress_paged_cache.h | 39 ++ include/infiniop/ops/nsa_paged_attention.h | 47 +++ .../nsa_compress_paged_cache.cc | 29 ++ .../nsa_compress_paged_cache_infiniop.cc | 59 +++ .../nsa_paged_attention.cc | 40 ++ .../nsa_paged_attention_infiniop.cc | 67 +++ .../nsa_compress_paged_cache/cuda/kernel.cuh | 93 +++++ .../ops/nsa_compress_paged_cache/info.h | 119 ++++++ .../nsa_compress_paged_cache.h | 39 ++ .../nvidia/nsa_compress_paged_cache_nvidia.cu | 151 +++++++ .../nsa_compress_paged_cache_nvidia.cuh | 8 + .../ops/nsa_compress_paged_cache/operator.cc | 130 ++++++ .../ops/nsa_paged_attention/cuda/kernel.cuh | 239 +++++++++++ src/infiniop/ops/nsa_paged_attention/info.h | 165 ++++++++ .../nsa_paged_attention/nsa_paged_attention.h | 48 +++ .../nvidia/nsa_paged_attention_nvidia.cu | 209 ++++++++++ .../nvidia/nsa_paged_attention_nvidia.cuh | 8 + .../ops/nsa_paged_attention/operator.cc | 138 +++++++ test/infiniop/libinfiniop/op_register.py | 92 +++++ test/infiniop/nsa_compress_paged_cache.py | 271 +++++++++++++ test/infiniop/nsa_paged_attention.py | 380 ++++++++++++++++++ 24 files changed, 2407 insertions(+) create mode 100644 include/infinicore/ops/nsa_compress_paged_cache.hpp create mode 100644 include/infinicore/ops/nsa_paged_attention.hpp create mode 100644 include/infiniop/ops/nsa_compress_paged_cache.h create mode 100644 include/infiniop/ops/nsa_paged_attention.h create mode 100644 src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.cc create mode 100644 src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache_infiniop.cc create mode 100644 src/infinicore/ops/nsa_paged_attention/nsa_paged_attention.cc create mode 100644 src/infinicore/ops/nsa_paged_attention/nsa_paged_attention_infiniop.cc create mode 100644 src/infiniop/ops/nsa_compress_paged_cache/cuda/kernel.cuh create mode 100644 src/infiniop/ops/nsa_compress_paged_cache/info.h create mode 100644 src/infiniop/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.h create mode 100644 src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cu create mode 100644 src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cuh create mode 100644 src/infiniop/ops/nsa_compress_paged_cache/operator.cc create mode 100644 src/infiniop/ops/nsa_paged_attention/cuda/kernel.cuh create mode 100644 src/infiniop/ops/nsa_paged_attention/info.h create mode 100644 src/infiniop/ops/nsa_paged_attention/nsa_paged_attention.h create mode 100644 src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cu create mode 100644 src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cuh create mode 100644 src/infiniop/ops/nsa_paged_attention/operator.cc create mode 100644 test/infiniop/nsa_compress_paged_cache.py create mode 100644 test/infiniop/nsa_paged_attention.py diff --git a/include/infinicore/ops/nsa_compress_paged_cache.hpp b/include/infinicore/ops/nsa_compress_paged_cache.hpp new file mode 100644 index 000000000..bb5380632 --- /dev/null +++ b/include/infinicore/ops/nsa_compress_paged_cache.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(NsaCompressPagedCache, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int, bool); + +void nsa_compress_paged_cache_(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size, + bool update_last_only = false); + +} // namespace infinicore::op diff --git a/include/infinicore/ops/nsa_paged_attention.hpp b/include/infinicore/ops/nsa_paged_attention.hpp new file mode 100644 index 000000000..00e9adc10 --- /dev/null +++ b/include/infinicore/ops/nsa_paged_attention.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(NsaPagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, int, int, int); + +Tensor nsa_paged_attention(const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates, + float scale, int nsa_block_size, int window_size, int select_blocks); + +void nsa_paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates, + float scale, int nsa_block_size, int window_size, int select_blocks); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index bc4b84bc6..19aaf9457 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -98,6 +98,8 @@ #include "infiniop/ops/mul.h" #include "infiniop/ops/multi_margin_loss.h" #include "infiniop/ops/nrm2.h" +#include "infiniop/ops/nsa_compress_paged_cache.h" +#include "infiniop/ops/nsa_paged_attention.h" #include "infiniop/ops/ones.h" #include "infiniop/ops/pad.h" #include "infiniop/ops/paged_attention.h" diff --git a/include/infiniop/ops/nsa_compress_paged_cache.h b/include/infiniop/ops/nsa_compress_paged_cache.h new file mode 100644 index 000000000..f69d3fb5d --- /dev/null +++ b/include/infiniop/ops/nsa_compress_paged_cache.h @@ -0,0 +1,39 @@ +#ifndef __INFINIOP_NSA_COMPRESS_PAGED_CACHE_API_H__ +#define __INFINIOP_NSA_COMPRESS_PAGED_CACHE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopNsaCompressPagedCacheDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateNsaCompressPagedCacheDescriptor( + infiniopHandle_t handle, + infiniopNsaCompressPagedCacheDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + int nsa_block_size, + int update_last_only); + +__INFINI_C __export infiniStatus_t infiniopGetNsaCompressPagedCacheWorkspaceSize( + infiniopNsaCompressPagedCacheDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopNsaCompressPagedCache( + infiniopNsaCompressPagedCacheDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cmp, + void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyNsaCompressPagedCacheDescriptor( + infiniopNsaCompressPagedCacheDescriptor_t desc); + +#endif // __INFINIOP_NSA_COMPRESS_PAGED_CACHE_API_H__ diff --git a/include/infiniop/ops/nsa_paged_attention.h b/include/infiniop/ops/nsa_paged_attention.h new file mode 100644 index 000000000..e92414f86 --- /dev/null +++ b/include/infiniop/ops/nsa_paged_attention.h @@ -0,0 +1,47 @@ +#ifndef __INFINIOP_NSA_PAGED_ATTENTION_API_H__ +#define __INFINIOP_NSA_PAGED_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopNsaPagedAttentionDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateNsaPagedAttentionDescriptor( + infiniopHandle_t handle, + infiniopNsaPagedAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t gates_desc, + float scale, + int nsa_block_size, + int window_size, + int select_blocks); + +__INFINI_C __export infiniStatus_t infiniopGetNsaPagedAttentionWorkspaceSize( + infiniopNsaPagedAttentionDescriptor_t desc, + size_t *size); + +__INFINI_C __export infiniStatus_t infiniopNsaPagedAttention( + infiniopNsaPagedAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cmp, + const void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *gates, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyNsaPagedAttentionDescriptor( + infiniopNsaPagedAttentionDescriptor_t desc); + +#endif // __INFINIOP_NSA_PAGED_ATTENTION_API_H__ diff --git a/src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.cc b/src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.cc new file mode 100644 index 000000000..5d0269ff6 --- /dev/null +++ b/src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.cc @@ -0,0 +1,29 @@ +#include "infinicore/ops/nsa_compress_paged_cache.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(NsaCompressPagedCache); + +NsaCompressPagedCache::NsaCompressPagedCache(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size, + bool update_last_only) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens); + INFINICORE_GRAPH_OP_DISPATCH(k_cmp->device().getType(), k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, nsa_block_size, update_last_only); +} + +void NsaCompressPagedCache::execute(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size, + bool update_last_only) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + NsaCompressPagedCache, + k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, nsa_block_size, update_last_only); +} + +void nsa_compress_paged_cache_(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size, + bool update_last_only) { + NsaCompressPagedCache::execute(k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, nsa_block_size, update_last_only); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache_infiniop.cc b/src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache_infiniop.cc new file mode 100644 index 000000000..97f28d59b --- /dev/null +++ b/src/infinicore/ops/nsa_compress_paged_cache/nsa_compress_paged_cache_infiniop.cc @@ -0,0 +1,59 @@ +#include "infinicore/ops/nsa_compress_paged_cache.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::nsa_compress_paged_cache_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, NsaCompressPagedCache, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens; +}; + +void *plan(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &cache_lens, int nsa_block_size, bool update_last_only) { + size_t seed = hash_combine(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, nsa_block_size, update_last_only); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, NsaCompressPagedCache, + seed, + k_cmp->desc(), v_cmp->desc(), k_cache->desc(), v_cache->desc(), + block_tables->desc(), cache_lens->desc(), nsa_block_size, static_cast(update_last_only)); + + INFINIOP_WORKSPACE_TENSOR(workspace, NsaCompressPagedCache, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(k_cmp), + graph::GraphTensor(v_cmp), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(block_tables), + graph::GraphTensor(cache_lens)}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + INFINICORE_CHECK_ERROR( + infiniopNsaCompressPagedCache( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->k_cmp->data(), + p->v_cmp->data(), + p->k_cache->data(), + p->v_cache->data(), + p->block_tables->data(), + p->cache_lens->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(NsaCompressPagedCache, &plan, &run, &cleanup); + +} // namespace infinicore::op::nsa_compress_paged_cache_impl::infiniop diff --git a/src/infinicore/ops/nsa_paged_attention/nsa_paged_attention.cc b/src/infinicore/ops/nsa_paged_attention/nsa_paged_attention.cc new file mode 100644 index 000000000..8d74a2e64 --- /dev/null +++ b/src/infinicore/ops/nsa_paged_attention/nsa_paged_attention.cc @@ -0,0 +1,40 @@ +#include "infinicore/ops/nsa_paged_attention.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(NsaPagedAttention); + +NsaPagedAttention::NsaPagedAttention(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, + const Tensor &k_cache, const Tensor &v_cache, const Tensor &block_tables, + const Tensor &kv_lens, const Tensor &gates, float scale, int nsa_block_size, + int window_size, int select_blocks) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks); +} + +void NsaPagedAttention::execute(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, + const Tensor &k_cache, const Tensor &v_cache, const Tensor &block_tables, + const Tensor &kv_lens, const Tensor &gates, float scale, int nsa_block_size, + int window_size, int select_blocks) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + NsaPagedAttention, + out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks); +} + +Tensor nsa_paged_attention(const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates, + float scale, int nsa_block_size, int window_size, int select_blocks) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + nsa_paged_attention_(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks); + return out; +} + +void nsa_paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates, + float scale, int nsa_block_size, int window_size, int select_blocks) { + NsaPagedAttention::execute(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/nsa_paged_attention/nsa_paged_attention_infiniop.cc b/src/infinicore/ops/nsa_paged_attention/nsa_paged_attention_infiniop.cc new file mode 100644 index 000000000..d2f26af06 --- /dev/null +++ b/src/infinicore/ops/nsa_paged_attention/nsa_paged_attention_infiniop.cc @@ -0,0 +1,67 @@ +#include "infinicore/ops/nsa_paged_attention.hpp" + +#include "../infiniop_impl.hpp" + +namespace infinicore::op::nsa_paged_attention_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, NsaPagedAttention, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates; +}; + +void *plan(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &cache_lens, const Tensor &gates, + float scale, int nsa_block_size, int window_size, int select_blocks) { + size_t seed = hash_combine(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, nsa_block_size, window_size, select_blocks); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, NsaPagedAttention, + seed, + out->desc(), q->desc(), k_cmp->desc(), v_cmp->desc(), k_cache->desc(), v_cache->desc(), + block_tables->desc(), cache_lens->desc(), gates->desc(), + scale, nsa_block_size, window_size, select_blocks); + + INFINIOP_WORKSPACE_TENSOR(workspace, NsaPagedAttention, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k_cmp), + graph::GraphTensor(v_cmp), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(block_tables), + graph::GraphTensor(cache_lens), + graph::GraphTensor(gates)}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + INFINICORE_CHECK_ERROR( + infiniopNsaPagedAttention( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->out->data(), + p->q->data(), + p->k_cmp->data(), + p->v_cmp->data(), + p->k_cache->data(), + p->v_cache->data(), + p->block_tables->data(), + p->cache_lens->data(), + p->gates->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(NsaPagedAttention, &plan, &run, &cleanup); + +} // namespace infinicore::op::nsa_paged_attention_impl::infiniop diff --git a/src/infiniop/ops/nsa_compress_paged_cache/cuda/kernel.cuh b/src/infiniop/ops/nsa_compress_paged_cache/cuda/kernel.cuh new file mode 100644 index 000000000..0b1cacd5d --- /dev/null +++ b/src/infiniop/ops/nsa_compress_paged_cache/cuda/kernel.cuh @@ -0,0 +1,93 @@ +#ifndef __NSA_COMPRESS_PAGED_CACHE_CUDA_KERNEL_CUH__ +#define __NSA_COMPRESS_PAGED_CACHE_CUDA_KERNEL_CUH__ + +#include +#include +#include +#include + +namespace op::nsa_compress_paged_cache::cuda { + +template +__device__ inline float loadFloat(const T *ptr) { return static_cast(*ptr); } +template <> +__device__ inline float loadFloat<__half>(const __half *ptr) { return __half2float(*ptr); } +template <> +__device__ inline float loadFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr) { return __bfloat162float(*ptr); } +template +__device__ inline void storeFloat(T *ptr, float value) { *ptr = static_cast(value); } +template <> +__device__ inline void storeFloat<__half>(__half *ptr, float value) { *ptr = __float2half(value); } +template <> +__device__ inline void storeFloat<__nv_bfloat16>(__nv_bfloat16 *ptr, float value) { *ptr = __float2bfloat16(value); } + +template +__device__ void compressPagedCacheKernel( + Tdata *k_cmp, + Tdata *v_cmp, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + size_t max_num_blocks_per_seq, + size_t page_block_size, + size_t subblocks_per_page, + int nsa_block_size, + int update_last_only, + ptrdiff_t k_cmp_block_stride, + ptrdiff_t k_cmp_head_stride, + ptrdiff_t v_cmp_block_stride, + ptrdiff_t v_cmp_head_stride, + ptrdiff_t k_cache_batch_stride, + ptrdiff_t k_cache_head_stride, + ptrdiff_t k_cache_row_stride, + ptrdiff_t v_cache_batch_stride, + ptrdiff_t v_cache_head_stride, + ptrdiff_t v_cache_row_stride, + ptrdiff_t block_table_batch_stride, + ptrdiff_t cache_lens_stride) { + const size_t seq = blockIdx.x; + size_t logical_nsa_block = blockIdx.y; + const size_t kv_head = blockIdx.z; + const int dim = threadIdx.x; + constexpr int kHeadDim = 128; + if (dim >= kHeadDim) { + return; + } + + const int64_t seq_len = static_cast(cache_lens[seq * cache_lens_stride]); + if (seq_len <= 0) { + return; + } + if (update_last_only) { + logical_nsa_block = static_cast((seq_len - 1) / nsa_block_size); + } + const int64_t tok_begin = static_cast(logical_nsa_block) * nsa_block_size; + if (tok_begin >= seq_len) { + return; + } + const int64_t tok_end = (tok_begin + nsa_block_size < seq_len) ? (tok_begin + nsa_block_size) : seq_len; + const size_t logical_page = static_cast(tok_begin / static_cast(page_block_size)); + if (logical_page >= max_num_blocks_per_seq) { + return; + } + const size_t subblock = static_cast((tok_begin % static_cast(page_block_size)) / nsa_block_size); + const Tindex physical = block_tables[seq * block_table_batch_stride + logical_page]; + const size_t cmp_block = static_cast(physical) * subblocks_per_page + subblock; + float k_sum = 0.0f; + float v_sum = 0.0f; + for (int64_t tok = tok_begin; tok < tok_end; ++tok) { + const size_t row = static_cast(tok % static_cast(page_block_size)); + const size_t k_base = static_cast(physical) * k_cache_batch_stride + kv_head * k_cache_head_stride + row * k_cache_row_stride; + const size_t v_base = static_cast(physical) * v_cache_batch_stride + kv_head * v_cache_head_stride + row * v_cache_row_stride; + k_sum += loadFloat(k_cache + k_base + dim); + v_sum += loadFloat(v_cache + v_base + dim); + } + const float inv = 1.0f / static_cast(tok_end - tok_begin); + storeFloat(k_cmp + cmp_block * k_cmp_block_stride + kv_head * k_cmp_head_stride + dim, k_sum * inv); + storeFloat(v_cmp + cmp_block * v_cmp_block_stride + kv_head * v_cmp_head_stride + dim, v_sum * inv); +} + +} // namespace op::nsa_compress_paged_cache::cuda + +#endif // __NSA_COMPRESS_PAGED_CACHE_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/nsa_compress_paged_cache/info.h b/src/infiniop/ops/nsa_compress_paged_cache/info.h new file mode 100644 index 000000000..2cd4dcd36 --- /dev/null +++ b/src/infiniop/ops/nsa_compress_paged_cache/info.h @@ -0,0 +1,119 @@ +#ifndef __NSA_COMPRESS_PAGED_CACHE_INFO_H__ +#define __NSA_COMPRESS_PAGED_CACHE_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +namespace op::nsa_compress_paged_cache { + +class NsaCompressPagedCacheInfo { + NsaCompressPagedCacheInfo() = default; + +public: + infiniDtype_t dtype; + infiniDtype_t index_dtype; + int nsa_block_size; + int update_last_only; + size_t num_seqs; + size_t num_kv_heads; + size_t head_size; + size_t page_block_size; + size_t subblocks_per_page; + size_t max_num_blocks_per_seq; + + ptrdiff_t k_cmp_block_stride; + ptrdiff_t k_cmp_head_stride; + ptrdiff_t v_cmp_block_stride; + ptrdiff_t v_cmp_head_stride; + ptrdiff_t k_cache_batch_stride; + ptrdiff_t k_cache_head_stride; + ptrdiff_t k_cache_row_stride; + ptrdiff_t v_cache_batch_stride; + ptrdiff_t v_cache_head_stride; + ptrdiff_t v_cache_row_stride; + ptrdiff_t block_table_batch_stride; + ptrdiff_t cache_lens_stride; + + static utils::Result create( + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + int nsa_block_size, + int update_last_only) { + auto dtype = k_cache_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + if (v_cache_desc->dtype() != dtype || k_cmp_desc->dtype() != dtype || v_cmp_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (k_cmp_desc->ndim() != 3 || v_cmp_desc->ndim() != 3 || k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4 || block_tables_desc->ndim() != 2 || cache_lens_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + CHECK_OR_RETURN(k_cmp_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cmp_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(cache_lens_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + const auto index_dtype = block_tables_desc->dtype(); + if (index_dtype != cache_lens_desc->dtype()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (index_dtype != INFINI_DTYPE_I64 && index_dtype != INFINI_DTYPE_I32 && index_dtype != INFINI_DTYPE_U32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + auto k_shape = k_cache_desc->shape(); + const size_t num_blocks = k_shape[0]; + const size_t num_kv_heads = k_shape[1]; + const size_t page_block_size = k_shape[2]; + const size_t head_size = k_shape[3]; + if (head_size != 128 || nsa_block_size <= 0 || page_block_size % static_cast(nsa_block_size) != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (v_cache_desc->shape() != k_shape) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + const size_t subblocks_per_page = page_block_size / static_cast(nsa_block_size); + if (k_cmp_desc->shape()[0] != num_blocks * subblocks_per_page || k_cmp_desc->shape()[1] != num_kv_heads || k_cmp_desc->shape()[2] != head_size) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (v_cmp_desc->shape() != k_cmp_desc->shape()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (cache_lens_desc->shape()[0] != block_tables_desc->shape()[0]) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + return utils::Result(NsaCompressPagedCacheInfo{ + dtype, + index_dtype, + nsa_block_size, + update_last_only, + block_tables_desc->shape()[0], + num_kv_heads, + head_size, + page_block_size, + subblocks_per_page, + block_tables_desc->shape()[1], + k_cmp_desc->stride(0), + k_cmp_desc->stride(1), + v_cmp_desc->stride(0), + v_cmp_desc->stride(1), + k_cache_desc->stride(0), + k_cache_desc->stride(1), + k_cache_desc->stride(2), + v_cache_desc->stride(0), + v_cache_desc->stride(1), + v_cache_desc->stride(2), + block_tables_desc->stride(0), + cache_lens_desc->stride(0), + }); + } +}; + +} // namespace op::nsa_compress_paged_cache + +#endif // __NSA_COMPRESS_PAGED_CACHE_INFO_H__ diff --git a/src/infiniop/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.h b/src/infiniop/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.h new file mode 100644 index 000000000..51ef64b36 --- /dev/null +++ b/src/infiniop/ops/nsa_compress_paged_cache/nsa_compress_paged_cache.h @@ -0,0 +1,39 @@ +#ifndef NSA_COMPRESS_PAGED_CACHE_H +#define NSA_COMPRESS_PAGED_CACHE_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + namespace op::nsa_compress_paged_cache::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + NsaCompressPagedCacheInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(Opaque *opaque, NsaCompressPagedCacheInfo info, \ + size_t workspace_size, infiniDevice_t device_type, int device_id) \ + : InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \ + _info(info), _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + size_t workspaceSize() const { return _workspace_size; } \ + static infiniStatus_t create( \ + infiniopHandle_t handle, Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t k_cmp_desc, \ + infiniopTensorDescriptor_t v_cmp_desc, \ + infiniopTensorDescriptor_t k_cache_desc, \ + infiniopTensorDescriptor_t v_cache_desc, \ + infiniopTensorDescriptor_t block_tables_desc, \ + infiniopTensorDescriptor_t seq_lens_desc, int nsa_block_size, \ + int update_last_only); \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, void *k_cmp, void *v_cmp, \ + const void *k_cache, const void *v_cache, const void *block_tables, \ + const void *seq_lens, void *stream) const; \ + }; \ + } + +#endif // NSA_COMPRESS_PAGED_CACHE_H diff --git a/src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cu b/src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cu new file mode 100644 index 000000000..4e08dac17 --- /dev/null +++ b/src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cu @@ -0,0 +1,151 @@ +#include "nsa_compress_paged_cache_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include +#include +#include +#include + +#include "../cuda/kernel.cuh" + +namespace op::nsa_compress_paged_cache::nvidia { + +namespace { + +template +INFINIOP_CUDA_KERNEL launchNsaCompressPagedCacheHd128( + Tdata *k_cmp, + Tdata *v_cmp, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + size_t max_num_blocks_per_seq, + size_t page_block_size, + size_t subblocks_per_page, + int nsa_block_size, + int update_last_only, + ptrdiff_t k_cmp_block_stride, + ptrdiff_t k_cmp_head_stride, + ptrdiff_t v_cmp_block_stride, + ptrdiff_t v_cmp_head_stride, + ptrdiff_t k_cache_batch_stride, + ptrdiff_t k_cache_head_stride, + ptrdiff_t k_cache_row_stride, + ptrdiff_t v_cache_batch_stride, + ptrdiff_t v_cache_head_stride, + ptrdiff_t v_cache_row_stride, + ptrdiff_t block_table_batch_stride, + ptrdiff_t cache_lens_stride) { + op::nsa_compress_paged_cache::cuda::compressPagedCacheKernel( + k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, max_num_blocks_per_seq, + page_block_size, subblocks_per_page, nsa_block_size, update_last_only, k_cmp_block_stride, k_cmp_head_stride, + v_cmp_block_stride, v_cmp_head_stride, k_cache_batch_stride, k_cache_head_stride, + k_cache_row_stride, v_cache_batch_stride, v_cache_head_stride, v_cache_row_stride, + block_table_batch_stride, cache_lens_stride); +} + +template +infiniStatus_t launchTyped( + void *k_cmp, + void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *cache_lens, + const NsaCompressPagedCacheInfo &info, + cudaStream_t stream) { + const size_t max_nsa_blocks = info.update_last_only ? 1 : (info.max_num_blocks_per_seq * info.page_block_size + info.nsa_block_size - 1) / info.nsa_block_size; + dim3 grid(info.num_seqs, max_nsa_blocks, info.num_kv_heads); + dim3 block(128); + launchNsaCompressPagedCacheHd128<<>>( + static_cast(k_cmp), + static_cast(v_cmp), + static_cast(k_cache), + static_cast(v_cache), + static_cast(block_tables), + static_cast(cache_lens), + info.max_num_blocks_per_seq, + info.page_block_size, + info.subblocks_per_page, + info.nsa_block_size, + info.update_last_only, + info.k_cmp_block_stride, + info.k_cmp_head_stride, + info.v_cmp_block_stride, + info.v_cmp_head_stride, + info.k_cache_batch_stride, + info.k_cache_head_stride, + info.k_cache_row_stride, + info.v_cache_batch_stride, + info.v_cache_head_stride, + info.v_cache_row_stride, + info.block_table_batch_stride, + info.cache_lens_stride); + return INFINI_STATUS_SUCCESS; +} + +} // namespace + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + int nsa_block_size, + int update_last_only) { + auto info_res = NsaCompressPagedCacheInfo::create( + k_cmp_desc, v_cmp_desc, k_cache_desc, v_cache_desc, block_tables_desc, cache_lens_desc, nsa_block_size, update_last_only); + CHECK_RESULT(info_res); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info_res.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *k_cmp, + void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *cache_lens, + void *stream_) const { + (void)workspace; + (void)workspace_size; + auto stream = static_cast(stream_); + if (_info.dtype == INFINI_DTYPE_F16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + return launchTyped(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, _info, stream); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + return launchTyped(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, _info, stream); + } + return launchTyped(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, _info, stream); + } + if (_info.index_dtype == INFINI_DTYPE_I64) { + return launchTyped(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, _info, stream); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + return launchTyped(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, _info, stream); + } + return launchTyped(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, _info, stream); +} + +} // namespace op::nsa_compress_paged_cache::nvidia diff --git a/src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cuh b/src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cuh new file mode 100644 index 000000000..7f60547cb --- /dev/null +++ b/src/infiniop/ops/nsa_compress_paged_cache/nvidia/nsa_compress_paged_cache_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __NSA_COMPRESS_PAGED_CACHE_NVIDIA_H__ +#define __NSA_COMPRESS_PAGED_CACHE_NVIDIA_H__ + +#include "../nsa_compress_paged_cache.h" + +DESCRIPTOR(nvidia) + +#endif // __NSA_COMPRESS_PAGED_CACHE_NVIDIA_H__ diff --git a/src/infiniop/ops/nsa_compress_paged_cache/operator.cc b/src/infiniop/ops/nsa_compress_paged_cache/operator.cc new file mode 100644 index 000000000..a25311fa3 --- /dev/null +++ b/src/infiniop/ops/nsa_compress_paged_cache/operator.cc @@ -0,0 +1,130 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/nsa_compress_paged_cache.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) +#include "nvidia/nsa_compress_paged_cache_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateNsaCompressPagedCacheDescriptor( + infiniopHandle_t handle, + infiniopNsaCompressPagedCacheDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + int nsa_block_size, + int update_last_only) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::nsa_compress_paged_cache::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + k_cmp_desc, v_cmp_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, \ + nsa_block_size, update_last_only); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + CREATE(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGetNsaCompressPagedCacheWorkspaceSize( + infiniopNsaCompressPagedCacheDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + GET(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopNsaCompressPagedCache( + infiniopNsaCompressPagedCacheDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cmp, + void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, k_cmp, v_cmp, k_cache, v_cache, block_tables, seq_lens, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + CALCULATE(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopDestroyNsaCompressPagedCacheDescriptor( + infiniopNsaCompressPagedCacheDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + DESTROY(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/src/infiniop/ops/nsa_paged_attention/cuda/kernel.cuh b/src/infiniop/ops/nsa_paged_attention/cuda/kernel.cuh new file mode 100644 index 000000000..3d8af3ddc --- /dev/null +++ b/src/infiniop/ops/nsa_paged_attention/cuda/kernel.cuh @@ -0,0 +1,239 @@ +#ifndef __NSA_PAGED_ATTENTION_CUDA_KERNEL_CUH__ +#define __NSA_PAGED_ATTENTION_CUDA_KERNEL_CUH__ + +#include +#include +#include +#include +#include + +namespace op::nsa_paged_attention::cuda { + +template +__device__ inline float loadFloat(const T *ptr) { + return static_cast(*ptr); +} + +template <> +__device__ inline float loadFloat<__half>(const __half *ptr) { + return __half2float(*ptr); +} + +template <> +__device__ inline float loadFloat<__nv_bfloat16>(const __nv_bfloat16 *ptr) { + return __bfloat162float(*ptr); +} + +template +__device__ inline void storeFloat(T *ptr, float value) { + *ptr = static_cast(value); +} + +template <> +__device__ inline void storeFloat<__half>(__half *ptr, float value) { + *ptr = __float2half(value); +} + +template <> +__device__ inline void storeFloat<__nv_bfloat16>(__nv_bfloat16 *ptr, float value) { + *ptr = __float2bfloat16(value); +} + +template +__device__ void nsaPagedDecodeHd128Kernel( + Tdata *out, + const Tdata *q, + const Tdata *k_cmp, + const Tdata *v_cmp, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const Tgate *gates, + size_t num_heads, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + size_t subblocks_per_page, + int nsa_block_size, + int window_size, + int select_blocks, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_cmp_block_stride, + ptrdiff_t k_cmp_head_stride, + ptrdiff_t v_cmp_block_stride, + ptrdiff_t v_cmp_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_head_stride, + ptrdiff_t k_row_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_head_stride, + ptrdiff_t v_row_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + ptrdiff_t block_table_batch_stride, + ptrdiff_t cache_lens_stride, + ptrdiff_t gates_seq_stride, + ptrdiff_t gates_branch_stride, + ptrdiff_t gates_head_stride) { + constexpr int kHeadDim = 128; + __shared__ float scratch[kHeadDim]; + + const size_t seq = static_cast(blockIdx.x); + const size_t head = static_cast(blockIdx.y); + const int dim = threadIdx.x; + if (dim >= kHeadDim) { + return; + } + + const size_t kv_head = head / (num_heads / num_kv_heads); + const int64_t seq_len = static_cast(cache_lens[seq * cache_lens_stride]); + const float qd = loadFloat(q + seq * q_stride + head * q_head_stride + dim); + float comp_acc = 0.0f; + float comp_m = -INFINITY; + float comp_l = 0.0f; + + constexpr int kMaxSelectBlocks = 32; + const int active_select_blocks = max(1, min(select_blocks, kMaxSelectBlocks)); + float top_scores[kMaxSelectBlocks]; + int top_blocks[kMaxSelectBlocks]; +#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_HYGON_API) +#pragma unroll +#endif + for (int i = 0; i < kMaxSelectBlocks; ++i) { + top_scores[i] = -INFINITY; + top_blocks[i] = -1; + } + + const int64_t nsa_blocks = (seq_len + nsa_block_size - 1) / nsa_block_size; + for (int64_t nsa_block = 0; nsa_block < nsa_blocks; ++nsa_block) { + const int64_t tok_begin = nsa_block * nsa_block_size; + const size_t logical_page = static_cast(tok_begin / static_cast(page_block_size)); + if (logical_page >= max_num_blocks_per_seq) { + continue; + } + const size_t subblock = static_cast((tok_begin % static_cast(page_block_size)) / nsa_block_size); + const Tindex physical = block_tables[seq * block_table_batch_stride + logical_page]; + const size_t cmp_block = static_cast(physical) * subblocks_per_page + subblock; + const size_t base_k = cmp_block * k_cmp_block_stride + kv_head * k_cmp_head_stride; + const size_t base_v = cmp_block * v_cmp_block_stride + kv_head * v_cmp_head_stride; + const float kd = loadFloat(k_cmp + base_k + dim); + scratch[dim] = qd * kd; + __syncthreads(); + for (int stride = 64; stride > 0; stride >>= 1) { + if (dim < stride) { + scratch[dim] += scratch[dim + stride]; + } + __syncthreads(); + } + const float score = scratch[0] * scale; + if (score > top_scores[active_select_blocks - 1]) { + int insert_pos = active_select_blocks - 1; + while (insert_pos > 0 && score > top_scores[insert_pos - 1]) { + top_scores[insert_pos] = top_scores[insert_pos - 1]; + top_blocks[insert_pos] = top_blocks[insert_pos - 1]; + --insert_pos; + } + top_scores[insert_pos] = score; + top_blocks[insert_pos] = static_cast(nsa_block); + } + const float vd = loadFloat(v_cmp + base_v + dim); + const float new_m = fmaxf(comp_m, score); + const float alpha = expf(comp_m - new_m); + const float beta = expf(score - new_m); + comp_acc = comp_acc * alpha + beta * vd; + comp_l = comp_l * alpha + beta; + comp_m = new_m; + __syncthreads(); + } + const float comp_out = comp_l > 0.0f ? comp_acc / comp_l : 0.0f; + + float sel_acc = 0.0f; + float sel_m = -INFINITY; + float sel_l = 0.0f; +#if !defined(ENABLE_ILUVATAR_API) && !defined(ENABLE_HYGON_API) +#pragma unroll +#endif + for (int selected = 0; selected < active_select_blocks; ++selected) { + const int nsa_block = top_blocks[selected]; + if (nsa_block < 0) { + continue; + } + const int64_t tok_begin = static_cast(nsa_block) * nsa_block_size; + const int64_t tok_end = min(tok_begin + static_cast(nsa_block_size), seq_len); + for (int64_t tok = tok_begin; tok < tok_end; ++tok) { + const size_t logical_block = static_cast(tok / static_cast(page_block_size)); + const size_t block_offset = static_cast(tok % static_cast(page_block_size)); + if (logical_block >= max_num_blocks_per_seq) { + continue; + } + const Tindex physical = block_tables[seq * block_table_batch_stride + logical_block]; + const size_t base_k = static_cast(physical) * k_batch_stride + kv_head * k_head_stride + block_offset * k_row_stride; + const size_t base_v = static_cast(physical) * v_batch_stride + kv_head * v_head_stride + block_offset * v_row_stride; + const float kd = loadFloat(k_cache + base_k + dim); + scratch[dim] = qd * kd; + __syncthreads(); + for (int stride = 64; stride > 0; stride >>= 1) { + if (dim < stride) { + scratch[dim] += scratch[dim + stride]; + } + __syncthreads(); + } + const float score = scratch[0] * scale; + const float vd = loadFloat(v_cache + base_v + dim); + const float new_m = fmaxf(sel_m, score); + const float alpha = expf(sel_m - new_m); + const float beta = expf(score - new_m); + sel_acc = sel_acc * alpha + beta * vd; + sel_l = sel_l * alpha + beta; + sel_m = new_m; + __syncthreads(); + } + } + const float sel_out = sel_l > 0.0f ? sel_acc / sel_l : 0.0f; + + float win_acc = 0.0f; + float win_m = -INFINITY; + float win_l = 0.0f; + const int64_t win_begin = window_size > 0 ? ((seq_len > window_size) ? (seq_len - window_size) : 0) : seq_len; + for (int64_t tok = win_begin; tok < seq_len; ++tok) { + const size_t logical_block = static_cast(tok / static_cast(page_block_size)); + const size_t block_offset = static_cast(tok % static_cast(page_block_size)); + if (logical_block >= max_num_blocks_per_seq) { + continue; + } + const Tindex physical = block_tables[seq * block_table_batch_stride + logical_block]; + const size_t base_k = static_cast(physical) * k_batch_stride + kv_head * k_head_stride + block_offset * k_row_stride; + const size_t base_v = static_cast(physical) * v_batch_stride + kv_head * v_head_stride + block_offset * v_row_stride; + const float kd = loadFloat(k_cache + base_k + dim); + scratch[dim] = qd * kd; + __syncthreads(); + for (int stride = 64; stride > 0; stride >>= 1) { + if (dim < stride) { + scratch[dim] += scratch[dim + stride]; + } + __syncthreads(); + } + const float score = scratch[0] * scale; + const float vd = loadFloat(v_cache + base_v + dim); + const float new_m = fmaxf(win_m, score); + const float alpha = expf(win_m - new_m); + const float beta = expf(score - new_m); + win_acc = win_acc * alpha + beta * vd; + win_l = win_l * alpha + beta; + win_m = new_m; + __syncthreads(); + } + const float win_out = win_l > 0.0f ? win_acc / win_l : 0.0f; + + const float g_cmp = loadFloat(gates + seq * gates_seq_stride + 0 * gates_branch_stride + head * gates_head_stride); + const float g_sel = loadFloat(gates + seq * gates_seq_stride + 1 * gates_branch_stride + head * gates_head_stride); + const float g_swa = loadFloat(gates + seq * gates_seq_stride + 2 * gates_branch_stride + head * gates_head_stride); + storeFloat(out + seq * o_stride + head * o_head_stride + dim, g_cmp * comp_out + g_sel * sel_out + g_swa * win_out); +} + +} // namespace op::nsa_paged_attention::cuda + +#endif // __NSA_PAGED_ATTENTION_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/nsa_paged_attention/info.h b/src/infiniop/ops/nsa_paged_attention/info.h new file mode 100644 index 000000000..2174e578a --- /dev/null +++ b/src/infiniop/ops/nsa_paged_attention/info.h @@ -0,0 +1,165 @@ +#ifndef __NSA_PAGED_ATTENTION_INFO_H__ +#define __NSA_PAGED_ATTENTION_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" + +namespace op::nsa_paged_attention { + +class NsaPagedAttentionInfo { + NsaPagedAttentionInfo() = default; + +public: + infiniDtype_t dtype; + infiniDtype_t gates_dtype; + infiniDtype_t index_dtype; + float scale; + int nsa_block_size; + int window_size; + int select_blocks; + + size_t num_seqs; + size_t num_heads; + size_t num_kv_heads; + size_t head_size; + size_t page_block_size; + size_t subblocks_per_page; + size_t max_num_blocks_per_seq; + + ptrdiff_t q_stride; + ptrdiff_t q_head_stride; + ptrdiff_t k_cmp_block_stride; + ptrdiff_t k_cmp_head_stride; + ptrdiff_t v_cmp_block_stride; + ptrdiff_t v_cmp_head_stride; + ptrdiff_t k_batch_stride; + ptrdiff_t k_head_stride; + ptrdiff_t k_row_stride; + ptrdiff_t v_batch_stride; + ptrdiff_t v_head_stride; + ptrdiff_t v_row_stride; + ptrdiff_t o_stride; + ptrdiff_t o_head_stride; + ptrdiff_t block_table_batch_stride; + ptrdiff_t cache_lens_stride; + ptrdiff_t gates_seq_stride; + ptrdiff_t gates_branch_stride; + ptrdiff_t gates_head_stride; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t gates_desc, + float scale, + int nsa_block_size, + int window_size, + int select_blocks) { + auto dtype = q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype + || k_cmp_desc->dtype() != dtype || v_cmp_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (gates_desc->dtype() != dtype && gates_desc->dtype() != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (q_desc->ndim() != 3 || out_desc->ndim() != 3 || k_cmp_desc->ndim() != 3 || v_cmp_desc->ndim() != 3 + || k_cache_desc->ndim() != 4 || v_cache_desc->ndim() != 4 || block_tables_desc->ndim() != 2 + || cache_lens_desc->ndim() != 1 || gates_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + CHECK_OR_RETURN(q_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(out_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cmp_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cmp_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_cache_desc->stride(3) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(block_tables_desc->stride(1) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(cache_lens_desc->stride(0) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); + + const auto index_dtype = block_tables_desc->dtype(); + if (index_dtype != cache_lens_desc->dtype()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (index_dtype != INFINI_DTYPE_I64 && index_dtype != INFINI_DTYPE_I32 && index_dtype != INFINI_DTYPE_U32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto q_shape = q_desc->shape(); + auto k_shape = k_cache_desc->shape(); + const size_t num_seqs = q_shape[0]; + const size_t num_heads = q_shape[1]; + const size_t head_size = q_shape[2]; + const size_t num_kv_heads = k_shape[1]; + const size_t page_block_size = k_shape[2]; + if (head_size != 128 || page_block_size == 0 || nsa_block_size <= 0 || window_size < 0 || select_blocks <= 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_heads % num_kv_heads != 0 || page_block_size % static_cast(nsa_block_size) != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (k_shape[3] != head_size || v_cache_desc->shape() != k_shape) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + const size_t subblocks_per_page = page_block_size / static_cast(nsa_block_size); + if (k_cmp_desc->shape()[0] != k_shape[0] * subblocks_per_page || k_cmp_desc->shape()[1] != num_kv_heads + || k_cmp_desc->shape()[2] != head_size || v_cmp_desc->shape() != k_cmp_desc->shape()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (out_desc->shape()[0] != num_seqs || out_desc->shape()[1] != num_heads || out_desc->shape()[2] != head_size) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (cache_lens_desc->shape()[0] != num_seqs) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (gates_desc->shape()[0] != num_seqs || gates_desc->shape()[1] != 3 || gates_desc->shape()[2] != num_heads) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + return utils::Result(NsaPagedAttentionInfo{ + dtype, + gates_desc->dtype(), + index_dtype, + scale, + nsa_block_size, + window_size, + select_blocks, + num_seqs, + num_heads, + num_kv_heads, + head_size, + page_block_size, + subblocks_per_page, + block_tables_desc->shape()[1], + q_desc->stride(0), + q_desc->stride(1), + k_cmp_desc->stride(0), + k_cmp_desc->stride(1), + v_cmp_desc->stride(0), + v_cmp_desc->stride(1), + k_cache_desc->stride(0), + k_cache_desc->stride(1), + k_cache_desc->stride(2), + v_cache_desc->stride(0), + v_cache_desc->stride(1), + v_cache_desc->stride(2), + out_desc->stride(0), + out_desc->stride(1), + block_tables_desc->stride(0), + cache_lens_desc->stride(0), + gates_desc->stride(0), + gates_desc->stride(1), + gates_desc->stride(2), + }); + } +}; + +} // namespace op::nsa_paged_attention + +#endif // __NSA_PAGED_ATTENTION_INFO_H__ diff --git a/src/infiniop/ops/nsa_paged_attention/nsa_paged_attention.h b/src/infiniop/ops/nsa_paged_attention/nsa_paged_attention.h new file mode 100644 index 000000000..e013c10d7 --- /dev/null +++ b/src/infiniop/ops/nsa_paged_attention/nsa_paged_attention.h @@ -0,0 +1,48 @@ +#ifndef NSA_PAGED_ATTENTION_H +#define NSA_PAGED_ATTENTION_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + namespace op::nsa_paged_attention::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + NsaPagedAttentionInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(Opaque *opaque, NsaPagedAttentionInfo info, \ + size_t workspace_size, infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_cmp_desc, \ + infiniopTensorDescriptor_t v_cmp_desc, \ + infiniopTensorDescriptor_t k_cache_desc, \ + infiniopTensorDescriptor_t v_cache_desc, \ + infiniopTensorDescriptor_t block_tables_desc, \ + infiniopTensorDescriptor_t seq_lens_desc, \ + infiniopTensorDescriptor_t gates_desc, float scale, \ + int nsa_block_size, int window_size, int select_blocks); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, void *out, \ + const void *q, const void *k_cmp, const void *v_cmp, \ + const void *k_cache, const void *v_cache, \ + const void *block_tables, const void *seq_lens, \ + const void *gates, void *stream) const; \ + }; \ + } + +#endif // NSA_PAGED_ATTENTION_H diff --git a/src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cu b/src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cu new file mode 100644 index 000000000..9d22c1825 --- /dev/null +++ b/src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cu @@ -0,0 +1,209 @@ +#include "nsa_paged_attention_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include +#include +#include +#include +#include + +#include "../cuda/kernel.cuh" + +namespace op::nsa_paged_attention::nvidia { + +namespace { + +template +INFINIOP_CUDA_KERNEL launchNsaPagedDecodeHd128( + Tdata *out, + const Tdata *q, + const Tdata *k_cmp, + const Tdata *v_cmp, + const Tdata *k_cache, + const Tdata *v_cache, + const Tindex *block_tables, + const Tindex *cache_lens, + const Tgate *gates, + size_t num_heads, + size_t num_kv_heads, + float scale, + size_t max_num_blocks_per_seq, + size_t page_block_size, + size_t subblocks_per_page, + int nsa_block_size, + int window_size, + int select_blocks, + ptrdiff_t q_stride, + ptrdiff_t q_head_stride, + ptrdiff_t k_cmp_block_stride, + ptrdiff_t k_cmp_head_stride, + ptrdiff_t v_cmp_block_stride, + ptrdiff_t v_cmp_head_stride, + ptrdiff_t k_batch_stride, + ptrdiff_t k_head_stride, + ptrdiff_t k_row_stride, + ptrdiff_t v_batch_stride, + ptrdiff_t v_head_stride, + ptrdiff_t v_row_stride, + ptrdiff_t o_stride, + ptrdiff_t o_head_stride, + ptrdiff_t block_table_batch_stride, + ptrdiff_t cache_lens_stride, + ptrdiff_t gates_seq_stride, + ptrdiff_t gates_branch_stride, + ptrdiff_t gates_head_stride) { + op::nsa_paged_attention::cuda::nsaPagedDecodeHd128Kernel( + out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, num_heads, num_kv_heads, scale, + max_num_blocks_per_seq, page_block_size, subblocks_per_page, nsa_block_size, window_size, select_blocks, q_stride, q_head_stride, + k_cmp_block_stride, k_cmp_head_stride, v_cmp_block_stride, v_cmp_head_stride, + k_batch_stride, k_head_stride, k_row_stride, v_batch_stride, v_head_stride, v_row_stride, + o_stride, o_head_stride, block_table_batch_stride, cache_lens_stride, gates_seq_stride, + gates_branch_stride, gates_head_stride); +} + +template +infiniStatus_t launchTyped( + void *out, + const void *q, + const void *k_cmp, + const void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *cache_lens, + const void *gates, + const NsaPagedAttentionInfo &info, + cudaStream_t stream) { + dim3 grid(info.num_seqs, info.num_heads); + dim3 block(128); + launchNsaPagedDecodeHd128<<>>( + static_cast(out), + static_cast(q), + static_cast(k_cmp), + static_cast(v_cmp), + static_cast(k_cache), + static_cast(v_cache), + static_cast(block_tables), + static_cast(cache_lens), + static_cast(gates), + info.num_heads, + info.num_kv_heads, + info.scale, + info.max_num_blocks_per_seq, + info.page_block_size, + info.subblocks_per_page, + info.nsa_block_size, + info.window_size, + info.select_blocks, + info.q_stride, + info.q_head_stride, + info.k_cmp_block_stride, + info.k_cmp_head_stride, + info.v_cmp_block_stride, + info.v_cmp_head_stride, + info.k_batch_stride, + info.k_head_stride, + info.k_row_stride, + info.v_batch_stride, + info.v_head_stride, + info.v_row_stride, + info.o_stride, + info.o_head_stride, + info.block_table_batch_stride, + info.cache_lens_stride, + info.gates_seq_stride, + info.gates_branch_stride, + info.gates_head_stride); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchByGate( + void *out, const void *q, const void *k_cmp, const void *v_cmp, const void *k_cache, const void *v_cache, + const void *block_tables, const void *cache_lens, const void *gates, + const NsaPagedAttentionInfo &info, cudaStream_t stream) { + if (info.gates_dtype == INFINI_DTYPE_F16) { + return launchTyped(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, info, stream); + } + if (info.gates_dtype == INFINI_DTYPE_BF16) { + return launchTyped(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, info, stream); + } + if (info.gates_dtype == INFINI_DTYPE_F32) { + return launchTyped(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, info, stream); + } + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t gates_desc, + float scale, + int nsa_block_size, + int window_size, + int select_blocks) { + auto info_res = NsaPagedAttentionInfo::create(out_desc, q_desc, k_cmp_desc, v_cmp_desc, k_cache_desc, v_cache_desc, + block_tables_desc, cache_lens_desc, gates_desc, scale, nsa_block_size, window_size, select_blocks); + CHECK_RESULT(info_res); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info_res.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cmp, + const void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *cache_lens, + const void *gates, + void *stream_) const { + (void)workspace; + (void)workspace_size; + auto stream = static_cast(stream_); + + if (_info.dtype == INFINI_DTYPE_F16) { + if (_info.index_dtype == INFINI_DTYPE_I64) { + return launchByGate(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, _info, stream); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + return launchByGate(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, _info, stream); + } + return launchByGate(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, _info, stream); + } + if (_info.index_dtype == INFINI_DTYPE_I64) { + return launchByGate(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, _info, stream); + } + if (_info.index_dtype == INFINI_DTYPE_I32) { + return launchByGate(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, _info, stream); + } + return launchByGate(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, _info, stream); +} + +} // namespace op::nsa_paged_attention::nvidia diff --git a/src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cuh b/src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cuh new file mode 100644 index 000000000..a9c7392aa --- /dev/null +++ b/src/infiniop/ops/nsa_paged_attention/nvidia/nsa_paged_attention_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __NSA_PAGED_ATTENTION_NVIDIA_H__ +#define __NSA_PAGED_ATTENTION_NVIDIA_H__ + +#include "../nsa_paged_attention.h" + +DESCRIPTOR(nvidia) + +#endif // __NSA_PAGED_ATTENTION_NVIDIA_H__ diff --git a/src/infiniop/ops/nsa_paged_attention/operator.cc b/src/infiniop/ops/nsa_paged_attention/operator.cc new file mode 100644 index 000000000..32bf183b2 --- /dev/null +++ b/src/infiniop/ops/nsa_paged_attention/operator.cc @@ -0,0 +1,138 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/nsa_paged_attention.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) +#include "nvidia/nsa_paged_attention_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateNsaPagedAttentionDescriptor( + infiniopHandle_t handle, + infiniopNsaPagedAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cmp_desc, + infiniopTensorDescriptor_t v_cmp_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t gates_desc, + float scale, + int nsa_block_size, + int window_size, + int select_blocks) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::nsa_paged_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, q_desc, k_cmp_desc, v_cmp_desc, k_cache_desc, v_cache_desc, block_tables_desc, \ + seq_lens_desc, gates_desc, scale, nsa_block_size, window_size, select_blocks); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + CREATE(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopGetNsaPagedAttentionWorkspaceSize( + infiniopNsaPagedAttentionDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + GET(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopNsaPagedAttention( + infiniopNsaPagedAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cmp, + const void *v_cmp, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *gates, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, seq_lens, gates, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + CALCULATE(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__INFINI_C infiniStatus_t infiniopDestroyNsaPagedAttentionDescriptor( + infiniopNsaPagedAttentionDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ALI_API + DESTROY(INFINI_DEVICE_ALI, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 027a7ec7f..05fd1b085 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -2918,3 +2918,95 @@ def chunk_gated_delta_rule_(lib): lib.infiniopDestroyChunkGatedDeltaRuleDescriptor.argtypes = [ infiniopOperatorDescriptor_t ] + + +@OpRegister.operator +def nsa_compress_paged_cache_(lib): + lib.infiniopCreateNsaCompressPagedCacheDescriptor.restype = c_int32 + lib.infiniopCreateNsaCompressPagedCacheDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + c_int32, + ] + + lib.infiniopGetNsaCompressPagedCacheWorkspaceSize.restype = c_int32 + lib.infiniopGetNsaCompressPagedCacheWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopNsaCompressPagedCache.restype = c_int32 + lib.infiniopNsaCompressPagedCache.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyNsaCompressPagedCacheDescriptor.restype = c_int32 + lib.infiniopDestroyNsaCompressPagedCacheDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def nsa_paged_attention_(lib): + lib.infiniopCreateNsaPagedAttentionDescriptor.restype = c_int32 + lib.infiniopCreateNsaPagedAttentionDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + c_int32, + c_int32, + c_int32, + ] + + lib.infiniopGetNsaPagedAttentionWorkspaceSize.restype = c_int32 + lib.infiniopGetNsaPagedAttentionWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopNsaPagedAttention.restype = c_int32 + lib.infiniopNsaPagedAttention.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyNsaPagedAttentionDescriptor.restype = c_int32 + lib.infiniopDestroyNsaPagedAttentionDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/nsa_compress_paged_cache.py b/test/infiniop/nsa_compress_paged_cache.py new file mode 100644 index 000000000..fd5239dc8 --- /dev/null +++ b/test/infiniop/nsa_compress_paged_cache.py @@ -0,0 +1,271 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + + +# ============================================================================== +# Reference Implementation +# ============================================================================== +def ref_nsa_compress_paged_cache( + k_cmp, + v_cmp, + k_cache, + v_cache, + block_tables, + seq_lens, + nsa_block_size, + update_last_only, +): + k_ref = k_cmp.clone() + v_ref = v_cmp.clone() + page_block_size = k_cache.shape[2] + subblocks_per_page = page_block_size // nsa_block_size + + for seq in range(block_tables.shape[0]): + seq_len = int(seq_lens[seq].item()) + if seq_len <= 0: + continue + if update_last_only: + nsa_blocks = [(seq_len - 1) // nsa_block_size] + else: + nsa_blocks = range((seq_len + nsa_block_size - 1) // nsa_block_size) + + for nsa_block in nsa_blocks: + tok_begin = nsa_block * nsa_block_size + if tok_begin >= seq_len: + continue + tok_end = min(tok_begin + nsa_block_size, seq_len) + logical_page = tok_begin // page_block_size + subblock = (tok_begin % page_block_size) // nsa_block_size + physical = int(block_tables[seq, logical_page].item()) + cmp_block = physical * subblocks_per_page + subblock + rows = ( + torch.arange(tok_begin, tok_end, device=k_cache.device) + % page_block_size + ) + k_ref[cmp_block] = ( + k_cache[physical, :, rows, :].float().mean(dim=1).to(k_ref.dtype) + ) + v_ref[cmp_block] = ( + v_cache[physical, :, rows, :].float().mean(dim=1).to(v_ref.dtype) + ) + return k_ref, v_ref + + +# ============================================================================== +# Test Configuration (Internal Use Only) +# ============================================================================== +_TEST_CASES_ = [ + # (num_seqs, max_seq_len, num_kv_heads, page_block_size, nsa_block_size, update_last_only) + (2, 192, 2, 128, 64, False), + (3, 191, 1, 128, 64, True), +] + +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +def _seq_lens(num_seqs, max_seq_len, nsa_block_size): + values = [] + for i in range(num_seqs): + values.append(max(nsa_block_size // 2 + 1, max_seq_len - i * 37)) + return torch.tensor(values, dtype=torch.int64) + + +def test( + handle, + device, + num_seqs, + max_seq_len, + num_kv_heads, + page_block_size, + nsa_block_size, + update_last_only, + dtype, + sync, +): + print( + f"Testing NsaCompressPagedCache on {InfiniDeviceNames[device]} with " + f"num_seqs={num_seqs}, max_seq_len={max_seq_len}, num_kv_heads={num_kv_heads}, " + f"page_block_size={page_block_size}, nsa_block_size={nsa_block_size}, " + f"update_last_only={update_last_only}, dtype={InfiniDtypeNames[dtype]}" + ) + + head_size = 128 + max_blocks_per_seq = (max_seq_len + page_block_size - 1) // page_block_size + num_blocks = num_seqs * max_blocks_per_seq + subblocks_per_page = page_block_size // nsa_block_size + + k_cache = TestTensor( + (num_blocks, num_kv_heads, page_block_size, head_size), + None, + dtype, + device, + scale=0.1, + ) + v_cache = TestTensor( + (num_blocks, num_kv_heads, page_block_size, head_size), + None, + dtype, + device, + scale=0.1, + ) + k_cmp = TestTensor( + (num_blocks * subblocks_per_page, num_kv_heads, head_size), + None, + dtype, + device, + mode="zeros", + ) + v_cmp = TestTensor( + (num_blocks * subblocks_per_page, num_kv_heads, head_size), + None, + dtype, + device, + mode="zeros", + ) + + block_tables_torch = torch.arange(num_blocks, dtype=torch.int64).view( + num_seqs, max_blocks_per_seq + ) + seq_lens_torch = _seq_lens(num_seqs, max_seq_len, nsa_block_size) + block_tables = TestTensor.from_torch(block_tables_torch, InfiniDtype.I64, device) + seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device) + + ans_k, ans_v = ref_nsa_compress_paged_cache( + k_cmp.torch_tensor(), + v_cmp.torch_tensor(), + k_cache.torch_tensor(), + v_cache.torch_tensor(), + block_tables.torch_tensor(), + seq_lens.torch_tensor(), + nsa_block_size, + update_last_only, + ) + + if sync: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateNsaCompressPagedCacheDescriptor( + handle, + ctypes.byref(descriptor), + k_cmp.descriptor, + v_cmp.descriptor, + k_cache.descriptor, + v_cache.descriptor, + block_tables.descriptor, + seq_lens.descriptor, + nsa_block_size, + int(update_last_only), + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetNsaCompressPagedCacheWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + k_cmp.destroy_desc() + v_cmp.destroy_desc() + k_cache.destroy_desc() + v_cache.destroy_desc() + block_tables.destroy_desc() + seq_lens.destroy_desc() + + def lib_nsa_compress_paged_cache(): + check_error( + LIBINFINIOP.infiniopNsaCompressPagedCache( + descriptor, + workspace.data(), + workspace_size.value, + k_cmp.data(), + v_cmp.data(), + k_cache.data(), + v_cache.data(), + block_tables.data(), + seq_lens.data(), + None, + ) + ) + + lib_nsa_compress_paged_cache() + + if sync: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + print("Verifying compressed K cache...") + debug(k_cmp.actual_tensor(), ans_k, atol=atol, rtol=rtol) + print("Verifying compressed V cache...") + debug(v_cmp.actual_tensor(), ans_v, atol=atol, rtol=rtol) + assert torch.allclose(k_cmp.actual_tensor(), ans_k, atol=atol, rtol=rtol) + assert torch.allclose(v_cmp.actual_tensor(), ans_v, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: ref_nsa_compress_paged_cache( + k_cmp.torch_tensor(), + v_cmp.torch_tensor(), + k_cache.torch_tensor(), + v_cache.torch_tensor(), + block_tables.torch_tensor(), + seq_lens.torch_tensor(), + nsa_block_size, + update_last_only, + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lib_nsa_compress_paged_cache, device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyNsaCompressPagedCacheDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/nsa_paged_attention.py b/test/infiniop/nsa_paged_attention.py new file mode 100644 index 000000000..9d18a07f8 --- /dev/null +++ b/test/infiniop/nsa_paged_attention.py @@ -0,0 +1,380 @@ +import ctypes +from ctypes import c_uint64 + +import torch +from libinfiniop import ( + LIBINFINIOP, + InfiniDeviceNames, + InfiniDtype, + InfiniDtypeNames, + TestTensor, + TestWorkspace, + check_error, + debug, + get_args, + get_test_devices, + get_tolerance, + infiniopOperatorDescriptor_t, + profile_operation, + test_operator, +) + + +# ============================================================================== +# Reference Implementation +# ============================================================================== +def ref_nsa_compress_paged_cache( + k_cache, + v_cache, + block_tables, + seq_lens, + nsa_block_size, +): + page_block_size = k_cache.shape[2] + subblocks_per_page = page_block_size // nsa_block_size + k_cmp = torch.zeros( + (k_cache.shape[0] * subblocks_per_page, k_cache.shape[1], k_cache.shape[3]), + dtype=k_cache.dtype, + device=k_cache.device, + ) + v_cmp = torch.zeros_like(k_cmp) + + for seq in range(block_tables.shape[0]): + seq_len = int(seq_lens[seq].item()) + for nsa_block in range((seq_len + nsa_block_size - 1) // nsa_block_size): + tok_begin = nsa_block * nsa_block_size + tok_end = min(tok_begin + nsa_block_size, seq_len) + logical_page = tok_begin // page_block_size + subblock = (tok_begin % page_block_size) // nsa_block_size + physical = int(block_tables[seq, logical_page].item()) + cmp_block = physical * subblocks_per_page + subblock + rows = ( + torch.arange(tok_begin, tok_end, device=k_cache.device) + % page_block_size + ) + k_cmp[cmp_block] = ( + k_cache[physical, :, rows, :].float().mean(dim=1).to(k_cmp.dtype) + ) + v_cmp[cmp_block] = ( + v_cache[physical, :, rows, :].float().mean(dim=1).to(v_cmp.dtype) + ) + return k_cmp, v_cmp + + +def _attention_one(q, keys, values, scale): + if keys.numel() == 0: + return torch.zeros((values.shape[-1],), dtype=torch.float32, device=q.device) + scores = torch.sum(keys.float() * q.float().view(1, -1), dim=-1) * scale + probs = torch.softmax(scores, dim=-1) + return torch.sum(probs.view(-1, 1) * values.float(), dim=0) + + +def ref_nsa_paged_attention( + q, + k_cmp, + v_cmp, + k_cache, + v_cache, + block_tables, + seq_lens, + gates, + scale, + nsa_block_size, + window_size, + select_blocks, +): + num_seqs, num_heads, head_size = q.shape + num_kv_heads = k_cache.shape[1] + page_block_size = k_cache.shape[2] + subblocks_per_page = page_block_size // nsa_block_size + out = torch.empty_like(q, dtype=torch.float32) + + for seq in range(num_seqs): + seq_len = int(seq_lens[seq].item()) + nsa_blocks = (seq_len + nsa_block_size - 1) // nsa_block_size + for head in range(num_heads): + kv_head = head // (num_heads // num_kv_heads) + comp_keys = [] + comp_values = [] + block_scores = [] + for nsa_block in range(nsa_blocks): + tok_begin = nsa_block * nsa_block_size + logical_page = tok_begin // page_block_size + subblock = (tok_begin % page_block_size) // nsa_block_size + physical = int(block_tables[seq, logical_page].item()) + cmp_block = physical * subblocks_per_page + subblock + key = k_cmp[cmp_block, kv_head] + value = v_cmp[cmp_block, kv_head] + comp_keys.append(key) + comp_values.append(value) + block_scores.append( + torch.sum(q[seq, head].float() * key.float()) * scale + ) + + comp_out = _attention_one( + q[seq, head], torch.stack(comp_keys), torch.stack(comp_values), scale + ) + + top_count = min(select_blocks, nsa_blocks) + top_blocks = torch.topk( + torch.stack(block_scores), k=top_count + ).indices.tolist() + selected_keys = [] + selected_values = [] + for nsa_block in top_blocks: + tok_begin = nsa_block * nsa_block_size + tok_end = min(tok_begin + nsa_block_size, seq_len) + for tok in range(tok_begin, tok_end): + logical_page = tok // page_block_size + row = tok % page_block_size + physical = int(block_tables[seq, logical_page].item()) + selected_keys.append(k_cache[physical, kv_head, row]) + selected_values.append(v_cache[physical, kv_head, row]) + sel_out = _attention_one( + q[seq, head], + torch.stack(selected_keys), + torch.stack(selected_values), + scale, + ) + + win_begin = max(0, seq_len - window_size) if window_size > 0 else seq_len + window_keys = [] + window_values = [] + for tok in range(win_begin, seq_len): + logical_page = tok // page_block_size + row = tok % page_block_size + physical = int(block_tables[seq, logical_page].item()) + window_keys.append(k_cache[physical, kv_head, row]) + window_values.append(v_cache[physical, kv_head, row]) + win_out = _attention_one( + q[seq, head], + torch.stack(window_keys), + torch.stack(window_values), + scale, + ) + + g = gates[seq, :, head].float() + out[seq, head] = g[0] * comp_out + g[1] * sel_out + g[2] * win_out + return out.to(q.dtype) + + +# ============================================================================== +# Test Configuration (Internal Use Only) +# ============================================================================== +_TEST_CASES_ = [ + # (num_seqs, num_heads, num_kv_heads, page_block_size, max_seq_len, nsa_block_size, window_size, select_blocks) + (2, 4, 2, 128, 192, 64, 64, 2), + (3, 8, 2, 128, 255, 64, 128, 4), +] + +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 2e-2, "rtol": 2e-2}, + InfiniDtype.BF16: {"atol": 8e-2, "rtol": 8e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +def _seq_lens(num_seqs, max_seq_len, nsa_block_size): + values = [] + for i in range(num_seqs): + values.append(max(nsa_block_size + 1, max_seq_len - i * 31)) + return torch.tensor(values, dtype=torch.int64) + + +def test( + handle, + device, + num_seqs, + num_heads, + num_kv_heads, + page_block_size, + max_seq_len, + nsa_block_size, + window_size, + select_blocks, + dtype, + sync, +): + print( + f"Testing NsaPagedAttention on {InfiniDeviceNames[device]} with " + f"num_seqs={num_seqs}, num_heads={num_heads}, num_kv_heads={num_kv_heads}, " + f"page_block_size={page_block_size}, max_seq_len={max_seq_len}, " + f"nsa_block_size={nsa_block_size}, window_size={window_size}, " + f"select_blocks={select_blocks}, dtype={InfiniDtypeNames[dtype]}" + ) + + head_size = 128 + scale = 1.0 / (head_size**0.5) + max_blocks_per_seq = (max_seq_len + page_block_size - 1) // page_block_size + num_blocks = num_seqs * max_blocks_per_seq + q = TestTensor((num_seqs, num_heads, head_size), None, dtype, device, scale=0.1) + out = TestTensor( + (num_seqs, num_heads, head_size), None, dtype, device, mode="zeros" + ) + k_cache = TestTensor( + (num_blocks, num_kv_heads, page_block_size, head_size), + None, + dtype, + device, + scale=0.1, + ) + v_cache = TestTensor( + (num_blocks, num_kv_heads, page_block_size, head_size), + None, + dtype, + device, + scale=0.1, + ) + + block_tables_torch = torch.arange(num_blocks, dtype=torch.int64).view( + num_seqs, max_blocks_per_seq + ) + seq_lens_torch = _seq_lens(num_seqs, max_seq_len, nsa_block_size) + block_tables = TestTensor.from_torch(block_tables_torch, InfiniDtype.I64, device) + seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I64, device) + + k_cmp_torch, v_cmp_torch = ref_nsa_compress_paged_cache( + k_cache.torch_tensor(), + v_cache.torch_tensor(), + block_tables.torch_tensor(), + seq_lens.torch_tensor(), + nsa_block_size, + ) + k_cmp = TestTensor.from_torch(k_cmp_torch, dtype, device) + v_cmp = TestTensor.from_torch(v_cmp_torch, dtype, device) + + gates_torch = torch.rand((num_seqs, 3, num_heads), dtype=torch.float32) * 0.8 + 0.1 + gates = TestTensor.from_torch(gates_torch, InfiniDtype.F32, device) + + ans = ref_nsa_paged_attention( + q.torch_tensor(), + k_cmp.torch_tensor(), + v_cmp.torch_tensor(), + k_cache.torch_tensor(), + v_cache.torch_tensor(), + block_tables.torch_tensor(), + seq_lens.torch_tensor(), + gates.torch_tensor(), + scale, + nsa_block_size, + window_size, + select_blocks, + ) + + if sync: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateNsaPagedAttentionDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + q.descriptor, + k_cmp.descriptor, + v_cmp.descriptor, + k_cache.descriptor, + v_cache.descriptor, + block_tables.descriptor, + seq_lens.descriptor, + gates.descriptor, + scale, + nsa_block_size, + window_size, + select_blocks, + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetNsaPagedAttentionWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + out.destroy_desc() + q.destroy_desc() + k_cmp.destroy_desc() + v_cmp.destroy_desc() + k_cache.destroy_desc() + v_cache.destroy_desc() + block_tables.destroy_desc() + seq_lens.destroy_desc() + gates.destroy_desc() + + def lib_nsa_paged_attention(): + check_error( + LIBINFINIOP.infiniopNsaPagedAttention( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + q.data(), + k_cmp.data(), + v_cmp.data(), + k_cache.data(), + v_cache.data(), + block_tables.data(), + seq_lens.data(), + gates.data(), + None, + ) + ) + + lib_nsa_paged_attention() + + if sync: + sync() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: ref_nsa_paged_attention( + q.torch_tensor(), + k_cmp.torch_tensor(), + v_cmp.torch_tensor(), + k_cache.torch_tensor(), + v_cache.torch_tensor(), + block_tables.torch_tensor(), + seq_lens.torch_tensor(), + gates.torch_tensor(), + scale, + nsa_block_size, + window_size, + select_blocks, + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lib_nsa_paged_attention, device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyNsaPagedAttentionDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m")