@@ -100,7 +100,7 @@ func.func @missing_elements_per_thread(%mem: !wave.tensor<[@M] of f16, <global>>
100100module attributes {wave.normal_form = #wave.normal_form <full_types >} {
101101func.func @read_write_conflict (%mem: !wave.tensor <[@M ] of f16 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
102102 %reg = wave.read %mem {elements_per_thread = 4 } : (!wave.tensor <[@M ] of f16 , <global >>) -> !wave.tensor <[@M ] of f16 , <register >>
103- // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
103+ // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
104104 wave.write %reg , %mem {elements_per_thread = 8 } : !wave.tensor <[@M ] of f16 , <register >>, !wave.tensor <[@M ] of f16 , <global >>
105105 return
106106}
@@ -112,7 +112,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
112112func.func @read_write_conflict_indirect (%mem: !wave.tensor <[@M ] of f16 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
113113 %reg = wave.read %mem {elements_per_thread = 4 } : (!wave.tensor <[@M ] of f16 , <global >>) -> !wave.tensor <[@M ] of f16 , <register >>
114114 %val = wave.exp2 %reg : (!wave.tensor <[@M ] of f16 , <register >>) -> !wave.tensor <[@M ] of f16 , <register >>
115- // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
115+ // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
116116 wave.write %reg , %mem {elements_per_thread = 8 } : !wave.tensor <[@M ] of f16 , <register >>, !wave.tensor <[@M ] of f16 , <global >>
117117 return
118118}
@@ -162,6 +162,66 @@ module {
162162
163163// -----
164164
165+ // CHECK: #wave.normal_form<full_types,memory_only_types>
166+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
167+ // CHECK-LABEL: @memory_resharding_allowed
168+ func.func @memory_resharding_allowed (%mem: !wave.tensor <[@M ] of f16 , <shared >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
169+ %cst = arith.constant 0.0 : f16
170+ // Register gets 8 elements per thread from write operation's backward propagation
171+ // CHECK: wave.register {{.*}} : vector<8xf16>
172+ %reg8 = wave.register %cst : !wave.tensor <[@M ] of f16 , <register >>
173+
174+ // Write 8 elements per thread to memory
175+ // CHECK: wave.write {{.*}} : vector<8xf16>, !wave.tensor<[@M] of f16, <shared>>
176+ wave.write %reg8 , %mem {elements_per_thread = 8 } : !wave.tensor <[@M ] of f16 , <register >>, !wave.tensor <[@M ] of f16 , <shared >>
177+
178+ // Read 4 elements per thread from same memory - this should be allowed (memory resharding)
179+ // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<4xf16>
180+ %reg4 = wave.read %mem {elements_per_thread = 4 } : (!wave.tensor <[@M ] of f16 , <shared >>) -> !wave.tensor <[@M ] of f16 , <register >>
181+
182+ return
183+ }
184+ }
185+
186+ // -----
187+
188+ // CHECK: #wave.normal_form<full_types,memory_only_types>
189+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
190+ // CHECK-LABEL: @write_backward_propagation
191+ func.func @write_backward_propagation (%mem: !wave.tensor <[@M ] of f16 , <shared >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
192+ %cst = arith.constant 0.0 : f16
193+ // RegisterOp without explicit elements_per_thread - should get it from backward propagation
194+ // CHECK: wave.register {{.*}} : vector<4xf16>
195+ %reg = wave.register %cst : !wave.tensor <[@M ] of f16 , <register >>
196+
197+ // WriteOp should propagate elements_per_thread backward to register operand
198+ // CHECK: wave.write {{.*}} : vector<4xf16>, !wave.tensor<[@M] of f16, <shared>>
199+ wave.write %reg , %mem {elements_per_thread = 4 } : !wave.tensor <[@M ] of f16 , <register >>, !wave.tensor <[@M ] of f16 , <shared >>
200+
201+ return
202+ }
203+ }
204+
205+ // -----
206+
207+ // CHECK: #wave.normal_form<full_types,memory_only_types>
208+ module attributes {wave.normal_form = #wave.normal_form <full_types >} {
209+ // CHECK-LABEL: @read_register_propagation
210+ func.func @read_register_propagation (%mem: !wave.tensor <[@M ] of f16 , <shared >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 128 }>} {
211+ // ReadOp should only propagate to its register result, not validate memory
212+ // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<6xf16>
213+ %reg = wave.read %mem {elements_per_thread = 6 } : (!wave.tensor <[@M ] of f16 , <shared >>) -> !wave.tensor <[@M ] of f16 , <register >>
214+
215+ // Downstream operation should get 6 elements per thread
216+ // CHECK: wave.exp2 {{.*}} : (vector<6xf16>) -> vector<6xf16>
217+ %result = wave.exp2 %reg : (!wave.tensor <[@M ] of f16 , <register >>) -> !wave.tensor <[@M ] of f16 , <register >>
218+
219+ return
220+ }
221+ }
222+
223+ // -----
224+
165225module attributes {wave.normal_form = #wave.normal_form <full_types >} {
166226func.func @mma_uninitialized_lhs (%mem1: !wave.tensor <[@N , @K ] of f16 , <global >>, %mem2: !wave.tensor <[@M , @N ] of f32 , <global >>) attributes {wave.hyperparameters = #wave.hyperparameters <{M = 16 , N = 16 , K = 16 }>, wave.constraints = [#wave.hardware_constraint <threads_per_wave = 32 , waves_per_block = [1 , 1 , 1 ], mma_type = #wave.mma_kind <f32 _16 x16 x16 _f16 >, vector_shapes = {M = 1 , N = 1 , K = 16 }, max_bits_per_load = 128 >]} {
167227 // LHS without elements_per_thread - this will remain uninitialized.
0 commit comments