Skip to content

Commit 712f22e

Browse files
committed
Fix elements_per_thread propagation to ignore memory operands
Changes: - ReadOp: Only propagate attribute to result (register), ignore memory - WriteOp: Only validate/propagate with register operand, ignore memory This fixes false positives where memory resharding was incorrectly flagged as propagation errors. Fixes #622. Signed-off-by: tyb0807 <[email protected]>
1 parent 9a3bf1b commit 712f22e

File tree

3 files changed

+130
-4
lines changed

3 files changed

+130
-4
lines changed

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, Identity
280280

281281
def ReadOp : WaveOp<"read", [
282282
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
283-
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
283+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
284284
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
285285
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
286286
let summary = "Reads from memory";
@@ -334,7 +334,7 @@ def RegisterOp : WaveOp<"register", [
334334

335335
def WriteOp : WaveOp<"write", [
336336
WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait,
337-
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
337+
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
338338
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
339339
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
340340
let summary = "Writes into memory";

water/lib/Dialect/Wave/IR/WaveOps.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,32 @@ LogicalResult ReadOp::verify() {
13211321
bounds.getMapping());
13221322
}
13231323

1324+
llvm::FailureOr<mlir::ChangeResult> wave::ReadOp::propagateElementsPerThreadForward(
1325+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
1326+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1327+
llvm::raw_ostream &errs) {
1328+
// ReadOp only propagates elements_per_thread attribute to result (register)
1329+
// Memory operand is ignored for propagation - you can read any number of elements
1330+
// from memory regardless of how many were written
1331+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1332+
if (!elementsPerThread)
1333+
return mlir::ChangeResult::NoChange;
1334+
1335+
wave::ElementsPerThreadLatticeValue expectedResult(*elementsPerThread);
1336+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1337+
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
1338+
resultElements, "elements_per_thread attribute", "", "result", errs);
1339+
}
1340+
1341+
llvm::FailureOr<mlir::ChangeResult> wave::ReadOp::propagateElementsPerThreadBackward(
1342+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
1343+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
1344+
llvm::raw_ostream &) {
1345+
// ReadOp doesn't propagate backward to memory operand
1346+
// Memory is decoupled from register dataflow for elements_per_thread
1347+
return mlir::ChangeResult::NoChange;
1348+
}
1349+
13241350
//-----------------------------------------------------------------------------
13251351
// RegisterOp
13261352
//-----------------------------------------------------------------------------
@@ -1402,6 +1428,46 @@ LogicalResult WriteOp::verify() {
14021428
bounds.getMapping());
14031429
}
14041430

