Skip to content

Commit 93fbc79

Browse files
authored
[water] Fix elements_per_thread propagation to ignore memory operands (#623)
Changes: - ReadOp: Only propagate attribute to result (register), ignore memory - WriteOp: Only validate/propagate with register operand, ignore memory This fixes false positives where memory resharding was incorrectly flagged as propagation errors. Fixes #622. --------- Signed-off-by: tyb0807 <[email protected]>
1 parent 27c2a71 commit 93fbc79

File tree

3 files changed

+140
-8
lines changed

3 files changed

+140
-8
lines changed

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, Identity
282282

283283
def ReadOp : WaveOp<"read", [
284284
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
285-
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
285+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
286286
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
287287
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
288288
let summary = "Reads from memory";
@@ -336,7 +336,7 @@ def RegisterOp : WaveOp<"register", [
336336

337337
def WriteOp : WaveOp<"write", [
338338
WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait,
339-
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
339+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
340340
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
341341
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
342342
let summary = "Writes into memory";

water/lib/Dialect/Wave/IR/WaveOps.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,34 @@ LogicalResult ReadOp::verify() {
14481448
bounds.getMapping());
14491449
}
14501450

1451+
llvm::FailureOr<mlir::ChangeResult>
1452+
wave::ReadOp::propagateElementsPerThreadForward(
1453+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
1454+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1455+
llvm::raw_ostream &errs) {
1456+
// ReadOp only propagates elements_per_thread attribute to result (register).
1457+
// Memory operand is ignored for propagation - you can read any number of
1458+
// elements from memory regardless of how many were written.
1459+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1460+
if (!elementsPerThread)
1461+
return mlir::ChangeResult::NoChange;
1462+
1463+
wave::ElementsPerThreadLatticeValue expectedResult(*elementsPerThread);
1464+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1465+
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
1466+
resultElements, "elements_per_thread attribute", "", "result", errs);
1467+
}
1468+
1469+
llvm::FailureOr<mlir::ChangeResult>
1470+
wave::ReadOp::propagateElementsPerThreadBackward(
1471+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
1472+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1473+
llvm::raw_ostream &) {
1474+
// ReadOp doesn't propagate backward to memory operand.
1475+
// Memory is decoupled from register dataflow for elements_per_thread.
1476+
return mlir::ChangeResult::NoChange;
1477+
}
1478+
14511479
//-----------------------------------------------------------------------------
14521480
// RegisterOp
14531481
//-----------------------------------------------------------------------------
@@ -1529,6 +1557,50 @@ LogicalResult WriteOp::verify() {
15291557
bounds.getMapping());
15301558
}
15311559

1560+
llvm::FailureOr<mlir::ChangeResult>
1561+
wave::WriteOp::propagateElementsPerThreadForward(
1562+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1563+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
1564+
llvm::raw_ostream &errs) {
1565+
// WriteOp only validates that elements_per_thread attribute matches register
1566+
// operand. Memory operand is ignored for propagation - you can write to
1567+
// memory with any layout.
1568+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1569+
if (!elementsPerThread)
1570+
return mlir::ChangeResult::NoChange;
1571+
1572+
// Validate register operand (value_to_store) matches attribute.
1573+
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
1574+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
1575+
operandElements.slice(getValueToStoreMutable().getOperandNumber(), 1);
1576+
1577+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1578+
expectedValue, valueOnly,
1579+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>(),
1580+
"elements_per_thread attribute", "operand", "", errs);
1581+
}
1582+
1583+
llvm::FailureOr<mlir::ChangeResult>
1584+
wave::WriteOp::propagateElementsPerThreadBackward(
1585+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1586+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
1587+
llvm::raw_ostream &errs) {
1588+
// WriteOp only propagates backward to register operand (value_to_store).
1589+
// Memory operand is ignored - you can write any layout to memory.
1590+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1591+
if (!elementsPerThread)
1592+
return mlir::ChangeResult::NoChange;
1593+
1594+
// Propagate to register operand only.
1595+
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
1596+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
1597+
operandElements.slice(getValueToStoreMutable().getOperandNumber(), 1);
1598+
1599+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1600+
expectedValue, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
1601+
valueOnly, "elements_per_thread attribute", "", "operand", errs);
1602+
}
1603+
15321604
// Propagate index expressions forward from the operands to the result of the
15331605
// WriteOp. Since WriteOp has no results, this is a no-op.
15341606
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateIndexExprsForward(

water/test/Dialect/Wave/propagate-elements-per-thread.mlir

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,66 @@ module {
162162

163163
// -----
164164

165+
// CHECK: #wave.normal_form<full_types,memory_only_types>
166+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
167+
// CHECK-LABEL: @memory_resharding_allowed
168+
func.func @memory_resharding_allowed(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
169+
%cst = arith.constant 0.0 : f16
170+
// Register gets 8 elements per thread from write operation's backward propagation.
171+
// CHECK: wave.register {{.*}} : vector<8xf16>
172+
%reg8 = wave.register %cst : !wave.tensor<[@M] of f16, <register>>
173+
174+
// Write 8 elements per thread to memory.
175+
// CHECK: wave.write {{.*}} : vector<8xf16>, !wave.tensor<[@M] of f16, <shared>>
176+
wave.write %reg8, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <shared>>
177+
178+
// Read 4 elements per thread from same memory - this should be allowed (memory resharding).
179+
// CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<4xf16>
180+
%reg4 = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <shared>>) -> !wave.tensor<[@M] of f16, <register>>
181+
182+
return
183+
}
184+
}
185+
186+
// -----
187+
188+
// CHECK: #wave.normal_form<full_types,memory_only_types>
189+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
190+
// CHECK-LABEL: @write_backward_propagation
191+
func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
192+
%cst = arith.constant 0.0 : f16
193+
// RegisterOp doesn't have explicit elements_per_thread - should get it from backward propagation.
194+
// CHECK: wave.register {{.*}} : vector<4xf16>
195+
%reg = wave.register %cst : !wave.tensor<[@M] of f16, <register>>
196+
197+
// WriteOp should propagate elements_per_thread backward to register operand.
198+
// CHECK: wave.write {{.*}} : vector<4xf16>, !wave.tensor<[@M] of f16, <shared>>
199+
wave.write %reg, %mem {elements_per_thread = 4} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <shared>>
200+
201+
return
202+
}
203+
}
204+
205+
// -----
206+
207+
// CHECK: #wave.normal_form<full_types,memory_only_types>
208+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
209+
// CHECK-LABEL: @read_register_propagation
210+
func.func @read_register_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
211+
// ReadOp should only propagate to its register result, not validate memory.
212+
// CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<6xf16>
213+
%reg = wave.read %mem {elements_per_thread = 6} : (!wave.tensor<[@M] of f16, <shared>>) -> !wave.tensor<[@M] of f16, <register>>
214+
215+
// Downstream operation should get 6 elements per thread.
216+
// CHECK: wave.exp2 {{.*}} : (vector<6xf16>) -> vector<6xf16>
217+
%result = wave.exp2 %reg : (!wave.tensor<[@M] of f16, <register>>) -> !wave.tensor<[@M] of f16, <register>>
218+
219+
return
220+
}
221+
}
222+
223+
// -----
224+
165225
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
166226
func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
167227
// LHS without elements_per_thread - will be computed from RHS + MMA constraints.
@@ -174,7 +234,7 @@ func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, <global
174234
// ACC properly initialized through read operation.
175235
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
176236

