Skip to content

Commit 062acc7

Browse files
authored
[TRITON][GEMM] Add layout option to GEMM A16W16 (ROCm#383)
* Initial GEMM tuning * Revert tuning * Capitalize layout naming * Addd layout to x_vals
1 parent c1debd8 commit 062acc7

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

op_benchmarks/triton/bench_gemm_a16w16.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,32 @@ def model_benchmark_shapes(args):
1717
N = config["intermediate_size"]
1818
K = config["hidden_size"]
1919

20-
shapes.append((M, N, K))
20+
shapes.append((M, N, K, 'TN'))
2121

2222
return shapes
2323

2424

2525
def get_x_vals():
2626
x_vals = [
27-
(1, 1280, 8192),
28-
(32, 1280, 8192),
29-
(64, 1280, 8192),
30-
(128, 1280, 8192),
31-
(192, 1280, 8192),
32-
(256, 1280, 8192),
33-
(320, 1280, 8192),
34-
(512, 1280, 8192),
35-
(1024, 1280, 8192),
36-
(2048, 1280, 8192),
37-
(4096, 1280, 8192),
38-
(8192, 1280, 8192),
39-
(16384, 1280, 8192),
27+
(1, 1280, 8192, 'TN'),
28+
(32, 1280, 8192, 'TN'),
29+
(64, 1280, 8192, 'TN'),
30+
(128, 1280, 8192, 'TN'),
31+
(192, 1280, 8192, 'TN'),
32+
(256, 1280, 8192, 'TN'),
33+
(320, 1280, 8192, 'TN'),
34+
(512, 1280, 8192, 'TN'),
35+
(1024, 1280, 8192, 'TN'),
36+
(2048, 1280, 8192, 'TN'),
37+
(4096, 1280, 8192, 'TN'),
38+
(8192, 1280, 8192, 'TN'),
39+
(16384, 1280, 8192, 'TN'),
40+
(8192, 7168, 20480, 'NT'),
41+
(1024, 20480, 8192, 'NT'),
42+
(1024, 8192, 20480, 'NT'),
43+
(8192, 7168, 20480, 'TN'),
44+
(1024, 20480, 8192, 'TN'),
45+
(1024, 8192, 20480, 'TN'),
4046
]
4147
return x_vals
4248

@@ -45,11 +51,11 @@ def run_benchmark(args):
4551
assert not(args.shape and args.model) or not(args.shape and args.M), \
4652
"User can specify --shape or --model MODEL -M VAL exclusively"
4753

48-
x_names = ['M', 'N', 'K']
54+
x_names = ['M', 'N', 'K', 'layout']
4955
if args.model:
5056
x_vals_list = model_benchmark_shapes(args)
5157
elif args.shape:
52-
x_vals_list = [args.shape]
58+
x_vals_list = [args.shape + [args.layout]]
5359
else:
5460
x_vals_list = get_x_vals()
5561

@@ -71,10 +77,10 @@ def run_benchmark(args):
7177
ylabel=ylabel, plot_name=f'GEMM A16W16 Benchmark', args={"metric": args.metric})
7278

7379
@triton.testing.perf_report([benchmark])
74-
def bench_gemm_a16w16(M, N, K, metric, provider):
80+
def bench_gemm_a16w16(M, N, K, layout, metric, provider):
7581
# NOTE: Assume bias and output has the same dtype
7682
c_dtype = torch.bfloat16
77-
x, w = generate_gemm_a16w16_inputs(M, N, K, c_dtype)
83+
x, w = generate_gemm_a16w16_inputs(M, N, K, c_dtype, layout)
7884
# flops
7985
flops = 2.0 * M * N * K
8086
# memory transfer
@@ -119,6 +125,8 @@ def parse_args():
119125
help="user-defined shape to benchmark")
120126
parser.add_argument("--metric", type=str, choices=["time", "throughput", "bandwidth"],
121127
default="throughput", help="metric to plot")
128+
parser.add_argument("--layout", type=str, choices=["TT", "TN", "NT", "NN"],
129+
default="TN", help="Layout of input and weight matrix")
122130
args = parser.parse_args()
123131
return args
124132

op_tests/triton_tests/test_gemm_a16w16.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,16 @@
55
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
66

77

8-
def generate_gemm_a16w16_inputs(M, N, K, dtype):
9-
x = torch.randn((M, K), dtype=dtype).cuda()
10-
weight = torch.randn((K, N), dtype=dtype).cuda()
8+
def generate_gemm_a16w16_inputs(M, N, K, dtype, layout="TN"):
9+
if layout[0] == 'T':
10+
x = torch.randn((M, K), dtype=dtype).cuda()
11+
else:
12+
x = torch.randn((K, M), dtype=dtype).cuda().T
13+
14+
if layout[1] == 'T':
15+
weight = torch.randn((K, N), dtype=dtype).cuda()
16+
else:
17+
weight = torch.randn((N, K), dtype=dtype).cuda().T
1118

1219
return x, weight
1320

0 commit comments

Comments
 (0)