-
Notifications
You must be signed in to change notification settings - Fork 14.4k
Lower affine modulo by powers of two using bitwise AND #146311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: Yuxi Sun (sherylll) ChangesThis patch adds a special-case optimization in the affine-to-standard lowering pass to replace modulo operations by constant powers of two with a single bitwise AND operation. This reduces instruction count and improves performance for common cases like Full diff: https://github.com/llvm/llvm-project/pull/146311.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 66b3f2a4f93a5..de9c7874767e4 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -80,12 +80,24 @@ class AffineApplyExpander
/// let remainder = srem a, b;
/// negative = a < 0 in
/// select negative, remainder + b, remainder.
+ ///
+ /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative x.
Value visitModExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
}
+
+ // Special case: x mod n where n is a power of 2 can be optimized to x & (n-1)
+ int64_t rhsValue = rhsConst.getValue();
+ if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
+ auto lhs = visit(expr.getLHS());
+ assert(lhs && "unexpected affine expr lowering failure");
+
+ Value maskCst = builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
+ return builder.create<arith::AndIOp>(loc, lhs, maskCst);
+ }
}
auto lhs = visit(expr.getLHS());
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 550ea71882e14..07f7c64fe6ea5 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,3 +927,12 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
// CHECK: scf.reduce.return %[[RES]] : i64
// CHECK: }
// CHECK: }
+
+#map_mod_8 = affine_map<(i) -> (i mod 8)>
+// CHECK-LABEL: func @affine_apply_mod_8
+func.func @affine_apply_mod_8(%arg0 : index) -> (index) {
+ // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
+ // CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index
+ %0 = affine.apply #map_mod_8 (%arg0)
+ return %0 : index
+}
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/Affine/Utils/Utils.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index de9c78747..0cffe52dd 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -81,21 +81,24 @@ public:
/// negative = a < 0 in
/// select negative, remainder + b, remainder.
///
- /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative x.
+ /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative
+ /// x.
Value visitModExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
}
-
- // Special case: x mod n where n is a power of 2 can be optimized to x & (n-1)
+
+ // Special case: x mod n where n is a power of 2 can be optimized to x &
+ // (n-1)
int64_t rhsValue = rhsConst.getValue();
if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
auto lhs = visit(expr.getLHS());
assert(lhs && "unexpected affine expr lowering failure");
-
- Value maskCst = builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
+
+ Value maskCst =
+ builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
return builder.create<arith::AndIOp>(loc, lhs, maskCst);
}
}
|
9792917
to
db4de47
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we not do this on arith.remsi / arith.remui?
@Groverkss you mean add a canonicalizer to arith.remsi / arith.remui? In that case do we also need to handle cmp + select? The case I have in mind is incrementing (I just noticed my change didn't check for non-negativity of LHS, oops...) |
This patch adds a special-case optimization in the affine-to-standard lowering pass to replace modulo operations by constant powers of two with a single bitwise AND operation. This reduces instruction count and improves performance for common cases like
x mod 2
.