Skip to content

Commit 3c5f475

Browse files
authored
remove contiguous copy for flash-attn opbuilder (#372)
* remove unnecessary codes for latest flash-attn opbuilder * add use-flash-attn-builder to make flash_attn usage clear and compatible * use hasattr
1 parent 888a63a commit 3c5f475

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

megatron/arguments.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def validate_args(args, defaults={}):
421421
args.compression_training = False
422422

423423
# FlashAttention
424-
args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2
424+
args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or args.use_flash_attn_builder
425425

426426
# AML
427427
if args.aml_data_download_path is not None:
@@ -910,6 +910,8 @@ def _add_training_args(parser):
910910
'https://arxiv.org/abs/2307.08691')
911911
group.add_argument('--use-flash-attn-triton', action='store_true',
912912
help='use FlashAttention implementation of attention using Triton.')
913+
group.add_argument('--use-flash-attn-builder', action='store_true',
914+
help='use FlashAttention op builder.')
913915
group.add_argument('--disable-bias-linear', action='store_false',
914916
help='Disable bias in the linear layers',
915917
dest='add_bias_linear')

megatron/model/transformer.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,19 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
381381

382382
# Use FlashAttention-2 when args.use_flash_attn_v2 is True
383383
args = get_args()
384-
self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func
384+
self.use_flash_attn_builder_v1 = False
385+
self.use_flash_attn_builder_v2 = False
386+
self.use_flash_attn = False
387+
if args.use_flash_attn_builder:
388+
if hasattr(flash_attn_builder, 'flash_attn_func'):
389+
self.flash_attn_func = flash_attn_builder.flash_attn_func
390+
self.use_flash_attn_builder_v1 = True
391+
else:
392+
self.flash_attn_func = flash_attn_builder.flash_attn_func_v2
393+
self.use_flash_attn_builder_v2 = True
394+
else:
395+
self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func
396+
self.use_flash_attn = True
385397

386398
def forward(self, q, k, v):
387399
"""Implements the multihead softmax attention.
@@ -392,22 +404,19 @@ def forward(self, q, k, v):
392404

393405
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
394406
assert all((get_accelerator().on_accelerator(i) for i in (q, k, v)))
395-
# if get_accelerator().device_name() == 'cuda':
396-
# assert all((i.is_cuda for i in (q,k,v)))
397-
# else:
398-
# assert all((i.is_xpu for i in (q,k,v)))
399407

400408
batch_size, seqlen_q = q.shape[0], q.shape[1]
401409
seqlen_k = k.shape[1]
402410

403-
if get_accelerator().device_name() == 'cuda':
404-
# goes for cuda device
411+
if self.use_flash_attn:
405412
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
406413
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
407414
device=q.device)
408-
else:
409-
# goes for other device
415+
elif self.use_flash_attn_builder_v1:
410416
q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]]
417+
else:
418+
# use_flash_attn_builder_v2
419+
q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [q, k, v]]
411420

412421
if self.training:
413422
# during training q,k,v always have same seqlen
@@ -424,16 +433,26 @@ def forward(self, q, k, v):
424433
device=q.device) if get_accelerator().device_name() == 'cuda' else None
425434
dropout_p = 0
426435

427-
output = self.flash_attn_func(
428-
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
429-
dropout_p,
430-
softmax_scale=self.softmax_scale, causal=is_causal
431-
) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func(
432-
q, k, v, self.dropout_p, self.softmax_scale, is_causal
433-
)
436+
if self.use_flash_attn:
437+
output = self.flash_attn_func(
438+
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
439+
dropout_p,
440+
softmax_scale=self.softmax_scale, causal=is_causal
441+
)
442+
else:
443+
# use_flash_attn_builder
444+
output = self.flash_attn_func(
445+
q, k, v, self.dropout_p, self.softmax_scale, is_causal
446+
)
447+
448+
if self.use_flash_attn:
449+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
450+
elif self.use_flash_attn_builder_v1:
451+
output = rearrange(output, 'b h s d -> b s h d').contiguous()
452+
else:
453+
# use_flash_attn_builder_v2:
454+
output = rearrange(output, 'b h s d -> b s h d')
434455

435-
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) if get_accelerator().device_name() == 'cuda' else rearrange(
436-
output, 'b h s d -> b s h d').contiguous()
437456
return output
438457

439458
class FlashSelfAttentionTriton(torch.nn.Module):
@@ -492,7 +511,8 @@ def __init__(self, config, layer_number,
492511
self.num_key_value_heads = config.num_key_value_heads
493512
self.use_gqa = (self.num_attention_heads != self.num_key_value_heads)
494513

495-
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2) \
514+
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \
515+
args.use_flash_attn_builder) \
496516
and attention_type == AttnType.self_attn \
497517
and self.attn_mask_type == AttnMaskType.causal
498518
self.use_flash_attn_triton = args.use_flash_attn_triton
@@ -504,12 +524,13 @@ def __init__(self, config, layer_number,
504524
flash_attn_builder = None
505525

506526
if args.use_flash_attn_v1:
507-
assert flash_attn_unpadded_func != None or flash_attn_builder != None, ("Cannot import FlashAttention v1 "
508-
"and Cannot find FlashAttention Builder")
527+
assert flash_attn_unpadded_func != None, "Cannot import FlashAttention v1 "
509528
if args.use_flash_attn_v2:
510529
assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 "
511530
if args.use_flash_attn_triton:
512531
assert flash_attn_func != None, "Cannot import FlashAttention triton "
532+
if args.use_flash_attn_builder:
533+
assert flash_attn_builder != None, "Cannot find FlashAttention op builder "
513534

514535
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
515536
'self-attention for now')

0 commit comments

Comments
 (0)