Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,21 @@ class ReadOpLoweringPattern : public OpConversionPattern<wave::ReadOp> {
LogicalResult
matchAndRewrite(wave::ReadOp op, wave::ReadOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type convertedType =
getTypeConverter()->convertType(op.getResult().getType());
if (!convertedType)
return rewriter.notifyMatchFailure(op,
"WaveTensorType conversion failed");
auto vectorType = cast<VectorType>(convertedType);
// Check if result is already a vector (after PropagateElementsPerThread
// pass)
Type resultType = op.getResult().getType();
VectorType vectorType;
if (auto vecType = dyn_cast<VectorType>(resultType)) {
// Already converted to vector by PropagateElementsPerThread
vectorType = vecType;
} else {
// Still a WaveTensorType, needs conversion
Type convertedType = getTypeConverter()->convertType(resultType);
if (!convertedType)
return rewriter.notifyMatchFailure(op,
"WaveTensorType conversion failed");
vectorType = cast<VectorType>(convertedType);
}
FailureOr<MemAccessInfo> memInfo = createMemoryIndicesAndMask(
rewriter, getTypeConverter(), op, op.getMemory().getType(), vectorType);
if (failed(memInfo))
Expand Down
20 changes: 20 additions & 0 deletions water/test/Dialect/Wave/lower-wave-to-mlir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,26 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t

// -----

module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_types>} {
// CHECK-LABEL: @read_with_vector_result
func.func @read_with_vector_result(%mem: !wave.tensor<[@M, @N] of f16, <global>>)
attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 128}>} {
// Test ReadOp lowering when result is already a vector type
// (simulates after PropagateElementsPerThread pass)
%0 = wave.read %mem index [{
M : [#wave.index_symbol<WG0>] -> (WG0, 1, 1),
N : [#wave.index_symbol<T0>] -> (T0, 8, 1)
}]
: (!wave.tensor<[@M, @N] of f16, <global>>) -> vector<8xf16>

// CHECK: %[[READ:.*]] = vector.load
// CHECK-SAME: : memref<128x128xf16, #gpu.address_space<global>>, vector<8xf16>
return
}
}

// -----

module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_types>} {
// CHECK-LABEL: @lower_write
func.func @lower_write(%mem: !wave.tensor<[@M, @N] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 128, N = 128}>} {
Expand Down
Loading