@@ -162,6 +162,66 @@ module {
162162
163163// -----
164164
165+ // CHECK: #wave.normal_form<full_types,memory_only_types>
166+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
167+ // CHECK-LABEL: @memory_resharding_allowed
168+ func.func @memory_resharding_allowed (%mem: !wave.tensor <[@M ] of f16 , <shared >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
169+ %cst = arith.constant 0.0 : f16
170+ // Register gets 8 elements per thread from write operation's backward propagation.
171+ // CHECK: wave.register {{.*}} : vector<8xf16>
172+ %reg8 = wave.register %cst : !wave.tensor <[@M ] of f16 , <register >>
173+
174+ // Write 8 elements per thread to memory.
175+ // CHECK: wave.write {{.*}} : vector<8xf16>, !wave.tensor<[@M] of f16, <shared>>
176+ wave.write %reg8 , %mem {elements_per_thread = 8 } : !wave.tensor <[@M ] of f16 , <register >>, !wave.tensor <[@M ] of f16 , <shared >>
177+
178+ // Read 4 elements per thread from same memory - this should be allowed (memory resharding).
179+ // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<4xf16>
180+ %reg4 = wave.read %mem {elements_per_thread = 4 } : (!wave.tensor <[@M ] of f16 , <shared >>) -> !wave.tensor <[@M ] of f16 , <register >>
181+
182+ return
183+ }
184+ }
185+
186+ // -----
187+
188+ // CHECK: #wave.normal_form<full_types,memory_only_types>
189+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
190+ // CHECK-LABEL: @write_backward_propagation
191+ func.func @write_backward_propagation (%mem: !wave.tensor <[@M ] of f16 , <shared >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
192+ %cst = arith.constant 0.0 : f16
193+ // RegisterOp doesn't have explicit elements_per_thread - should get it from backward propagation.
194+ // CHECK: wave.register {{.*}} : vector<4xf16>
195+ %reg = wave.register %cst : !wave.tensor <[@M ] of f16 , <register >>
196+
197+ // WriteOp should propagate elements_per_thread backward to register operand.
198+ // CHECK: wave.write {{.*}} : vector<4xf16>, !wave.tensor<[@M] of f16, <shared>>
199+ wave.write %reg , %mem {elements_per_thread = 4 } : !wave.tensor <[@M ] of f16 , <register >>, !wave.tensor <[@M ] of f16 , <shared >>
200+
201+ return
202+ }
203+ }
204+
205+ // -----
206+
207+ // CHECK: #wave.normal_form<full_types,memory_only_types>
208+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
209+ // CHECK-LABEL: @read_register_propagation
210+ func.func @read_register_propagation (%mem: !wave.tensor <[@M ] of f16 , <shared >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
211+ // ReadOp should only propagate to its register result, not validate memory.
212+ // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<6xf16>
213+ %reg = wave.read %mem {elements_per_thread = 6 } : (!wave.tensor <[@M ] of f16 , <shared >>) -> !wave.tensor <[@M ] of f16 , <register >>
214+
215+ // Downstream operation should get 6 elements per thread.
216+ // CHECK: wave.exp2 {{.*}} : (vector<6xf16>) -> vector<6xf16>
217+ %result = wave.exp2 %reg : (!wave.tensor <[@M ] of f16 , <register >>) -> !wave.tensor <[@M ] of f16 , <register >>
218+
219+ return
220+ }
221+ }
222+
223+ // -----
224+
165225module attributes {wave.normal_form = #wave.normal_form <full_types >} {
166226func.func @mma_compute_lhs_from_rhs (%mem1: !wave.tensor <[@N , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
167227 // LHS without elements_per_thread - will be computed from RHS + MMA constraints.
@@ -174,7 +234,7 @@ func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, <global
174234 // ACC properly initialized through read operation.
175235 %acc = wave.read %mem2 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
176236
177- // LHS elements_per_thread computed via MMA backward propagation
237+ // LHS elements_per_thread computed via MMA backward propagation.
178238 %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
179239 return
180240}
@@ -194,7 +254,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global
194254 // ACC properly initialized through read operation.
195255 %acc = wave.read %mem2 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
196256
197- // RHS elements_per_thread computed via MMA backward propagation
257+ // RHS elements_per_thread computed via MMA backward propagation.
198258 %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
199259 return
200260}
@@ -205,7 +265,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global
205265// Test MMA can compute both LHS and RHS when both are uninitialized
206266module attributes {wave.normal_form = #wave.normal_form <full_types >} {
207267 func.func @mma_compute_both_lhs_rhs (%mem1: !wave.tensor <[@M , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@N , @K ] of f16 , <global >>, %mem3: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
208- // Both LHS and RHS without elements_per_thread - can compute from MMA formulas
268+ // Both LHS and RHS without elements_per_thread - can compute from MMA formulas.
209269 %lhs_init = arith.constant 0.0 : f16
210270 %lhs = wave.register %lhs_init : !wave.tensor <[@M , @K ] of f16 , <register >>
211271 %rhs_init = arith.constant 0.0 : f16
@@ -215,7 +275,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
215275 %acc = wave.read %mem3 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
216276
217277 // With proper MMA formulas, we can now compute both LHS and RHS from constraints,
218- // so this should succeed instead of failing
278+ // so this should succeed instead of failing.
219279 %result = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >} : (!wave.tensor <[@M , @K ] of f16 , <register >>, !wave.tensor <[@N , @K ] of f16 , <register >>, !wave.tensor <[@M , @N ] of f32 , <register >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
220280 return
221281 }
@@ -226,14 +286,14 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
226286// Test MMA error when operand has wrong elements_per_thread
227287module attributes {wave.normal_form = #wave.normal_form <full_types >} {
228288 func.func @mma_operand_mismatch (%mem1: !wave.tensor <[@M , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
229- // LHS with wrong elements_per_thread (should be 8, not 4)
289+ // LHS with wrong elements_per_thread (should be 8, not 4).
230290 %lhs = wave.read %mem1 {elements_per_thread = 4 } : (!wave.tensor <[@M , @K ] of f16 , <global >>) -> !wave.tensor <[@M , @K ] of f16 , <register >>
231291
232292 // RHS without elements_per_thread - will be computed from MMA constraints.
233293 %rhs_init = arith.constant 0.0 : f16
234294 %rhs = wave.register %rhs_init : !wave.tensor <[@N , @K ] of f16 , <register >>
235295
236- // ACC properly initialized
296+ // ACC properly initialized.
237297 %acc = wave.read %mem2 {elements_per_thread = 8 } : (!wave.tensor <[@M , @N ] of f32 , <global >>) -> !wave.tensor <[@M , @N ] of f32 , <register >>
238298
239299 // expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}}
0 commit comments