diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index 9a368f372c296..ad91e25cede3e 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -52,13 +52,45 @@ static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, } static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, - vector::MaskedLoadOp maskedOp) { + vector::MaskedLoadOp maskedOp, + bool passthru) { VectorType vectorType = maskedOp.getVectorType(); Value load = builder.create( loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); - Value res = builder.create( - loc, vectorType, maskedOp.getMask(), load, maskedOp.getPassThru()); - return res; + if (passthru) + load = builder.create(loc, vectorType, maskedOp.getMask(), + load, maskedOp.getPassThru()); + return load; +} + +/// Check if the given value comes from a: +/// +/// arith.select %cond, TRUE/FALSE, TRUE/FALSE +/// +/// i.e the condition is either always true or it's always false. +/// +/// Returns the condition to use for scf.if (condition) { true } else { false }. +static FailureOr matchFullSelect(OpBuilder &b, Value val) { + auto selectOp = val.getDefiningOp(); + if (!selectOp) + return failure(); + std::optional trueInt = getConstantIntValue(selectOp.getTrueValue()); + std::optional falseInt = + getConstantIntValue(selectOp.getFalseValue()); + if (!trueInt || !falseInt) + return failure(); + // getConstantIntValue returns -1 for "true" for bools. + if (trueInt.value() == -1 && falseInt.value() == 0) + return selectOp.getCondition(); + + if (trueInt.value() == 0 && falseInt.value() == -1) { + Value cond = selectOp.getCondition(); + Value one = b.create(cond.getLoc(), /*value=*/true, + /*width=*/1); + Value inverse = b.create(cond.getLoc(), cond, one); + return inverse; + } + return failure(); } static constexpr char kMaskedloadNeedsMask[] = @@ -78,6 +110,16 @@ struct MaskedLoadLowering final : OpRewritePattern { return failure(); } + // Check if this is either a full inbounds load or an empty, oob load. If + // so, take the fast path and don't generate a if condition, because we know + // doing the oob load is always safe. + if (succeeded(matchFullSelect(rewriter, maskedOp.getMask()))) { + Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), + maskedOp, /*passthru=*/true); + rewriter.replaceOp(maskedOp, load); + return success(); + } + Location loc = maskedOp.getLoc(); Value src = maskedOp.getBase(); @@ -135,7 +177,8 @@ struct MaskedLoadLowering final : OpRewritePattern { }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { - Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp); + Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, + /*passthru=*/true); rewriter.create(loc, res); }; @@ -148,11 +191,65 @@ struct MaskedLoadLowering final : OpRewritePattern { } }; +struct FullMaskedLoadToConditionalLoad + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, + PatternRewriter &rewriter) const override { + FailureOr maybeCond = matchFullSelect(rewriter, loadOp.getMask()); + if (failed(maybeCond)) { + return failure(); + } + + Value cond = maybeCond.value(); + auto trueBuilder = [&](OpBuilder &builder, Location loc) { + Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp, + /*passthru=*/false); + rewriter.create(loc, res); + }; + auto falseBuilder = [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, loadOp.getPassThru()); + }; + auto ifOp = rewriter.create(loadOp.getLoc(), cond, trueBuilder, + falseBuilder); + rewriter.replaceOp(loadOp, ifOp); + return success(); + } +}; + +struct FullMaskedStoreToConditionalStore + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, + PatternRewriter &rewriter) const override { + FailureOr maybeCond = matchFullSelect(rewriter, storeOp.getMask()); + if (failed(maybeCond)) { + return failure(); + } + Value cond = maybeCond.value(); + + auto trueBuilder = [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, storeOp.getValueToStore(), + storeOp.getBase(), storeOp.getIndices()); + rewriter.create(loc); + }; + auto ifOp = rewriter.create(storeOp.getLoc(), cond, trueBuilder); + rewriter.replaceOp(storeOp, ifOp); + return success(); + } +}; + } // namespace void mlir::amdgpu::populateAmdgpuMaskedloadToLoadPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + benefit); } struct AmdgpuMaskedloadToLoadPass final diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index febe46bf7a759..d6682ba14eeca 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -114,3 +114,28 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space>, %idx : index, %mask : vector<4xi1>, %passthru : vector<4xf32>) -> vector<4xf32> { + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf32, #amdgpu.address_space>, vector<4xi1>, vector<4xf32> into vector<4xf32> + return %res : vector<4xf32> +} + +// ----- + +func.func @full_select_maskedload_fatrawbuffer_to_load(%mem : memref<8x8xf16, #amdgpu.address_space>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> { + %true = arith.constant dense : vector<4xi1> + %false = arith.constant dense : vector<4xi1> + %mask = arith.select %cond, %true, %false : vector<4xi1> + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16, #amdgpu.address_space>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %res : vector<4xf16> +} + +func.func @full_select_maskedload_to_load(%mem : memref<8x8xf16>, %idx : index, %cond : i1, %passthru : vector<4xf16>) -> vector<4xf16> { + %true = arith.constant dense : vector<4xi1> + %false = arith.constant dense : vector<4xi1> + %mask = arith.select %cond, %true, %false : vector<4xi1> + %res = vector.maskedload %mem[%idx, %idx], %mask, %passthru : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %res : vector<4xf16> +}