Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/infinicore/ops/nsa_compress_paged_cache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(NsaCompressPagedCache, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, int, bool);

void nsa_compress_paged_cache_(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size,
bool update_last_only = false);

} // namespace infinicore::op
19 changes: 19 additions & 0 deletions include/infinicore/ops/nsa_paged_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(NsaPagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, int, int, int);

Tensor nsa_paged_attention(const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates,
float scale, int nsa_block_size, int window_size, int select_blocks);

void nsa_paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates,
float scale, int nsa_block_size, int window_size, int select_blocks);

} // namespace infinicore::op
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@
#include "infiniop/ops/mul.h"
#include "infiniop/ops/multi_margin_loss.h"
#include "infiniop/ops/nrm2.h"
#include "infiniop/ops/nsa_compress_paged_cache.h"
#include "infiniop/ops/nsa_paged_attention.h"
#include "infiniop/ops/ones.h"
#include "infiniop/ops/pad.h"
#include "infiniop/ops/paged_attention.h"
Expand Down
39 changes: 39 additions & 0 deletions include/infiniop/ops/nsa_compress_paged_cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef __INFINIOP_NSA_COMPRESS_PAGED_CACHE_API_H__
#define __INFINIOP_NSA_COMPRESS_PAGED_CACHE_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopNsaCompressPagedCacheDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateNsaCompressPagedCacheDescriptor(
infiniopHandle_t handle,
infiniopNsaCompressPagedCacheDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cmp_desc,
infiniopTensorDescriptor_t v_cmp_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
int nsa_block_size,
int update_last_only);

__INFINI_C __export infiniStatus_t infiniopGetNsaCompressPagedCacheWorkspaceSize(
infiniopNsaCompressPagedCacheDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t infiniopNsaCompressPagedCache(
infiniopNsaCompressPagedCacheDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cmp,
void *v_cmp,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyNsaCompressPagedCacheDescriptor(
infiniopNsaCompressPagedCacheDescriptor_t desc);

#endif // __INFINIOP_NSA_COMPRESS_PAGED_CACHE_API_H__
47 changes: 47 additions & 0 deletions include/infiniop/ops/nsa_paged_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#ifndef __INFINIOP_NSA_PAGED_ATTENTION_API_H__
#define __INFINIOP_NSA_PAGED_ATTENTION_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopNsaPagedAttentionDescriptor_t;

__INFINI_C __export infiniStatus_t infiniopCreateNsaPagedAttentionDescriptor(
infiniopHandle_t handle,
infiniopNsaPagedAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cmp_desc,
infiniopTensorDescriptor_t v_cmp_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t gates_desc,
float scale,
int nsa_block_size,
int window_size,
int select_blocks);

__INFINI_C __export infiniStatus_t infiniopGetNsaPagedAttentionWorkspaceSize(
infiniopNsaPagedAttentionDescriptor_t desc,
size_t *size);

__INFINI_C __export infiniStatus_t infiniopNsaPagedAttention(
infiniopNsaPagedAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cmp,
const void *v_cmp,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *gates,
void *stream);

__INFINI_C __export infiniStatus_t infiniopDestroyNsaPagedAttentionDescriptor(
infiniopNsaPagedAttentionDescriptor_t desc);

#endif // __INFINIOP_NSA_PAGED_ATTENTION_API_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "infinicore/ops/nsa_compress_paged_cache.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(NsaCompressPagedCache);

NsaCompressPagedCache::NsaCompressPagedCache(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size,
bool update_last_only) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens);
INFINICORE_GRAPH_OP_DISPATCH(k_cmp->device().getType(), k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, nsa_block_size, update_last_only);
}

void NsaCompressPagedCache::execute(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size,
bool update_last_only) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
NsaCompressPagedCache,
k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, nsa_block_size, update_last_only);
}

void nsa_compress_paged_cache_(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, int nsa_block_size,
bool update_last_only) {
NsaCompressPagedCache::execute(k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, nsa_block_size, update_last_only);
}

} // namespace infinicore::op
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#include "infinicore/ops/nsa_compress_paged_cache.hpp"

#include "../infiniop_impl.hpp"

namespace infinicore::op::nsa_compress_paged_cache_impl::infiniop {

INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, NsaCompressPagedCache, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens;
};

