Skip to content

Commit a726914

Browse files
for ci
1 parent a2d574d commit a726914

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

test/layers/test_moba_attention.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,27 @@
2121
get_cur_cu_seq_len_k = None
2222
import os
2323

24+
from fastdeploy.model_executor.layers.attention.attention import Attention
2425
from fastdeploy.model_executor.layers.attention.moba_attention_backend import (
2526
MobaAttentionBackend,
2627
)
28+
from fastdeploy.platforms import _Backend, current_platform
2729

2830

2931
class ModelConfig:
3032
def __init__(self):
3133
self.num_hidden_layers = 12
3234
self.head_dim = 128
35+
self.num_attention_heads = 8
36+
self.num_key_value_heads = 1
3337

3438

3539
class ParallelConfig:
3640
def __init__(self):
3741
self.block_size = 128
3842
self.max_model_len = 128 * 1024
3943
self.max_num_seqs = 1
44+
self.tensor_parallel_size = 1
4045

4146

4247
class ForwardMode:
@@ -48,16 +53,17 @@ class FDConfig:
4853
def __init__(self):
4954
self.parallel_config = ParallelConfig()
5055
self.model_config = ModelConfig()
56+
self.quant_config = {}
5157

5258

5359
def test_moba_attention(seq_len, num_heads, num_kv_heads, head_dim):
5460
max_seq_len = int(128 * 1024)
55-
moba_encoder_top_k_left = int(10)
56-
moba_encoder_top_k_right = int(15)
57-
moba_use_encoder_seq_limit = int(10 * 128)
58-
moba_decoder_top_k_left = int(10)
59-
moba_decoder_top_k_right = int(10)
60-
moba_use_decoder_seq_limit = int(10 * 128)
61+
moba_encoder_top_k_left = int(5)
62+
moba_encoder_top_k_right = int(10)
63+
moba_use_encoder_seq_limit = int(20 * 128)
64+
moba_decoder_top_k_left = int(20)
65+
moba_decoder_top_k_right = int(20)
66+
moba_use_decoder_seq_limit = int(20 * 128)
6167
os.environ["FD_ATTENTION_BACKEND"] = "MOBA_ATTN"
6268
os.environ["FD_MOBA_MLP_WEIGHT_PATH"] = "None"
6369
os.environ["FD_MOBA_ENCODER_TOP_K_LEFT"] = str(moba_encoder_top_k_left)
@@ -70,7 +76,7 @@ def test_moba_attention(seq_len, num_heads, num_kv_heads, head_dim):
7076
os.environ["FD_MOBA_MAX_SEQ_LENGTH"] = str(max_seq_len)
7177

7278
max_dec_len_this_time = int(0)
73-
qkv = paddle.randn([1, seq_len, num_heads + 2 * num_kv_heads, head_dim], dtype="bfloat16")
79+
qkv = paddle.randn([1, 4 * seq_len, num_heads + 2 * num_kv_heads, head_dim], dtype="bfloat16")
7480
q_input = qkv[:, :, :num_heads, :].reshape([-1, num_heads, head_dim])
7581
k_input = qkv[:, :, num_heads : num_heads + num_kv_heads, :].reshape([-1, num_kv_heads, head_dim])
7682
v_input = qkv[:, :, num_heads + num_kv_heads :, :].reshape([-1, num_kv_heads, head_dim])
@@ -80,14 +86,14 @@ def test_moba_attention(seq_len, num_heads, num_kv_heads, head_dim):
8086

8187
seq_lens_decoder = paddle.to_tensor([0], dtype="int32")
8288

83-
cachesk = paddle.zeros([(seq_len + 63) // 64 * 64, num_kv_heads, 64, head_dim], dtype="bfloat16")
84-
cachesv = paddle.zeros([(seq_len + 63) // 64 * 64, num_kv_heads, 64, head_dim], dtype="bfloat16")
89+
cachesk = paddle.zeros([(seq_len + 63) // 64 * 256, num_kv_heads, 64, head_dim], dtype="bfloat16")
90+
cachesv = paddle.zeros([(seq_len + 63) // 64 * 256, num_kv_heads, 64, head_dim], dtype="bfloat16")
8591

8692
block_tables = paddle.arange((seq_len + 63) // 64).astype("int32")
8793

8894
rotary_embs = paddle.ones([seq_len, head_dim], dtype="float32")
8995

90-
cache_k_block_means = paddle.zeros([(seq_len + 63) // 64, num_kv_heads, 64, head_dim], dtype="bfloat16")
96+
cache_k_block_means = paddle.zeros([(seq_len + 63) // 64 + 10, num_kv_heads, 64, head_dim], dtype="bfloat16")
9197

9298
fd_config = FDConfig()
9399
forward_meta = ForwardMode()
@@ -98,6 +104,7 @@ def test_moba_attention(seq_len, num_heads, num_kv_heads, head_dim):
98104
moba_attention_backend = MobaAttentionBackend(fd_config, 8, 1, 128)
99105
moba_attention_backend.init_attention_metadata(forward_meta)
100106
moba_attention_backend.get_kv_cache_shape(100)
107+
101108
if moba_attention is None:
102109
return
103110
if get_cur_cu_seq_len_k is None:
@@ -150,12 +157,17 @@ def test_moba_attention(seq_len, num_heads, num_kv_heads, head_dim):
150157
"none",
151158
)[0]
152159

153-
return out
160+
attention = Attention(fd_config, 0)
161+
162+
selected_backend = _Backend.__members__.get(os.environ["FD_ATTENTION_BACKEND"])
163+
attention_cls = current_platform.get_attention_backend_cls(selected_backend)
164+
165+
return out, attention, attention_cls
154166

155167

156168
if __name__ == "__main__":
157169
if paddle.is_compiled_with_cuda():
158-
seq_len = int(20 * 1024)
170+
seq_len = int(2 * 1024)
159171
num_heads = int(8)
160172
num_kv_heads = int(1)
161173
head_dim = int(128)

0 commit comments

Comments
 (0)