1
1
# SPDX-License-Identifier: MIT
2
- # Copyright (c ) 2024, Advanced Micro Devices, Inc. All rights reserved.
2
+ # Copyright (C ) 2024-2025 , Advanced Micro Devices, Inc. All rights reserved.
3
3
4
4
from torch import Tensor , Generator
5
5
from typing import Optional , Tuple
6
6
from ..jit .core import compile_ops , CK_DIR , AITER_CSRC_DIR , AITER_ROOT_DIR
7
7
from ..utility import dtypes
8
8
import torch
9
9
10
+
10
11
@compile_ops ("module_mha_fwd" , fc_name = "mha_fwd" )
11
12
def mha_fwd (
12
13
q : Tensor ,
@@ -48,7 +49,7 @@ def mha_varlen_fwd(
48
49
bias : Optional [Tensor ] = None ,
49
50
alibi_slopes : Optional [Tensor ] = None ,
50
51
gen : Optional [Generator ] = None ,
51
- ): ...
52
+ ) -> list [ Tensor ] : ...
52
53
53
54
54
55
@compile_ops ("module_mha_bwd" , fc_name = "mha_bwd" )
@@ -419,7 +420,9 @@ def pssk():
419
420
# bwd_hd64_bf16_causal_a32_rtz_pssk
420
421
# bwd_hd64_fp16_a32_pssk
421
422
# bwd_hd64_fp16_causal_a32_pssk
422
- ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
423
+ ret = (
424
+ is_v3_atomic_fp32 == True
425
+ ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
423
426
ret &= hdim_q == 64
424
427
ret &= nmask or (
425
428
mask and seqlen_q == seqlen_k
@@ -474,7 +477,9 @@ def psskddv():
474
477
# bwd_hd192_bf16_causal_a32_rtz_psskddv
475
478
ret = is_v3_atomic_fp32 == True
476
479
ret &= hdim_q > 64 and hdim_q <= 192
477
- ret &= nmask or (mask and seqlen_q == seqlen_k ) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
480
+ ret &= nmask or (
481
+ mask and seqlen_q == seqlen_k
482
+ ) # TODO: or (seqlen_q != seqlen_k and mask_type == top_left)
478
483
479
484
return ret
480
485
@@ -759,6 +764,7 @@ def _flash_attn_varlen_forward(
759
764
return_lse : bool = False ,
760
765
return_softmax : bool = False ,
761
766
block_table : Optional [torch .Tensor ] = None ,
767
+ out : Optional [torch .Tensor ] = None ,
762
768
zero_tensors : bool = False ,
763
769
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
764
770
# causal=true is the same as causal=false in this case
@@ -878,7 +884,7 @@ def _flash_attn_varlen_forward(
878
884
window_size_right ,
879
885
return_lse ,
880
886
return_softmax ,
881
- None ,
887
+ out ,
882
888
block_table ,
883
889
bias ,
884
890
alibi_slopes ,
@@ -963,7 +969,9 @@ def _flash_attn_varlen_backward(
963
969
]
964
970
965
971
(_ , nhead_q , hdim_q ) = q .shape
966
- (_ , nhead_k , hdim_v ) = v .shape
972
+
973
+ nhead_k = v .shape [- 2 ]
974
+ hdim_v = v .shape [- 1 ]
967
975
968
976
# mask
969
977
window_size_left = - 1 if window_size_left >= max_seqlen_k else window_size_left
@@ -994,12 +1002,14 @@ def pssk():
994
1002
# bwd_hd128_bf16_causal_a32_rtz_pssk_group
995
1003
# bwd_hd128_fp16_a32_pssk_group
996
1004
# bwd_hd128_fp16_causal_a32_pssk_group
997
- ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1005
+ ret = (
1006
+ is_v3_atomic_fp32 == True
1007
+ ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
998
1008
ret &= hdim_q == 64 or hdim_q == 128
999
- ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1009
+ ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1000
1010
1001
1011
return ret
1002
-
1012
+
1003
1013
def psskddv ():
1004
1014
# bwd_hd128_bf16_a32_rtne_psskddv_group
1005
1015
# bwd_hd128_bf16_a32_rtna_psskddv_group
@@ -1009,9 +1019,11 @@ def psskddv():
1009
1019
# bwd_hd128_bf16_causal_a32_rtz_psskddv_group
1010
1020
# bwd_hd128_fp16_a32_psskddv_group
1011
1021
# bwd_hd128_fp16_causal_a32_psskddv_group
1012
- ret = is_v3_atomic_fp32 == True # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1022
+ ret = (
1023
+ is_v3_atomic_fp32 == True
1024
+ ) # nhead_stride_dq_acc >= stride_dq_acc must be guaranteed
1013
1025
ret &= hdim_q > 64 and hdim_q < 128
1014
- ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1026
+ ret &= nmask # TODO: or (mask and mask_type == mask_enum::mask_top_left)
1015
1027
1016
1028
return ret
1017
1029
@@ -1027,7 +1039,7 @@ def can_impl_fmha_v3_bwd():
1027
1039
ret &= hdim_q >= 64 and hdim_q <= 128 and hdim_q % 8 == 0
1028
1040
ret &= mask or nmask
1029
1041
ret &= pssk () or psskddv ()
1030
- ret &= ' gfx942' in torch .cuda .get_device_properties ("cuda" ).gcnArchName
1042
+ ret &= " gfx942" in torch .cuda .get_device_properties ("cuda" ).gcnArchName
1031
1043
1032
1044
return ret
1033
1045
@@ -1122,15 +1134,16 @@ def forward(
1122
1134
return_lse ,
1123
1135
return_softmax ,
1124
1136
block_table ,
1137
+ out ,
1125
1138
is_grad_enabled ,
1126
1139
is_v3_atomic_fp32 : Optional [bool ] = True ,
1127
1140
how_v3_bf16_cvt : Optional [int ] = 1 ,
1128
1141
):
1129
1142
is_grad = is_grad_enabled and any (x .requires_grad for x in [q , k , v ])
1130
1143
if softmax_scale is None :
1131
1144
softmax_scale = q .shape [- 1 ] ** (- 0.5 )
1132
- head_size_q_og = q .size (2 )
1133
- head_size_v_og = v .size (2 )
1145
+ head_size_q_og = q .size (- 1 )
1146
+ head_size_v_og = v .size (- 1 )
1134
1147
if head_size_q_og % 8 != 0 :
1135
1148
q = torch .nn .functional .pad (q , [0 , 8 - head_size_q_og % 8 ])
1136
1149
k = torch .nn .functional .pad (k , [0 , 8 - head_size_q_og % 8 ])
@@ -1154,6 +1167,7 @@ def forward(
1154
1167
return_lse = return_lse ,
1155
1168
return_softmax = return_softmax and dropout_p > 0 ,
1156
1169
block_table = block_table ,
1170
+ out = out ,
1157
1171
)
1158
1172
if is_grad :
1159
1173
ctx .save_for_backward (
@@ -1243,6 +1257,7 @@ def backward(ctx, dout, *args):
1243
1257
None ,
1244
1258
None ,
1245
1259
None ,
1260
+ None ,
1246
1261
)
1247
1262
1248
1263
@@ -1264,6 +1279,7 @@ def flash_attn_varlen_func(
1264
1279
return_lse = False ,
1265
1280
return_attn_probs = False ,
1266
1281
block_table = None ,
1282
+ out = None ,
1267
1283
):
1268
1284
"""dropout_p should be set to 0.0 during evaluation
1269
1285
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@@ -1338,5 +1354,6 @@ def flash_attn_varlen_func(
1338
1354
return_lse ,
1339
1355
return_attn_probs ,
1340
1356
block_table ,
1357
+ out ,
1341
1358
torch .is_grad_enabled (),
1342
1359
)
0 commit comments