Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def Exp2Op : UnaryWaveOp<"exp2"> {

def MmaOp : WaveOp<"mma",
[DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface,
["initializeIndexExprsForward", "initializeIndexExprsBackward"]>]>,
WaveArithmeticOpDoc {
Expand All @@ -130,6 +131,11 @@ def MmaOp : WaveOp<"mma",
"$lhs `,` $rhs `,` $accumulator " # commonArgumentsSyntax # "attr-dict `:`"
"functional-type(operands, results)";
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Compute the expected elements per thread for this MMA operation.
unsigned computeElementsPerThread();
}];
}

//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -161,12 +167,14 @@ def IterateOp : Op<WaveDialect, "iterate", [

let arguments = (ins
Arg<WaveSymbolAttr, "Iterator symbol">:$iterator,
Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args,
Arg<Variadic<WaveTensorType>, "Captured values">:$captures
// Accept both WaveTensorType (before PropagateElementsPerThread) and AnyVectorOfAnyRank (after)
Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Carried values">:$iter_args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be WaveTensorInRegisters? That constraint already accepts the the tensors with no address space, tensors in register address space and 1D vectors. And we most likely don't want any vector of any rank here, which would include scalable, 0d and other nonsense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because WaveTensorInRegisters doesn't work with Variadic. I think because it's a TypeConstraint (or something like that) and not a Type.

Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Captured values">:$captures
);

let results = (outs
Res<Variadic<WaveTensorType>, "Yielded values">:$results
// Results follow the same type constraints as inputs
Res<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Yielded values">:$results
);

