@@ -917,126 +917,24 @@ struct find_mlir_standalone_attention_op
917
917
}
918
918
};
919
919
920
- struct find_mlir_gqa_attention_op
920
+ struct find_mlir_attention_op
921
921
{
922
922
mlir_mode dot_mode = mlir_mode::none;
923
923
924
- auto matcher () const { return match::name (" gpu::kv_cache_attention" ); }
925
-
926
- auto finalize_attention_module (module_ref m) const
927
- {
928
- eliminate_common_subexpression{}.apply (*m);
929
- dead_code_elimination{}.apply (*m);
930
- }
924
+ auto matcher () const { return match::name (" group" ); }
931
925
932
926
void apply (module_pass_manager& mpm, const match::matcher_result& r) const
933
927
{
934
- auto attn = r.result ;
935
-
936
- float scale_val = attn->get_operator ().to_value ().get (" scale" , 0.0 );
937
- std::size_t num_heads = attn->get_operator ().to_value ().get (" num_heads" , 32 );
938
- std::size_t kv_num_heads = attn->get_operator ().to_value ().get (" kv_num_heads" , 32 );
939
- auto kv_num_heads_factor = num_heads / kv_num_heads;
940
- auto qkv = attn->inputs ().at (0 );
941
- auto pk = attn->inputs ().at (1 );
942
- auto pv = attn->inputs ().at (2 );
943
- auto csl = attn->inputs ().at (3 );
944
- auto batch_size = pk->get_shape ().lens ()[0 ];
945
- auto seq_len = qkv->get_shape ().lens ()[2 ];
946
- auto head_size = qkv->get_shape ().lens ()[3 ];
947
- auto max_seq_len = pk->get_shape ().lens ()[2 ];
948
- csl = mpm.get_module ().insert_instruction (
949
- attn, make_op (" multibroadcast" , {{" out_lens" , {batch_size, num_heads}}}), csl);
950
-
951
- module m_attn;
952
- std::vector<instruction_ref> inputs = {qkv, pk, pv, csl};
953
- std::unordered_map<instruction_ref, instruction_ref> map_main_to_mattn;
954
- m_attn.add_params (inputs, &map_main_to_mattn);
955
-
956
- auto q = m_attn.add_instruction (
957
- make_op (" slice" , {{" axes" , {1 }}, {" starts" , {0 }}, {" ends" , {num_heads}}}),
958
- map_main_to_mattn.at (qkv));
959
- auto k = map_main_to_mattn.at (pk);
960
- auto v = map_main_to_mattn.at (pv);
961
- if (kv_num_heads_factor != 1 )
928
+ auto group = r.result ;
929
+ auto tag = group->get_operator ().to_value ().get (" tag" , " " );
930
+ if (tag != " attention" )
962
931
{
963
- auto kv_new_lens = k->get_shape ().lens ();
964
- kv_new_lens.at (1 ) = num_heads;
965
- k = m_attn.add_instruction (
966
- make_op (" unsqueeze" , {{" axes" , {2 }}}), k);
967
- v = m_attn.add_instruction (
968
- make_op (" unsqueeze" , {{" axes" , {2 }}}), v);
969
- auto kv_unsqueezed_lens = k->get_shape ().lens ();
970
- kv_unsqueezed_lens.at (2 ) = kv_num_heads_factor;
971
- k = m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , kv_unsqueezed_lens}}), k);
972
- v = m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , kv_unsqueezed_lens}}), v);
973
- k = m_attn.add_instruction (make_op (" reshape" , {{" dims" , kv_new_lens}}), k);
974
- v = m_attn.add_instruction (make_op (" reshape" , {{" dims" , kv_new_lens}}), v);
975
- }
976
- auto kt = m_attn.add_instruction (
977
- make_op (" transpose" , {{" permutation" , {0 , 1 , 3 , 2 }}}), k);
978
- auto gemm1 = m_attn.add_instruction (make_op (" dot" ), q, kt);
979
-
980
- std::vector<int > range_vec (max_seq_len);
981
- std::iota (range_vec.begin (), range_vec.end (), 0 );
982
- shape range_s{csl->get_shape ().type (), {max_seq_len}};
983
- auto range = m_attn.add_literal (range_s, range_vec);
984
- std::vector<std::size_t > bnsm{batch_size, num_heads, seq_len, max_seq_len};
985
- auto bc_range =
986
- m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , bnsm}}), range);
987
-
988
- auto scalar_s = shape{qkv->get_shape ().type (), {1 }};
989
- auto ninf =
990
- m_attn.add_literal (literal{scalar_s, {-std::numeric_limits<float >::infinity ()}});
991
- ninf = m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , bnsm}}), ninf);
992
-
993
- if (float_equal (scale_val, 0.0 ))
994
- {
995
- scale_val = 1 .0f / std::sqrt (static_cast <float >(head_size));
996
- }
997
- auto scale = m_attn.add_literal (literal{scalar_s, {scale_val}});
998
- scale = m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , bnsm}}), scale);
999
-
1000
- if (seq_len > 1 )
1001
- {
1002
- std::vector<int > seq_range_vec (seq_len);
1003
- std::iota (seq_range_vec.begin (), seq_range_vec.end (), 1 );
1004
- shape seq_range_s{csl->get_shape ().type (), {seq_len}};
1005
- auto seq_range = m_attn.add_literal (seq_range_s, seq_range_vec);
1006
- seq_range =
1007
- m_attn.add_instruction (make_op (" reshape" , {{" dims" , {seq_len, 1 }}}), seq_range);
1008
- seq_range =
1009
- m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , bnsm}}), seq_range);
1010
- auto causal_mask =
1011
- m_attn.add_instruction (make_op (" greater_or_equal" ), bc_range, seq_range);
1012
- causal_mask = m_attn.add_instruction (
1013
- make_op (" convert" , {{" target_type" , shape::bool_type}}), causal_mask);
1014
- gemm1 = m_attn.add_instruction (make_op (" where" ), causal_mask, ninf, gemm1);
932
+ return ;
1015
933
}
1016
934
1017
- auto bc_csl =
1018
- m_attn.add_instruction (make_op (" reshape" , {{" dims" , {batch_size, num_heads, 1 , 1 }}}),
1019
- map_main_to_mattn.at (csl));
1020
- auto mask_comp =
1021
- m_attn.add_instruction (make_op (" multibroadcast" , {{" out_lens" , bnsm}}), bc_csl);
1022
- auto mask = m_attn.add_instruction (make_op (" greater_or_equal" ), bc_range, mask_comp);
1023
- mask =
1024
- m_attn.add_instruction (make_op (" convert" , {{" target_type" , shape::bool_type}}), mask);
1025
- auto mul = m_attn.add_instruction (make_op (" mul" ), gemm1, scale);
1026
- auto where = m_attn.add_instruction (make_op (" where" ), mask, ninf, mul);
1027
- auto softmax = m_attn.add_instruction (make_op (" softmax" , {{" axis" , 3 }}), where);
1028
- auto scores = m_attn.add_instruction (make_op (" dot" ), softmax, v);
1029
- auto out =
1030
- m_attn.add_instruction (make_op (" transpose" , {{" permutation" , {0 , 2 , 1 , 3 }}}), scores);
1031
- out = m_attn.add_instruction (make_op (" reshape" , {{" dims" , attn->get_shape ().lens ()}}), out);
1032
- m_attn.add_return ({out});
1033
-
1034
- finalize_attention_module (&m_attn);
1035
- module_ref mpm_attn = mpm.create_module (" mlir_attn" , std::move (m_attn));
1036
- mpm_attn->set_bypass ();
1037
-
935
+ auto * m_attn = group->module_inputs ()[0 ];
1038
936
mpm.get_module ().replace_instruction (
1039
- attn , mlir_op{attn ->get_operator ()}, mlir_contiguous (mpm, inputs) , {mpm_attn });
937
+ group , mlir_op{group ->get_operator ()}, mlir_contiguous (mpm, group-> inputs ()) , {m_attn });
1040
938
}
1041
939
};
1042
940
@@ -1198,7 +1096,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
1198
1096
mpm.run_pass (dead_code_elimination{});
1199
1097
}
1200
1098
1201
- match::find_matches (mpm, find_mlir_gqa_attention_op {mlir_mode::all});
1099
+ match::find_matches (mpm, find_mlir_attention_op {mlir_mode::all});
1202
1100
mpm.run_pass (dead_code_elimination{});
1203
1101
1204
1102
match::find_matches (
0 commit comments