Skip to content

Commit 3754a99

Browse files
[Feature] block sparse attention (#3668)
* 支持稀疏attn * fix bug * code style * fix moba attn get kv shape * 修复a100编译 * codestyle * code style * code style * code style * fix conflict * 增加单侧 * code style * 增加eblite 加载时间 * fix bug * for ci * for ci * for ci * for ci * 支持mlp block size 128 * 增加小算子单测 * fix 单测 mlp * 将环境变量加入到config里面 * fix rollout config * 修复显存 * add test server * add test server * fix mlp 最后一层使用full attn
1 parent ccd52b5 commit 3754a99

31 files changed

+6553
-10
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -845,15 +845,15 @@ void SpeculateStepPaddle(
845845
const int max_draft_tokens);
846846

847847
void MergePrefillDecodeOutput(
848-
const paddle::Tensor &encoder_res,
849-
const paddle::Tensor &decoder_res,
850-
const paddle::Tensor &seq_lens_encoder,
851-
const paddle::Tensor &seq_lens_decoder,
852-
const paddle::Tensor &seq_lens_this_time,
853-
const paddle::Tensor &cu_seq_q,
854-
const int head_num,
855-
const int head_dim,
856-
const int max_token);
848+
const paddle::Tensor &encoder_res,
849+
const paddle::Tensor &decoder_res,
850+
const paddle::Tensor &seq_lens_encoder,
851+
const paddle::Tensor &seq_lens_decoder,
852+
const paddle::Tensor &seq_lens_this_time,
853+
const paddle::Tensor &cu_seq_q,
854+
const int head_num,
855+
const int head_dim,
856+
const int max_token);
857857

