Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
logs/

/benchmark/
11 changes: 7 additions & 4 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn


class Fa3AttBackend(BaseAttBackend):
Expand Down Expand Up @@ -125,8 +126,9 @@ class Fa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: Fa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens)

if args_mtp_step > 0:
if is_mtp_verify_decode:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -143,8 +145,9 @@ def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1
att_batch_size = self.infer_state.batch_size // mtp_size
assert self.infer_state.batch_size % mtp_size == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -163,7 +166,7 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_mtp_verify_decode:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
29 changes: 15 additions & 14 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from ..base_att import AttControl
from typing import Optional, TYPE_CHECKING
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
from typing import Union
Expand Down Expand Up @@ -45,9 +44,12 @@ def init_state(self):
torch.arange(batch_size, device=device), self.infer_state.b_q_seq_len
)
# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)

self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

def prefill_att(
self,
Expand Down Expand Up @@ -116,20 +118,19 @@ def init_state(self):
super().init_state()
self.backend: Fp8Fa3AttBackend = self.backend

args_mtp_step = get_env_start_args().mtp_step
att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0

device = self.infer_state.input_ids.device
batch_size = att_batch_size
batch_size = self.b_att_seq_len.shape[0]
mem_manager = self.backend.model.mem_manager

offline_scales: torch.Tensor = mem_manager.scales
head_num = mem_manager.head_num

# 为了减少推理计算量,在推理外部初始化k_descale和v_descale
self.k_descale = offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.v_descale = offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
self.k_descale = (
offline_scales[:, :head_num].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)
self.v_descale = (
offline_scales[:, head_num:].view(-1, 1, head_num).expand(offline_scales.shape[0], batch_size, head_num)
)

return

Expand Down Expand Up @@ -180,11 +181,11 @@ def _fp8_decode_att(
k_cache=cache_k,
v_cache=cache_v,
page_table=self.page_table,
cache_seqlens=self.infer_state.b_seq_len,
cache_seqlens=self.b_att_seq_len,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_k_new=self.cu_seqlens_k,
max_seqlen_q=self.decode_max_q_seq_len,
causal=False,
causal=True,
window_size=(-1, -1),
softcap=0.0,
q_descale=q_scale.view(self.infer_state.batch_size, k_head_num),
Expand Down
11 changes: 7 additions & 4 deletions lightllm/common/basemodel/attention/fa3/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.utils.sgl_utils import flash_attn_varlen_func
from lightllm.common.basemodel.batch_objs import is_mtp_verify_decode as is_mtp_verify_decode_fn


class MlaFa3AttBackend(BaseAttBackend):
Expand Down Expand Up @@ -108,8 +109,9 @@ class MlaFa3DecodeAttState(BaseDecodeAttState):
def init_state(self):
self.backend: MlaFa3AttBackend = self.backend
args_mtp_step = get_env_start_args().mtp_step
is_mtp_verify_decode = is_mtp_verify_decode_fn(args_mtp_step, self.infer_state.b_num_accepted_tokens)

if args_mtp_step > 0:
if is_mtp_verify_decode:
# 修正 mtp 在 fa3 下的输入。
mtp_size = args_mtp_step + 1
b_q_seq_len = torch.full(
Expand All @@ -126,8 +128,9 @@ def init_state(self):
self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int()
self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int()

att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1)
assert self.infer_state.batch_size % (args_mtp_step + 1) == 0
mtp_size = args_mtp_step + 1 if is_mtp_verify_decode else 1
att_batch_size = self.infer_state.batch_size // mtp_size
assert self.infer_state.batch_size % mtp_size == 0

model = self.backend.model
# 可以使用 cuda graph的时候从 buffer中申请
Expand All @@ -146,7 +149,7 @@ def init_state(self):
device=self.infer_state.input_ids.device,
)

if args_mtp_step > 0:
if is_mtp_verify_decode:
page_table_copy(
page_table=self.page_table[:, : self.infer_state.max_kv_seq_len],
req_to_token_indexs=model.req_manager.req_to_token_indexs,
Expand Down
Loading
Loading