Skip to content

Commit db302ae

Browse files
authored
Merge pull request #4083 from ROCm/gqa-accuracy-rel
Update 6.4 GQA tests
2 parents 6716e6d + 6a95267 commit db302ae

13 files changed

+293
-171
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ RUN cget -p $PREFIX install doxygen@Release_1_9_8
112112

113113
COPY ./test/onnx/.onnxrt-commit /
114114

115-
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
116-
ARG ONNXRUNTIME_BRANCH=main
115+
ARG ONNXRUNTIME_REPO=https://github.com/rocm/onnxruntime
116+
ARG ONNXRUNTIME_BRANCH=rocm6.4_internal_testing
117117
ARG ONNXRUNTIME_COMMIT
118118

119119
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime && \

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ register_migraphx_ops(
188188
greater
189189
greater_or_equal
190190
group_query_attention
191+
group
191192
gru
192193
identity
193194
if_op

src/include/migraphx/op/group.hpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
#ifndef MIGRAPHX_GUARD_OPERATORS_GROUP_HPP
25+
#define MIGRAPHX_GUARD_OPERATORS_GROUP_HPP
26+
27+
#include <migraphx/argument.hpp>
28+
#include <migraphx/module.hpp>
29+
#include <migraphx/check_shapes.hpp>
30+
31+
namespace migraphx {
32+
inline namespace MIGRAPHX_INLINE_NS {
33+
namespace op {
34+
35+
struct group
36+
{
37+
std::string tag = "";
38+
39+
std::string name() const { return "group"; }
40+
template <class Self, class F>
41+
static auto reflect(Self& self, F f)
42+
{
43+
return pack(f(self.tag, "tag"));
44+
}
45+
46+
shape compute_shape(const std::vector<shape>& inputs, const std::vector<module_ref>& mods) const
47+
{
48+
if(mods.size() != 1)
49+
MIGRAPHX_THROW("should have one submodule.");
50+
module_ref mod = mods[0];
51+
check_shapes{inputs, *this}.has_at_least(1);
52+
53+
auto result =
54+
mod->compute_shapes(inputs, {.name = name(), .strict_type = true, .strict_lens = true});
55+
if(result.size() == 1)
56+
return result.front();
57+
return shape{result};
58+
}
59+
};
60+
61+
} // namespace op
62+
} // namespace MIGRAPHX_INLINE_NS
63+
} // namespace migraphx
64+
65+
#endif

src/include/migraphx/operators.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
#include <migraphx/op/greater.hpp>
6767
#include <migraphx/op/greater_or_equal.hpp>
6868
#include <migraphx/op/group_query_attention.hpp>
69+
#include <migraphx/op/group.hpp>
6970
#include <migraphx/op/gru.hpp>
7071
#include <migraphx/op/identity.hpp>
7172
#include <migraphx/op/if_op.hpp>

src/targets/gpu/fuse_mlir.cpp

Lines changed: 9 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -917,126 +917,24 @@ struct find_mlir_standalone_attention_op
917917
}
918918
};
919919