let regions = (region
Expand Down Expand Up @@ -206,7 +214,8 @@ def YieldOp : Op<WaveDialect, "yield",
let summary = "Yields values from the current control flow context";

let arguments = (ins
Arg<Variadic<WaveTensorType>, "Yielded values">:$values
// Must match the type constraints of wave.iterate results
Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Yielded values">:$values
);

let assemblyFormat = "$values attr-dict `:` type($values)";
Expand Down Expand Up @@ -274,7 +283,7 @@ def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, Identity

def ReadOp : WaveOp<"read", [
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
let summary = "Reads from memory";
Expand Down Expand Up @@ -328,7 +337,7 @@ def RegisterOp : WaveOp<"register", [

def WriteOp : WaveOp<"write", [
WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait,
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
let summary = "Writes into memory";
Expand Down
1 change: 1 addition & 0 deletions water/lib/Dialect/Wave/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRWaveDialect
MLIRIR
MLIRControlFlowInterfaces
MLIRFunctionInterfaces
MLIRFuncDialect
)

# Install the Wave dialect library so Python can find it at runtime.
Expand Down
235 changes: 220 additions & 15 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "water/Dialect/Wave/IR/WaveOps.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -198,10 +199,26 @@ void wave::IterateOp::makeNonIsolated(mlir::RewriterBase &rewriter) {
}

bool wave::IterateOp::areTypesCompatible(mlir::Type lhs, mlir::Type rhs) {
return detail::verifyTypesCompatible(llvm::cast<wave::WaveTensorType>(lhs),
llvm::cast<wave::WaveTensorType>(rhs),
/*includeAddressSpace=*/true)
.succeeded();
// Handle both WaveTensorType and VectorType combinations
auto lhsTensor = llvm::dyn_cast<wave::WaveTensorType>(lhs);
auto rhsTensor = llvm::dyn_cast<wave::WaveTensorType>(rhs);
auto lhsVector = llvm::dyn_cast<mlir::VectorType>(lhs);
auto rhsVector = llvm::dyn_cast<mlir::VectorType>(rhs);

// Both are wave tensors - use existing logic
if (lhsTensor && rhsTensor) {
return detail::verifyTypesCompatible(lhsTensor, rhsTensor,
/*includeAddressSpace=*/true)
.succeeded();
}

// Both are vectors - simple equality check
if (lhsVector && rhsVector) {
return lhsVector == rhsVector;
}

// Mixed types are not compatible
return false;
}

mlir::OperandRange
Expand Down Expand Up @@ -250,19 +267,42 @@ mlir::LogicalResult wave::IterateOp::verify() {
}
for (auto &&[i, iterArg, result] :
llvm::enumerate(iterArgTypes, resultTypes)) {
auto iterArgTensor = llvm::cast<wave::WaveTensorType>(iterArg);
auto resultTensor = llvm::cast<wave::WaveTensorType>(result);
if (!iterArgTensor.getFullySpecified() || !resultTensor.getFullySpecified())
continue;
// Handle verification for both wave tensors and vectors
auto iterArgTensor = llvm::dyn_cast<wave::WaveTensorType>(iterArg);
auto resultTensor = llvm::dyn_cast<wave::WaveTensorType>(result);
auto iterArgVector = llvm::dyn_cast<mlir::VectorType>(iterArg);
auto resultVector = llvm::dyn_cast<mlir::VectorType>(result);

auto allDims =
llvm::to_vector(llvm::iota_range<int>(0, iterArgTensor.getRank(),
/*Inclusive=*/false));
auto istr = std::to_string(i);
if (mlir::failed(detail::verifyTypesMatchingDimensions(
getLoc(), "iter_args #" + istr, iterArgTensor, allDims,
"result #" + istr, resultTensor, allDims)))
return mlir::failure();

// Both are wave tensors - use existing shape verification logic
if (iterArgTensor && resultTensor) {
if (!iterArgTensor.getFullySpecified() ||
!resultTensor.getFullySpecified())
continue;

auto allDims =
llvm::to_vector(llvm::iota_range<int>(0, iterArgTensor.getRank(),
/*Inclusive=*/false));
if (mlir::failed(detail::verifyTypesMatchingDimensions(
getLoc(), "iter_args #" + istr, iterArgTensor, allDims,
"result #" + istr, resultTensor, allDims)))
return mlir::failure();
}
// Both are vectors - check exact type equality
else if (iterArgVector && resultVector) {
if (iterArgVector != resultVector) {
return emitOpError() << "iter_args #" << i << " type (" << iterArgVector
<< ") must match result #" << i << " type ("
<< resultVector << ")";
}
}
// Mixed types are not allowed
else {
return emitOpError() << "iter_args #" << i << " and result #" << i
<< " must be the same category of types (both wave "
"tensors or both vectors)";
}
}

return mlir::success();
Expand Down Expand Up @@ -1078,6 +1118,99 @@ LogicalResult MmaOp::verify() {
accumulatorType.getElementType());
}

/// Compute the expected elements per thread for this MMA operation.
/// Extracts threadsPerWave from ancestor operations with hardware constraints.
/// Returns 0 if no constraints are found.
unsigned wave::MmaOp::computeElementsPerThread() {
if (!getKindAttr()) {
return 0;
}
wave::WaveMmaSpec spec =
wave::WaveMmaKindAttr::getSpec(getContext(), getKind());

// Extract threads per wave from hardware constraint by walking up the
// ancestry.
mlir::Operation *op = getOperation();
while (op) {
if (auto constraints = op->getAttrOfType<mlir::ArrayAttr>(
wave::WaveDialect::kWaveConstraintsAttrName)) {
for (mlir::Attribute constraint : constraints) {
if (auto hardwareConstraint =
llvm::dyn_cast<wave::HardwareConstraintAttr>(constraint)) {
unsigned totalElements = spec.m * spec.n;
return totalElements / hardwareConstraint.getThreadsPerWave();
}
}
}
op = op->getParentOp();
}

// Return 0 to indicate failure if no constraints found.
return 0;
}

llvm::FailureOr<mlir::ChangeResult>
wave::MmaOp::propagateElementsPerThreadForward(
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
llvm::raw_ostream &errs) {
unsigned expectedElementsPerThread = computeElementsPerThread();
if (expectedElementsPerThread == 0) {
errs << "MMA operation has no hardware constraints available";
return mlir::failure();
}
wave::ElementsPerThreadLatticeValue expectedResult(expectedElementsPerThread);
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
resultElements, "computed from MMA kind", "", "result", errs);
}

