@@ -381,7 +381,19 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
381
381
382
382
# Use FlashAttention-2 when args.use_flash_attn_v2 is True
383
383
args = get_args ()
384
- self .flash_attn_func = flash_attn_varlen_func if args .use_flash_attn_v2 else flash_attn_unpadded_func
384
+ self .use_flash_attn_builder_v1 = False
385
+ self .use_flash_attn_builder_v2 = False
386
+ self .use_flash_attn = False
387
+ if args .use_flash_attn_builder :
388
+ if hasattr (flash_attn_builder , 'flash_attn_func' ):
389
+ self .flash_attn_func = flash_attn_builder .flash_attn_func
390
+ self .use_flash_attn_builder_v1 = True
391
+ else :
392
+ self .flash_attn_func = flash_attn_builder .flash_attn_func_v2
393
+ self .use_flash_attn_builder_v2 = True
394
+ else :
395
+ self .flash_attn_func = flash_attn_varlen_func if args .use_flash_attn_v2 else flash_attn_unpadded_func
396
+ self .use_flash_attn = True
385
397
386
398
def forward (self , q , k , v ):
387
399
"""Implements the multihead softmax attention.
@@ -392,22 +404,19 @@ def forward(self, q, k, v):
392
404
393
405
assert all ((i .dtype in [torch .float16 , torch .bfloat16 ] for i in (q ,k ,v )))
394
406
assert all ((get_accelerator ().on_accelerator (i ) for i in (q , k , v )))
395
- # if get_accelerator().device_name() == 'cuda':
396
- # assert all((i.is_cuda for i in (q,k,v)))
397
- # else:
398
- # assert all((i.is_xpu for i in (q,k,v)))
399
407
400
408
batch_size , seqlen_q = q .shape [0 ], q .shape [1 ]
401
409
seqlen_k = k .shape [1 ]
402
410
403
- if get_accelerator ().device_name () == 'cuda' :
404
- # goes for cuda device
411
+ if self .use_flash_attn :
405
412
q , k , v = [rearrange (x , 'b s ... -> (b s) ...' ) for x in [q , k , v ]]
406
413
cu_seqlens_q = torch .arange (0 , (batch_size + 1 ) * seqlen_q , step = seqlen_q , dtype = torch .int32 ,
407
414
device = q .device )
408
- else :
409
- # goes for other device
415
+ elif self .use_flash_attn_builder_v1 :
410
416
q , k , v = [rearrange (x , 'b s h d -> b h s d' ).contiguous () for x in [q , k , v ]]
417
+ else :
418
+ # use_flash_attn_builder_v2
419
+ q , k , v = [rearrange (x , 'b s h d -> b h s d' ) for x in [q , k , v ]]
411
420
412
421
if self .training :
413
422
# during training q,k,v always have same seqlen
@@ -424,16 +433,26 @@ def forward(self, q, k, v):
424
433
device = q .device ) if get_accelerator ().device_name () == 'cuda' else None
425
434
dropout_p = 0
426
435
427
- output = self .flash_attn_func (
428
- q , k , v , cu_seqlens_q , cu_seqlens_k , seqlen_q , seqlen_k ,
429
- dropout_p ,
430
- softmax_scale = self .softmax_scale , causal = is_causal
431
- ) if get_accelerator ().device_name () == 'cuda' else flash_attn_builder .flash_attn_func (
432
- q , k , v , self .dropout_p , self .softmax_scale , is_causal
433
- )
436
+ if self .use_flash_attn :
437
+ output = self .flash_attn_func (
438
+ q , k , v , cu_seqlens_q , cu_seqlens_k , seqlen_q , seqlen_k ,
439
+ dropout_p ,
440
+ softmax_scale = self .softmax_scale , causal = is_causal
441
+ )
442
+ else :
443
+ # use_flash_attn_builder
444
+ output = self .flash_attn_func (
445
+ q , k , v , self .dropout_p , self .softmax_scale , is_causal
446
+ )
447
+
448
+ if self .use_flash_attn :
449
+ output = rearrange (output , '(b s) ... -> b s ...' , b = batch_size )
450
+ elif self .use_flash_attn_builder_v1 :
451
+ output = rearrange (output , 'b h s d -> b s h d' ).contiguous ()
452
+ else :
453
+ # use_flash_attn_builder_v2:
454
+ output = rearrange (output , 'b h s d -> b s h d' )
434
455
435
- output = rearrange (output , '(b s) ... -> b s ...' , b = batch_size ) if get_accelerator ().device_name () == 'cuda' else rearrange (
436
- output , 'b h s d -> b s h d' ).contiguous ()
437
456
return output
438
457
439
458
class FlashSelfAttentionTriton (torch .nn .Module ):
@@ -492,7 +511,8 @@ def __init__(self, config, layer_number,
492
511
self .num_key_value_heads = config .num_key_value_heads
493
512
self .use_gqa = (self .num_attention_heads != self .num_key_value_heads )
494
513
495
- self .use_flash_attn = (args .use_flash_attn_v1 or args .use_flash_attn_triton or args .use_flash_attn_v2 ) \
514
+ self .use_flash_attn = (args .use_flash_attn_v1 or args .use_flash_attn_triton or args .use_flash_attn_v2 or \
515
+ args .use_flash_attn_builder ) \
496
516
and attention_type == AttnType .self_attn \
497
517
and self .attn_mask_type == AttnMaskType .causal
498
518
self .use_flash_attn_triton = args .use_flash_attn_triton
@@ -504,12 +524,13 @@ def __init__(self, config, layer_number,
504
524
flash_attn_builder = None
505
525
506
526
if args .use_flash_attn_v1 :
507
- assert flash_attn_unpadded_func != None or flash_attn_builder != None , ("Cannot import FlashAttention v1 "
508
- "and Cannot find FlashAttention Builder" )
527
+ assert flash_attn_unpadded_func != None , "Cannot import FlashAttention v1 "
509
528
if args .use_flash_attn_v2 :
510
529
assert flash_attn_varlen_func != None , "Cannot import FlashAttention v2 "
511
530
if args .use_flash_attn_triton :
512
531
assert flash_attn_func != None , "Cannot import FlashAttention triton "
532
+ if args .use_flash_attn_builder :
533
+ assert flash_attn_builder != None , "Cannot find FlashAttention op builder "
513
534
514
535
assert attention_type == AttnType .self_attn , ('FlashAttention code path only supports '
515
536
'self-attention for now' )
0 commit comments