177-
// LHS elements_per_thread computed via MMA backward propagation
237+
// LHS elements_per_thread computed via MMA backward propagation.
178238
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
179239
return
180240
}
@@ -194,7 +254,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global
194254
// ACC properly initialized through read operation.
195255
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
196256

197-
// RHS elements_per_thread computed via MMA backward propagation
257+
// RHS elements_per_thread computed via MMA backward propagation.
198258
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
199259
return
200260
}
@@ -205,7 +265,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global
205265
// Test MMA can compute both LHS and RHS when both are uninitialized
206266
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
207267
func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@N, @K] of f16, <global>>, %mem3: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
208-
// Both LHS and RHS without elements_per_thread - can compute from MMA formulas
268+
// Both LHS and RHS without elements_per_thread - can compute from MMA formulas.
209269
%lhs_init = arith.constant 0.0 : f16
210270
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
211271
%rhs_init = arith.constant 0.0 : f16
@@ -215,7 +275,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
215275
%acc = wave.read %mem3 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
216276

217277
// With proper MMA formulas, we can now compute both LHS and RHS from constraints,
218-
// so this should succeed instead of failing
278+
// so this should succeed instead of failing.
219279
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
220280
return
221281
}
@@ -226,14 +286,14 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
226286
// Test MMA error when operand has wrong elements_per_thread
227287
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
228288
func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
229-
// LHS with wrong elements_per_thread (should be 8, not 4)
289+
// LHS with wrong elements_per_thread (should be 8, not 4).
230290
%lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>
231291

232292
// RHS without elements_per_thread - will be computed from MMA constraints.
233293
%rhs_init = arith.constant 0.0 : f16
234294
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>
235295

236-
// ACC properly initialized
296+
// ACC properly initialized.
237297
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>
238298

239299
// expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}}

0 commit comments

Comments
 (0)