Skip to content

Commit c1debd8

Browse files
authored
fix block table bugs (ROCm#310)
* fix block table bugs * add seqlens_k args * add return type * change reshape and cache api * add block table transfer layer * add output argument * fix a bug * remove useless file * remove seq_k api
1 parent 92b72cc commit c1debd8

File tree

8 files changed

+90
-53
lines changed

8 files changed

+90
-53
lines changed

aiter/ops/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def reshape_and_cache_flash(
3939
value_cache: Tensor,
4040
slot_mapping: Tensor,
4141
kv_cache_dtype: str,
42-
k_scale: float,
43-
v_scale: float,
42+
k_scale: Tensor,
43+
v_scale: Tensor,
4444
): ...
4545

4646
@compile_ops("module_cache")

aiter/ops/mha.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
2+
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
from torch import Tensor, Generator
55
from typing import Optional, Tuple
66
from ..jit.core import compile_ops, CK_DIR, AITER_CSRC_DIR, AITER_ROOT_DIR
77
from ..utility import dtypes
88
import torch
99

10+
1011
@compile_ops("module_mha_fwd", fc_name="mha_fwd")
1112
def mha_fwd(
1213
q: Tensor,
@@ -48,7 +49,7 @@ def mha_varlen_fwd(
4849
bias: Optional[Tensor] = None,
4950
alibi_slopes: Optional[Tensor] = None,
5051
gen: Optional[Generator] = None,
51-
): ...
52+
) -> list[Tensor]: ...
5253

5354

5455
@compile_ops("module_mha_bwd", fc_name="mha_bwd")
@@ -419,7 +420,9 @@ def pssk():
419420
# bwd_hd64_bf16_causal_a32_rtz_pssk
420421
# bwd_hd64_fp16_a32_pssk
421422
# bwd_hd64_fp16_causal_a32_pssk
422-
ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
423+
ret = (
424+
is_v3_atomic_fp32 == True
425+
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
423426
ret &= hdim_q == 64
424427
ret &= nmask or (
425428
mask and seqlen_q == seqlen_k
@@ -474,7 +477,9 @@ def psskddv():
474477
# bwd_hd192_bf16_causal_a32_rtz_psskddv
475478
ret = is_v3_atomic_fp32 == True
476479
ret &= hdim_q > 64 and hdim_q <= 192
477-
ret &= nmask or (mask and seqlen_q == seqlen_k) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
480+
ret &= nmask or (
481+
mask and seqlen_q == seqlen_k
482+
) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
478483

479484
return ret
480485

@@ -759,6 +764,7 @@ def _flash_attn_varlen_forward(
759764
return_lse: bool = False,
760765
return_softmax: bool = False,
761766
block_table: Optional[torch.Tensor] = None,
767+
out: Optional[torch.Tensor] = None,
762768
zero_tensors: bool = False,
763769
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
764770
# causal=true is the same as causal=false in this case
@@ -878,7 +884,7 @@ def _flash_attn_varlen_forward(
878884
window_size_right,
879885
return_lse,
880886
return_softmax,
881-
None,
887+
out,
882888
block_table,
883889
bias,
884890
alibi_slopes,
@@ -963,7 +969,9 @@ def _flash_attn_varlen_backward(
963969
]
964970

965971
(_, nhead_q, hdim_q) = q.shape
966-
(_, nhead_k, hdim_v) = v.shape
972+
973+
nhead_k = v.shape[-2]
974+
hdim_v = v.shape[-1]
967975

968976
# mask
969977
window_size_left = -1 if window_size_left >= max_seqlen_k else window_size_left
@@ -994,12 +1002,14 @@ def pssk():
9941002
# bwd_hd128_bf16_causal_a32_rtz_pssk_group
9951003
# bwd_hd128_fp16_a32_pssk_group
9961004
# bwd_hd128_fp16_causal_a32_pssk_group
997-
ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1005+
ret = (
1006+
is_v3_atomic_fp32 == True
1007+
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
9981008
ret &= hdim_q == 64 or hdim_q == 128
999-
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1009+
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
10001010

10011011
return ret
1002-
1012+
10031013
def psskddv():
10041014
# bwd_hd128_bf16_a32_rtne_psskddv_group
10051015
# bwd_hd128_bf16_a32_rtna_psskddv_group
@@ -1009,9 +1019,11 @@ def psskddv():
10091019
# bwd_hd128_bf16_causal_a32_rtz_psskddv_group
10101020
# bwd_hd128_fp16_a32_psskddv_group
10111021
# bwd_hd128_fp16_causal_a32_psskddv_group
1012-
ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1022+
ret = (
1023+
is_v3_atomic_fp32 == True
1024+
) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
10131025
ret &= hdim_q > 64 and hdim_q < 128
1014-
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1026+
ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
10151027

10161028
return ret
10171029

@@ -1027,7 +1039,7 @@ def can_impl_fmha_v3_bwd():
10271039
ret &= hdim_q >= 64 and hdim_q <= 128 and hdim_q % 8 == 0
10281040
ret &= mask or nmask
10291041
ret &= pssk() or psskddv()
1030-
ret &= 'gfx942' in torch.cuda.get_device_properties("cuda").gcnArchName
1042+
ret &= "gfx942" in torch.cuda.get_device_properties("cuda").gcnArchName
10311043

10321044
return ret
10331045

@@ -1122,15 +1134,16 @@ def forward(
11221134
return_lse,
11231135
return_softmax,
11241136
block_table,
1137+
out,
11251138
is_grad_enabled,
11261139
is_v3_atomic_fp32: Optional[bool] = True,
11271140
how_v3_bf16_cvt: Optional[int] = 1,
11281141
):
11291142
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
11301143
if softmax_scale is None:
11311144
softmax_scale = q.shape[-1] ** (-0.5)
1132-
head_size_q_og = q.size(2)
1133-
head_size_v_og = v.size(2)
1145+
head_size_q_og = q.size(-1)
1146+
head_size_v_og = v.size(-1)
11341147
if head_size_q_og % 8 != 0:
11351148
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
11361149
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
@@ -1154,6 +1167,7 @@ def forward(
11541167
return_lse=return_lse,
11551168
return_softmax=return_softmax and dropout_p > 0,
11561169
block_table=block_table,
1170+
out=out,
11571171
)
11581172
if is_grad:
11591173
ctx.save_for_backward(
@@ -1243,6 +1257,7 @@ def backward(ctx, dout, *args):
12431257
None,
12441258
None,
12451259
None,
1260+
None,
12461261
)
12471262

12481263

@@ -1264,6 +1279,7 @@ def flash_attn_varlen_func(
12641279
return_lse=False,
12651280
return_attn_probs=False,
12661281
block_table=None,
1282+
out=None,
12671283
):
12681284
"""dropout_p should be set to 0.0 during evaluation
12691285
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -1338,5 +1354,6 @@ def flash_attn_varlen_func(
13381354
return_lse,
13391355
return_attn_probs,
13401356
block_table,
1357+
out,
13411358
torch.is_grad_enabled(),
13421359
)

csrc/include/cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void reshape_and_cache_flash(torch::Tensor &key, torch::Tensor &value,
2727
torch::Tensor &value_cache,
2828
torch::Tensor &slot_mapping,
2929
const std::string &kv_cache_dtype,
30-
const double k_scale, const double v_scale);
30+
torch::Tensor& k_scale, torch::Tensor& v_scale);
3131

3232
void reshape_and_cache_with_pertoken_quant(torch::Tensor &key, torch::Tensor &value,
3333
torch::Tensor &key_cache, torch::Tensor &value_cache,

csrc/include/rocm_ops.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#define ACTIVATION_PYBIND \
55
m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); \
@@ -573,4 +573,4 @@
573573
.value("No", ActivationType::No) \
574574
.value("Silu", ActivationType::Silu) \
575575
.value("Gelu", ActivationType::Gelu) \
576-
.export_values();
576+
.export_values();

csrc/include/torch/mha_varlen_fwd.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
// SPDX-License-Identifier: MIT
3-
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3+
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
44
#include <torch/extension.h>
55

66
namespace aiter {
@@ -10,7 +10,7 @@ mha_varlen_fwd(at::Tensor& q, // [total_q, hq, d]
1010
const at::Tensor& k, // [total_k, hk, d]
1111
const at::Tensor& v, // [total_k, hk, d]
1212
const at::Tensor& cu_seqlens_q, // [b+1]
13-
const at::Tensor& cu_seqlens_k, // [b+1]
13+
std::optional<const at::Tensor> &cu_seqlens_k, // [b+1]
1414
int max_seqlen_q,
1515
int max_seqlen_k,
1616
float p_dropout,

csrc/kernels/cache_kernels.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ namespace vllm
274274
const int64_t *__restrict__ slot_mapping, // [num_tokens]
275275
const int block_stride, const int key_stride, const int value_stride,
276276
const int num_heads, const int head_size, const int block_size,
277-
const float k_scale, const float v_scale)
277+
const float* k_scale, const float* v_scale)
278278
{
279279
const int64_t token_idx = blockIdx.x;
280280
const int64_t slot_idx = slot_mapping[token_idx];
@@ -305,9 +305,9 @@ namespace vllm
305305
else
306306
{
307307
key_cache[tgt_key_value_idx] =
308-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
308+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
309309
value_cache[tgt_key_value_idx] =
310-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
310+
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
311311
}
312312
}
313313
}
@@ -873,7 +873,7 @@ void reshape_and_cache(
873873
reinterpret_cast<CACHE_T *>(key_cache.data_ptr()), \
874874
reinterpret_cast<CACHE_T *>(value_cache.data_ptr()), \
875875
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
876-
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
876+
value_stride, num_heads, head_size, block_size, k_scale.data_ptr<float>(), v_scale.data_ptr<float>());
877877

878878
void reshape_and_cache_flash(
879879
torch::Tensor &key, // [num_tokens, num_heads, head_size]
@@ -882,8 +882,9 @@ void reshape_and_cache_flash(
882882
torch::Tensor &
883883
value_cache, // [num_blocks, block_size, num_heads, head_size]
884884
torch::Tensor &slot_mapping, // [num_tokens]
885-
const std::string &kv_cache_dtype, const double k_scale,
886-
const double v_scale)
885+
const std::string &kv_cache_dtype,
886+
torch::Tensor& k_scale,
887+
torch::Tensor& v_scale)
887888
{
888889
int num_tokens = key.size(0);
889890
int num_heads = key.size(1);

0 commit comments

Comments
 (0)