@@ -17,26 +17,32 @@ def model_benchmark_shapes(args):
17
17
N = config ["intermediate_size" ]
18
18
K = config ["hidden_size" ]
19
19
20
- shapes .append ((M , N , K ))
20
+ shapes .append ((M , N , K , 'TN' ))
21
21
22
22
return shapes
23
23
24
24
25
25
def get_x_vals ():
26
26
x_vals = [
27
- (1 , 1280 , 8192 ),
28
- (32 , 1280 , 8192 ),
29
- (64 , 1280 , 8192 ),
30
- (128 , 1280 , 8192 ),
31
- (192 , 1280 , 8192 ),
32
- (256 , 1280 , 8192 ),
33
- (320 , 1280 , 8192 ),
34
- (512 , 1280 , 8192 ),
35
- (1024 , 1280 , 8192 ),
36
- (2048 , 1280 , 8192 ),
37
- (4096 , 1280 , 8192 ),
38
- (8192 , 1280 , 8192 ),
39
- (16384 , 1280 , 8192 ),
27
+ (1 , 1280 , 8192 , 'TN' ),
28
+ (32 , 1280 , 8192 , 'TN' ),
29
+ (64 , 1280 , 8192 , 'TN' ),
30
+ (128 , 1280 , 8192 , 'TN' ),
31
+ (192 , 1280 , 8192 , 'TN' ),
32
+ (256 , 1280 , 8192 , 'TN' ),
33
+ (320 , 1280 , 8192 , 'TN' ),
34
+ (512 , 1280 , 8192 , 'TN' ),
35
+ (1024 , 1280 , 8192 , 'TN' ),
36
+ (2048 , 1280 , 8192 , 'TN' ),
37
+ (4096 , 1280 , 8192 , 'TN' ),
38
+ (8192 , 1280 , 8192 , 'TN' ),
39
+ (16384 , 1280 , 8192 , 'TN' ),
40
+ (8192 , 7168 , 20480 , 'NT' ),
41
+ (1024 , 20480 , 8192 , 'NT' ),
42
+ (1024 , 8192 , 20480 , 'NT' ),
43
+ (8192 , 7168 , 20480 , 'TN' ),
44
+ (1024 , 20480 , 8192 , 'TN' ),
45
+ (1024 , 8192 , 20480 , 'TN' ),
40
46
]
41
47
return x_vals
42
48
@@ -45,11 +51,11 @@ def run_benchmark(args):
45
51
assert not (args .shape and args .model ) or not (args .shape and args .M ), \
46
52
"User can specify --shape or --model MODEL -M VAL exclusively"
47
53
48
- x_names = ['M' , 'N' , 'K' ]
54
+ x_names = ['M' , 'N' , 'K' , 'layout' ]
49
55
if args .model :
50
56
x_vals_list = model_benchmark_shapes (args )
51
57
elif args .shape :
52
- x_vals_list = [args .shape ]
58
+ x_vals_list = [args .shape + [ args . layout ] ]
53
59
else :
54
60
x_vals_list = get_x_vals ()
55
61
@@ -71,10 +77,10 @@ def run_benchmark(args):
71
77
ylabel = ylabel , plot_name = f'GEMM A16W16 Benchmark' , args = {"metric" : args .metric })
72
78
73
79
@triton .testing .perf_report ([benchmark ])
74
- def bench_gemm_a16w16 (M , N , K , metric , provider ):
80
+ def bench_gemm_a16w16 (M , N , K , layout , metric , provider ):
75
81
# NOTE: Assume bias and output has the same dtype
76
82
c_dtype = torch .bfloat16
77
- x , w = generate_gemm_a16w16_inputs (M , N , K , c_dtype )
83
+ x , w = generate_gemm_a16w16_inputs (M , N , K , c_dtype , layout )
78
84
# flops
79
85
flops = 2.0 * M * N * K
80
86
# memory transfer
@@ -119,6 +125,8 @@ def parse_args():
119
125
help = "user-defined shape to benchmark" )
120
126
parser .add_argument ("--metric" , type = str , choices = ["time" , "throughput" , "bandwidth" ],
121
127
default = "throughput" , help = "metric to plot" )
128
+ parser .add_argument ("--layout" , type = str , choices = ["TT" , "TN" , "NT" , "NN" ],
129
+ default = "TN" , help = "Layout of input and weight matrix" )
122
130
args = parser .parse_args ()
123
131
return args
124
132
0 commit comments