void *plan(Tensor k_cmp, Tensor v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &cache_lens, int nsa_block_size, bool update_last_only) {
size_t seed = hash_combine(k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, nsa_block_size, update_last_only);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, NsaCompressPagedCache,
seed,
k_cmp->desc(), v_cmp->desc(), k_cache->desc(), v_cache->desc(),
block_tables->desc(), cache_lens->desc(), nsa_block_size, static_cast<int>(update_last_only));

INFINIOP_WORKSPACE_TENSOR(workspace, NsaCompressPagedCache, descriptor);

return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(k_cmp),
graph::GraphTensor(v_cmp),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(block_tables),
graph::GraphTensor(cache_lens)};
}

void run(void *planned_meta) {
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(
infiniopNsaCompressPagedCache(
p->descriptor->desc,
p->workspace->data(),
p->workspace->numel(),
p->k_cmp->data(),
p->v_cmp->data(),
p->k_cache->data(),
p->v_cache->data(),
p->block_tables->data(),
p->cache_lens->data(),
context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(NsaCompressPagedCache, &plan, &run, &cleanup);

} // namespace infinicore::op::nsa_compress_paged_cache_impl::infiniop
40 changes: 40 additions & 0 deletions src/infinicore/ops/nsa_paged_attention/nsa_paged_attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "infinicore/ops/nsa_paged_attention.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(NsaPagedAttention);

NsaPagedAttention::NsaPagedAttention(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp,
const Tensor &k_cache, const Tensor &v_cache, const Tensor &block_tables,
const Tensor &kv_lens, const Tensor &gates, float scale, int nsa_block_size,
int window_size, int select_blocks) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks);
}

void NsaPagedAttention::execute(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp,
const Tensor &k_cache, const Tensor &v_cache, const Tensor &block_tables,
const Tensor &kv_lens, const Tensor &gates, float scale, int nsa_block_size,
int window_size, int select_blocks) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
NsaPagedAttention,
out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks);
}

Tensor nsa_paged_attention(const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates,
float scale, int nsa_block_size, int window_size, int select_blocks) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
nsa_paged_attention_(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks);
return out;
}

void nsa_paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens, const Tensor &gates,
float scale, int nsa_block_size, int window_size, int select_blocks) {
NsaPagedAttention::execute(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, kv_lens, gates, scale, nsa_block_size, window_size, select_blocks);
}

} // namespace infinicore::op
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "infinicore/ops/nsa_paged_attention.hpp"

#include "../infiniop_impl.hpp"

namespace infinicore::op::nsa_paged_attention_impl::infiniop {

INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, NsaPagedAttention, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates;
};

void *plan(Tensor out, const Tensor &q, const Tensor &k_cmp, const Tensor &v_cmp, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &cache_lens, const Tensor &gates,
float scale, int nsa_block_size, int window_size, int select_blocks) {
size_t seed = hash_combine(out, q, k_cmp, v_cmp, k_cache, v_cache, block_tables, cache_lens, gates, nsa_block_size, window_size, select_blocks);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, NsaPagedAttention,
seed,
out->desc(), q->desc(), k_cmp->desc(), v_cmp->desc(), k_cache->desc(), v_cache->desc(),
block_tables->desc(), cache_lens->desc(), gates->desc(),
scale, nsa_block_size, window_size, select_blocks);

INFINIOP_WORKSPACE_TENSOR(workspace, NsaPagedAttention, descriptor);

return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k_cmp),
graph::GraphTensor(v_cmp),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(block_tables),
graph::GraphTensor(cache_lens),
graph::GraphTensor(gates)};
}

void run(void *planned_meta) {
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(
infiniopNsaPagedAttention(
p->descriptor->desc,
p->workspace->data(),
p->workspace->numel(),
p->out->data(),
p->q->data(),
p->k_cmp->data(),
p->v_cmp->data(),
p->k_cache->data(),
p->v_cache->data(),
p->block_tables->data(),
p->cache_lens->data(),
p->gates->data(),
context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(NsaPagedAttention, &plan, &run, &cleanup);

} // namespace infinicore::op::nsa_paged_attention_impl::infiniop
Loading
Loading