858858
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
859859
const paddle::Tensor &top_p,
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
#include "moba_attn.h"
17+
18+
19+
std::vector<paddle::Tensor> MobaAttention(
20+
const paddle::Tensor& qkv,
21+
const paddle::Tensor& q_input,
22+
const paddle::Tensor& k_input,
23+
const paddle::Tensor& v_input,
24+
const paddle::Tensor& cu_seq_q,
25+
const paddle::Tensor& cu_seq_k,
26+
const paddle::Tensor& cu_seq_q_pack,
27+
const paddle::Tensor& q_pack_tokens,
28+
const paddle::Tensor& seq_len_encoder,
29+
const paddle::Tensor& seq_len_decoder,
30+
const paddle::Tensor& key_cache,
31+
const paddle::Tensor& value_cache,
32+
const paddle::Tensor& block_tables,
33+
const paddle::Tensor& rope_sin_cos,
34+
const paddle::Tensor& k_block_means,
35+
const paddle::optional<paddle::Tensor>& attn_gate_weight,
36+
const paddle::optional<paddle::Tensor>& qkv_bias,
37+
const paddle::optional<paddle::Tensor>& cache_k_quant_scale,
38+
const paddle::optional<paddle::Tensor>& cache_v_quant_scale,
39+
const paddle::optional<paddle::Tensor>& cache_k_dequant_scale,
40+
const paddle::optional<paddle::Tensor>& cache_v_dequant_scale,
41+
const paddle::optional<paddle::Tensor>& cache_k_zero_points,
42+
const paddle::optional<paddle::Tensor>& cache_v_zero_points,
43+
const int head_num,
44+
const int kv_head_num,
45+
const int head_dim,
46+
const int max_seq_len,
47+
const int max_enc_len_this_time,
48+
const int max_dec_len_this_time,
49+
const int moba_encoder_top_k_left,
50+
const int moba_encoder_top_k_right,
51+
const int moba_use_encoder_seq_limit,
52+
const int moba_decoder_top_k_left,
53+
const int moba_decoder_top_k_right,
54+
const int moba_use_decoder_seq_limit,
55+
const bool moba_use_mlp,
56+
const std::string &cache_quant_type_str) {
57+
58+
paddle::Tensor out = paddle::empty({qkv.dims()[0], head_num * head_dim}, qkv.dtype(), qkv.place());
59+
if (max_dec_len_this_time > 0) {
60+
MobaDecoderAttnWriteCacheKv(
61+
qkv,
62+
q_input,
63+
cu_seq_q,
64+
cu_seq_k,
65+
seq_len_encoder,
66+
seq_len_decoder,
67+
key_cache,
68+
value_cache,
69+
block_tables,
70+
rope_sin_cos,
71+
k_block_means,
72+
qkv_bias,
73+
cache_k_quant_scale,
74+
cache_v_quant_scale,
75+
cache_k_dequant_scale,
76+
cache_v_dequant_scale,
77+
cache_k_zero_points,
78+
cache_v_zero_points,
79+
head_num,
80+
kv_head_num,
81+
head_dim,
82+
max_seq_len,
83+
cache_quant_type_str);
84+
85+
auto qk_gate_weight = MobaQKGemm(
86+
q_input,
87+
k_block_means,
88+
seq_len_encoder,
89+
seq_len_decoder,
90+
cu_seq_q,
91+
cu_seq_k,
92+
max_dec_len_this_time,
93+
max_dec_len_this_time,
94+
head_num,
95+
kv_head_num,
96+
true,
97+
moba_use_decoder_seq_limit
98+
)[0];
99+
100+
auto qk_gate_topk_idx = QkSortDecoder(
101+
qk_gate_weight,
102+
seq_len_encoder,
103+
seq_len_decoder,
104+
head_num,
105+
kv_head_num,
106+
moba_decoder_top_k_left,
107+
moba_decoder_top_k_right,
108+
moba_use_decoder_seq_limit
109+
)[0];
110+
111+
MobaDecoderAttn(
112+
q_input,
113+
seq_len_encoder,
114+
seq_len_decoder,
115+
cu_seq_q,
116+
key_cache,
117+
value_cache,
118+
block_tables,
119+
k_block_means,
120+
out,
121+
qk_gate_topk_idx,
122+
cache_k_quant_scale,
123+
cache_v_quant_scale,
124+
cache_k_dequant_scale,
125+
cache_v_dequant_scale,
126+
cache_k_zero_points,
127+
cache_v_zero_points,
128+
head_num,
129+
kv_head_num,
130+
head_dim,
131+
max_seq_len,
132+
moba_use_decoder_seq_limit,
133+
max_dec_len_this_time,
134+
max_dec_len_this_time,
135+
cache_quant_type_str
136+
);
137+
}
138+
139+
if (max_enc_len_this_time > 0) {
140+
FusedBlockMeanAndRope(
141+
qkv,
142+
k_block_means,
143+
q_input,
144+
k_input,
145+
v_input,
146+
rope_sin_cos,
147+
seq_len_encoder,
148+
seq_len_decoder,
149+
cu_seq_q,
150+
cu_seq_k,
151+
qkv_bias,
152+
head_num,
153+
kv_head_num,
154+
head_dim,
155+
max_seq_len,
156+
max_enc_len_this_time,
157+
max_enc_len_this_time,
158+
cache_quant_type_str
159+
);
160+
161+
MobaEncoderAttnWriteCacheKv(
162+
k_input,
163+
v_input,
164+
cu_seq_k,
165+
seq_len_encoder,
166+
seq_len_decoder,
167+
key_cache,
168+
value_cache,
169+
block_tables,
170+
cache_k_quant_scale,
171+
cache_v_quant_scale,
172+
cache_k_dequant_scale,
173+
cache_v_dequant_scale,
174+
cache_k_zero_points,
175+
cache_v_zero_points,
176+
head_num,
177+
kv_head_num,
178+
head_dim,
179+
max_enc_len_this_time,
180+
cache_quant_type_str
181+
);
182+
183+
GetKVFromCache(
184+
k_input,
185+
v_input,
186+
cu_seq_k,
187+
seq_len_encoder,
188+
seq_len_decoder,
189+
key_cache,
190+
value_cache,
191+
block_tables,
192+
cache_k_dequant_scale,
193+
cache_v_dequant_scale,
194+
cache_k_zero_points,
195+
cache_v_zero_points,
196+
head_num,
197+
kv_head_num,
198+
head_dim,
199+
max_seq_len,
200+
max_enc_len_this_time + max_dec_len_this_time,
201+
cache_quant_type_str
202+
);
203+
204+
paddle::Tensor *k_gate_weight = const_cast<paddle::Tensor*>(&k_block_means);
205+
206+
if (moba_use_mlp && attn_gate_weight) {
207+
paddle::Tensor k_gate_mlp = MobaMlpEinsum(
208+
k_input,
209+
attn_gate_weight.get(),
210+
seq_len_encoder,
211+
seq_len_decoder,
212+
cu_seq_k,
213+
max_seq_len,
214+
kv_head_num
215+
)[0];
216+
k_gate_weight = &k_gate_mlp;
217+
}
218+
219+
auto qk_gate_weight = MobaQKGemm(
220+
q_input,
221+
*k_gate_weight,
222+
seq_len_encoder,
223+
seq_len_decoder,
224+
cu_seq_q,
225+
cu_seq_k,
226+
max_enc_len_this_time,
227+
max_enc_len_this_time + max_dec_len_this_time,
228+
head_num,
229+
kv_head_num,
230+
false,
231+
moba_use_encoder_seq_limit
232+
)[0];
233+
234+
235+
auto qk_gate_topk_idx = QkSortEncoder(
236+
qk_gate_weight,
237+
seq_len_encoder,
238+
seq_len_decoder,
239+
cu_seq_q,
240+
cu_seq_k,
241+
cu_seq_q_pack,
242+
q_pack_tokens,
243+
max_enc_len_this_time,
244+
max_enc_len_this_time + max_dec_len_this_time,
245+
head_num,
246+
kv_head_num,
247+
moba_encoder_top_k_left,
248+
moba_encoder_top_k_right,
249+
moba_use_mlp && !attn_gate_weight ? max_seq_len : moba_use_encoder_seq_limit)[0];
250+
251+
MobaEncoderAttn(
252+
q_input,
253+
k_input,
254+
v_input,
255+
qk_gate_topk_idx,
256+
cu_seq_q,
257+
cu_seq_k,
258+
cu_seq_q_pack,
259+
seq_len_encoder,
260+
seq_len_decoder,
261+
out,
262+
max_enc_len_this_time,
263+
max_enc_len_this_time + max_dec_len_this_time,
264+
head_num,
265+
kv_head_num,
266+
head_dim,
267+
max_seq_len
268+
);
269+
}
270+
271+
return {out};
272+
}
273+
274+
275+
PD_BUILD_OP(moba_attention)
276+
.Inputs({
277+
"qkv",
278+
"q_input",
279+
"k_input",
280+
"v_input",
281+
"cu_seq_q",
282+
"cu_seq_k",
283+
"cu_seq_q_pack",
284+
"q_pack_tokens",
285+
"seq_len_encoder",
286+
"seq_len_decoder",
287+
"key_cache",
288+
"value_cache",
289+
"block_tables",
290+
"rope_sin_cos",
291+
"k_block_means",
292+
paddle::Optional("attn_gate_weight"),
293+
paddle::Optional("qkv_bias"),
294+
paddle::Optional("cache_k_quant_scale"),
295+
paddle::Optional("cache_v_quant_scale"),
296+
paddle::Optional("cache_k_dequant_scale"),
297+
paddle::Optional("cache_v_dequant_scale"),
298+
paddle::Optional("cache_k_zero_points"),
299+
paddle::Optional("cache_v_zero_points")})
300+
.Attrs({
301+
"head_num: int",
302+
"kv_head_num: int",
303+
"head_dim: int",
304+
"max_seq_len: int",
305+
"max_enc_len_this_time: int",
306+
"max_dec_len_this_time: int",
307+
"moba_encoder_top_k_left: int",
308+
"moba_encoder_top_k_right: int",
309+
"moba_use_encoder_seq_limit: int",
310+
"moba_decoder_top_k_left: int",
311+
"moba_decoder_top_k_right: int",
312+
"moba_use_decoder_seq_limit: int",
313+
"moba_use_mlp: bool",
314+
"cache_quant_type_str: std::string"})
315+
.Outputs({
316+
"out",
317+
"q_input_out",
318+
"k_input_out",
319+
"v_input_out",
320+
"key_cache_out",
321+
"value_cache_out",
322+
"k_block_means_out"})
323+
.SetInplaceMap({{
324+
"q_input", "q_input_out"},
325+
{"k_input", "k_input_out"},
326+
{"v_input", "v_input_out"},
327+
{"key_cache", "key_cache_out"},
328+
{"value_cache", "value_cache_out"},
329+
{"k_block_means", "k_block_means_out"}})
330+
.SetKernelFn(PD_KERNEL(MobaAttention));

0 commit comments

Comments
 (0)