920-
struct find_mlir_gqa_attention_op
920+
struct find_mlir_attention_op
921921
{
922922
mlir_mode dot_mode = mlir_mode::none;
923923

924-
auto matcher() const { return match::name("gpu::kv_cache_attention"); }
925-
926-
auto finalize_attention_module(module_ref m) const
927-
{
928-
eliminate_common_subexpression{}.apply(*m);
929-
dead_code_elimination{}.apply(*m);
930-
}
924+
auto matcher() const { return match::name("group"); }
931925

932926
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
933927
{
934-
auto attn = r.result;
935-
936-
float scale_val = attn->get_operator().to_value().get("scale", 0.0);
937-
std::size_t num_heads = attn->get_operator().to_value().get("num_heads", 32);
938-
std::size_t kv_num_heads = attn->get_operator().to_value().get("kv_num_heads", 32);
939-
auto kv_num_heads_factor = num_heads / kv_num_heads;
940-
auto qkv = attn->inputs().at(0);
941-
auto pk = attn->inputs().at(1);
942-
auto pv = attn->inputs().at(2);
943-
auto csl = attn->inputs().at(3);
944-
auto batch_size = pk->get_shape().lens()[0];
945-
auto seq_len = qkv->get_shape().lens()[2];
946-
auto head_size = qkv->get_shape().lens()[3];
947-
auto max_seq_len = pk->get_shape().lens()[2];
948-
csl = mpm.get_module().insert_instruction(
949-
attn, make_op("multibroadcast", {{"out_lens", {batch_size, num_heads}}}), csl);
950-
951-
module m_attn;
952-
std::vector<instruction_ref> inputs = {qkv, pk, pv, csl};
953-
std::unordered_map<instruction_ref, instruction_ref> map_main_to_mattn;
954-
m_attn.add_params(inputs, &map_main_to_mattn);
955-
956-
auto q = m_attn.add_instruction(
957-
make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {num_heads}}}),
958-
map_main_to_mattn.at(qkv));
959-
auto k = map_main_to_mattn.at(pk);
960-
auto v = map_main_to_mattn.at(pv);
961-
if(kv_num_heads_factor != 1)
928+
auto group = r.result;
929+
auto tag = group->get_operator().to_value().get("tag", "");
930+
if(tag != "attention")
962931
{
963-
auto kv_new_lens = k->get_shape().lens();
964-
kv_new_lens.at(1) = num_heads;
965-
k = m_attn.add_instruction(
966-
make_op("unsqueeze", {{"axes", {2}}}), k);
967-
v = m_attn.add_instruction(
968-
make_op("unsqueeze", {{"axes", {2}}}), v);
969-
auto kv_unsqueezed_lens = k->get_shape().lens();
970-
kv_unsqueezed_lens.at(2) = kv_num_heads_factor;
971-
k = m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", kv_unsqueezed_lens}}), k);
972-
v = m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", kv_unsqueezed_lens}}), v);
973-
k = m_attn.add_instruction(make_op("reshape", {{"dims", kv_new_lens}}), k);
974-
v = m_attn.add_instruction(make_op("reshape", {{"dims", kv_new_lens}}), v);
975-
}
976-
auto kt = m_attn.add_instruction(
977-
make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k);
978-
auto gemm1 = m_attn.add_instruction(make_op("dot"), q, kt);
979-
980-
std::vector<int> range_vec(max_seq_len);
981-
std::iota(range_vec.begin(), range_vec.end(), 0);
982-
shape range_s{csl->get_shape().type(), {max_seq_len}};
983-
auto range = m_attn.add_literal(range_s, range_vec);
984-
std::vector<std::size_t> bnsm{batch_size, num_heads, seq_len, max_seq_len};
985-
auto bc_range =
986-
m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", bnsm}}), range);
987-
988-
auto scalar_s = shape{qkv->get_shape().type(), {1}};
989-
auto ninf =
990-
m_attn.add_literal(literal{scalar_s, {-std::numeric_limits<float>::infinity()}});
991-
ninf = m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", bnsm}}), ninf);
992-
993-
if(float_equal(scale_val, 0.0))
994-
{
995-
scale_val = 1.0f / std::sqrt(static_cast<float>(head_size));
996-
}
997-
auto scale = m_attn.add_literal(literal{scalar_s, {scale_val}});
998-
scale = m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", bnsm}}), scale);
999-
1000-
if(seq_len > 1)
1001-
{
1002-
std::vector<int> seq_range_vec(seq_len);
1003-
std::iota(seq_range_vec.begin(), seq_range_vec.end(), 1);
1004-
shape seq_range_s{csl->get_shape().type(), {seq_len}};
1005-
auto seq_range = m_attn.add_literal(seq_range_s, seq_range_vec);
1006-
seq_range =
1007-
m_attn.add_instruction(make_op("reshape", {{"dims", {seq_len, 1}}}), seq_range);
1008-
seq_range =
1009-
m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", bnsm}}), seq_range);
1010-
auto causal_mask =
1011-
m_attn.add_instruction(make_op("greater_or_equal"), bc_range, seq_range);
1012-
causal_mask = m_attn.add_instruction(
1013-
make_op("convert", {{"target_type", shape::bool_type}}), causal_mask);
1014-
gemm1 = m_attn.add_instruction(make_op("where"), causal_mask, ninf, gemm1);
932+
return;
1015933
}
1016934

1017-
auto bc_csl =
1018-
m_attn.add_instruction(make_op("reshape", {{"dims", {batch_size, num_heads, 1, 1}}}),
1019-
map_main_to_mattn.at(csl));
1020-
auto mask_comp =
1021-
m_attn.add_instruction(make_op("multibroadcast", {{"out_lens", bnsm}}), bc_csl);
1022-
auto mask = m_attn.add_instruction(make_op("greater_or_equal"), bc_range, mask_comp);
1023-
mask =
1024-
m_attn.add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), mask);
1025-
auto mul = m_attn.add_instruction(make_op("mul"), gemm1, scale);
1026-
auto where = m_attn.add_instruction(make_op("where"), mask, ninf, mul);
1027-
auto softmax = m_attn.add_instruction(make_op("softmax", {{"axis", 3}}), where);
1028-
auto scores = m_attn.add_instruction(make_op("dot"), softmax, v);
1029-
auto out =
1030-
m_attn.add_instruction(make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), scores);
1031-
out = m_attn.add_instruction(make_op("reshape", {{"dims", attn->get_shape().lens()}}), out);
1032-
m_attn.add_return({out});
1033-
1034-
finalize_attention_module(&m_attn);
1035-
module_ref mpm_attn = mpm.create_module("mlir_attn", std::move(m_attn));
1036-
mpm_attn->set_bypass();
1037-
935+
auto* m_attn = group->module_inputs()[0];
1038936
mpm.get_module().replace_instruction(
1039-
attn, mlir_op{attn->get_operator()}, mlir_contiguous(mpm, inputs), {mpm_attn});
937+
group, mlir_op{group->get_operator()}, mlir_contiguous(mpm, group->inputs()), {m_attn});
1040938
}
1041939
};
1042940

@@ -1198,7 +1096,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
11981096
mpm.run_pass(dead_code_elimination{});
11991097
}
12001098

1201-
match::find_matches(mpm, find_mlir_gqa_attention_op{mlir_mode::all});
1099+
match::find_matches(mpm, find_mlir_attention_op{mlir_mode::all});
12021100
mpm.run_pass(dead_code_elimination{});
12031101

12041102
match::find_matches(

src/targets/gpu/lowering.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -555,17 +555,6 @@ struct miopen_apply
555555
{"output_shape", to_value(ins->get_shape())}}),
556556
ins->inputs());
557557
});
558-
559-
apply_map.emplace("gpu::kv_cache_attention", [=](instruction_ref ins) {
560-
auto s = ins->get_shape();
561-
auto output = insert_allocation(ins, s);
562-
auto new_inputs = ins->inputs();
563-
new_inputs.push_back(output);
564-
return mod->replace_instruction(
565-
ins,
566-
make_op("gpu::precompile_op", {{"op", to_value(ins->get_operator())}}),
567-
new_inputs);
568-
});
569558
}
570559

571560
void add_scan_slice_op()

0 commit comments

Comments
 (0)