1431+
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateElementsPerThreadForward(
1432+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1433+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
1434+
llvm::raw_ostream &errs) {
1435+
// WriteOp only validates that elements_per_thread attribute matches register operand
1436+
// Memory operand is ignored for propagation - you can write to memory with any layout
1437+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1438+
if (!elementsPerThread)
1439+
return mlir::ChangeResult::NoChange;
1440+
1441+
// Validate register operand (value_to_store) matches attribute
1442+
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
1443+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
1444+
operandElements.slice(0, 1); // Only first operand (value_to_store)
1445+
1446+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1447+
expectedValue, valueOnly, llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>(),
1448+
"elements_per_thread attribute", "register operand", "", errs);
1449+
}
1450+
1451+
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateElementsPerThreadBackward(
1452+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
1453+
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
1454+
llvm::raw_ostream &errs) {
1455+
// WriteOp only propagates backward to register operand (value_to_store)
1456+
// Memory operand is ignored - you can write any layout to memory
1457+
std::optional<int64_t> elementsPerThread = getElementsPerThread();
1458+
if (!elementsPerThread)
1459+
return mlir::ChangeResult::NoChange;
1460+
1461+
// Propagate to register operand only
1462+
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
1463+
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
1464+
operandElements.slice(0, 1); // Only first operand (value_to_store)
1465+
1466+
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
1467+
expectedValue, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
1468+
valueOnly, "elements_per_thread attribute", "", "register operand", errs);
1469+
}
1470+
14051471
// Propagate index expressions forward from the operands to the result of the
14061472
// WriteOp. Since WriteOp has no results, this is a no-op.
14071473
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateIndexExprsForward(

water/test/Dialect/Wave/propagate-elements-per-thread.mlir

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func.func @missing_elements_per_thread(%mem: !wave.tensor<[@M] of f16, <global>>
100100
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
101101
func.func @read_write_conflict(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
102102
%reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <global>>) -> !wave.tensor<[@M] of f16, <register>>
103-
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
103+
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
104104
wave.write %reg, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <global>>
105105
return
106106
}
@@ -112,7 +112,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
112112
func.func @read_write_conflict_indirect(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
113113
%reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <global>>) -> !wave.tensor<[@M] of f16, <register>>
114114
%val = wave.exp2 %reg : (!wave.tensor<[@M] of f16, <register>>) -> !wave.tensor<[@M] of f16, <register>>
115-
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
115+
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
116116
wave.write %reg, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <global>>
117117
return
118118
}
@@ -162,6 +162,66 @@ module {
162162

163163
// -----
164164

165+
// CHECK: #wave.normal_form<full_types,memory_only_types>
166+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
167+
// CHECK-LABEL: @memory_resharding_allowed
168+
func.func @memory_resharding_allowed(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
169+
%cst = arith.constant 0.0 : f16
170+
// Register gets 8 elements per thread from write operation's backward propagation
171+
// CHECK: wave.register {{.*}} : vector<8xf16>
172+
%reg8 = wave.register %cst : !wave.tensor<[@M] of f16, <register>>
173+
174+
// Write 8 elements per thread to memory
175+
// CHECK: wave.write {{.*}} : vector<8xf16>, !wave.tensor<[@M] of f16, <shared>>
176+
wave.write %reg8, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <shared>>
177+
178+
// Read 4 elements per thread from same memory - this should be allowed (memory resharding)
179+
// CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<4xf16>
180+
%reg4 = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <shared>>) -> !wave.tensor<[@M] of f16, <register>>
181+
182+
return
183+
}
184+
}
185+
186+
// -----
187+
188+
// CHECK: #wave.normal_form<full_types,memory_only_types>
189+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
190+
// CHECK-LABEL: @write_backward_propagation
191+
func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
192+
%cst = arith.constant 0.0 : f16
193+
// RegisterOp without explicit elements_per_thread - should get it from backward propagation
194+
// CHECK: wave.register {{.*}} : vector<4xf16>
195+
%reg = wave.register %cst : !wave.tensor<[@M] of f16, <register>>
196+
197+
// WriteOp should propagate elements_per_thread backward to register operand
198+
// CHECK: wave.write {{.*}} : vector<4xf16>, !wave.tensor<[@M] of f16, <shared>>
199+
wave.write %reg, %mem {elements_per_thread = 4} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <shared>>
200+
201+
return
202+
}
203+
}
204+
205+
// -----
206+
207+
// CHECK: #wave.normal_form<full_types,memory_only_types>
208+
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
209+
// CHECK-LABEL: @read_register_propagation
210+
func.func @read_register_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
211+
// ReadOp should only propagate to its register result, not validate memory
212+
// CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<6xf16>
213+
%reg = wave.read %mem {elements_per_thread = 6} : (!wave.tensor<[@M] of f16, <shared>>) -> !wave.tensor<[@M] of f16, <register>>
214+
215+
// Downstream operation should get 6 elements per thread
216+
// CHECK: wave.exp2 {{.*}} : (vector<6xf16>) -> vector<6xf16>
217+
%result = wave.exp2 %reg : (!wave.tensor<[@M] of f16, <register>>) -> !wave.tensor<[@M] of f16, <register>>
218+
219+
return
220+
}
221+
}
222+
223+
// -----
224+
165225
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
166226
func.func @mma_uninitialized_lhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
167227
// LHS without elements_per_thread - this will remain uninitialized.

0 commit comments

Comments
 (0)