llvm::FailureOr<mlir::ChangeResult>
wave::MmaOp::propagateElementsPerThreadBackward(
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// For MMA, the accumulator should have the same elements per thread as the
// result. The LHS and RHS operands may have different constraints based on
// their dimensions.
unsigned expectedElementsPerThread = computeElementsPerThread();
if (expectedElementsPerThread == 0) {
errs << "MMA operation has no hardware constraints available";
return mlir::failure();
}
wave::ElementsPerThreadLatticeValue expectedAccumulator(
expectedElementsPerThread);

unsigned accumulatorOperandNumber =
getAccumulatorMutable().getOperandNumber();

// Validate that LHS and RHS operands have concrete elements_per_thread
// values. We don't propagate to them, but we check they've been properly
// initialized.
for (unsigned i = 0; i < 2 && i < operandElements.size();
++i) { // LHS (0) and RHS (1) operands
if (operandElements[i].isBottom()) {
errs << "MMA operand #" << i << " (";
errs << (i == 0 ? "LHS" : "RHS");
errs << ") has uninitialized elements_per_thread";
return mlir::failure();
}
}

// Propagate to the accumulator operand.
if (operandElements.size() > accumulatorOperandNumber) {
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> accumulatorOnly =
operandElements.slice(accumulatorOperandNumber, 1);

return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedAccumulator,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(), accumulatorOnly,
"computed from MMA kind", "", "accumulator operand", errs);
}

return mlir::ChangeResult::NoChange;
}

//-----------------------------------------------------------------------------
// ReadOp
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -1233,6 +1366,34 @@ LogicalResult ReadOp::verify() {
bounds.getMapping());
}

llvm::FailureOr<mlir::ChangeResult>
wave::ReadOp::propagateElementsPerThreadForward(
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
llvm::raw_ostream &errs) {
// ReadOp only propagates elements_per_thread attribute to result (register)
// Memory operand is ignored for propagation - you can read any number of
// elements from memory regardless of how many were written
std::optional<int64_t> elementsPerThread = getElementsPerThread();
if (!elementsPerThread)
return mlir::ChangeResult::NoChange;

wave::ElementsPerThreadLatticeValue expectedResult(*elementsPerThread);
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
resultElements, "elements_per_thread attribute", "", "result", errs);
}

llvm::FailureOr<mlir::ChangeResult>
wave::ReadOp::propagateElementsPerThreadBackward(
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
llvm::raw_ostream &) {
// ReadOp doesn't propagate backward to memory operand
// Memory is decoupled from register dataflow for elements_per_thread
return mlir::ChangeResult::NoChange;
}

//-----------------------------------------------------------------------------
// RegisterOp
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -1314,6 +1475,50 @@ LogicalResult WriteOp::verify() {
bounds.getMapping());
}

llvm::FailureOr<mlir::ChangeResult>
wave::WriteOp::propagateElementsPerThreadForward(
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// WriteOp only validates that elements_per_thread attribute matches register
// operand Memory operand is ignored for propagation - you can write to memory
// with any layout
std::optional<int64_t> elementsPerThread = getElementsPerThread();
if (!elementsPerThread)
return mlir::ChangeResult::NoChange;

// Validate register operand (value_to_store) matches attribute
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
operandElements.slice(0, 1); // Only first operand (value_to_store)

return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedValue, valueOnly,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>(),
"elements_per_thread attribute", "register operand", "", errs);
}

llvm::FailureOr<mlir::ChangeResult>
wave::WriteOp::propagateElementsPerThreadBackward(
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// WriteOp only propagates backward to register operand (value_to_store)
// Memory operand is ignored - you can write any layout to memory
std::optional<int64_t> elementsPerThread = getElementsPerThread();
if (!elementsPerThread)
return mlir::ChangeResult::NoChange;

// Propagate to register operand only
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
operandElements.slice(0, 1); // Only first operand (value_to_store)

return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedValue, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
valueOnly, "elements_per_thread attribute", "", "register operand", errs);
}

// Propagate index expressions forward from the operands to the result of the
// WriteOp. Since WriteOp has no results, this is a no-op.
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateIndexExprsForward(
Expand Down
2 changes: 1 addition & 1 deletion water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ materializeAffine(Location loc, ArrayRef<Attribute> symbols, AffineMap map,
std::optional<int64_t> value = hyper.getSymbolValue(name);
#ifndef NDEBUG
if (!value) {
llvm::errs() << "symbol: " << name << "\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

assert(false && "unknown symbol, should have been caught by verifiers");
}
#endif
Expand Down Expand Up @@ -134,6 +133,7 @@ materializeAffine(Location loc, ArrayRef<Attribute> symbols, AffineMap map,
AffineMap submap =
AffineMap::get(map.getNumDims(), map.getNumSymbols(), expr);
SmallVector<Value> symVals = baseSymVals;

affine::canonicalizeMapAndOperands(&submap, &symVals);

Value apply = affine::AffineApplyOp::create(rewriter, loc, submap, symVals);
Expand Down
Loading