Skip to content

Commit b28246f

Browse files
authored
[ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (vllm-project#14065)
Signed-off-by: Sage Moore <[email protected]>
1 parent 3b5567a commit b28246f

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from vllm.attention.ops.paged_attn import PagedAttention
1010
from vllm.attention.ops.prefix_prefill import context_attention_fwd
1111
from vllm.logger import init_logger
12-
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
12+
from vllm.v1.attention.backends.flash_attn import (
13+
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
1314

1415
logger = init_logger(__name__)
1516

@@ -49,6 +50,10 @@ def get_kv_cache_shape(
4950
def use_cascade_attention(*args, **kwargs) -> bool:
5051
return False
5152

53+
@staticmethod
54+
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
55+
return FlashAttentionMetadataBuilder
56+
5257

5358
class ROCmAttentionImpl(AttentionImpl):
5459

0 commit comments

Comments
 (0)