diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 66b3f2a4f93a5..0cffe52dd0615 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -80,12 +80,27 @@ class AffineApplyExpander /// let remainder = srem a, b; /// negative = a < 0 in /// select negative, remainder + b, remainder. + /// + /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative + /// x. Value visitModExpr(AffineBinaryOpExpr expr) { if (auto rhsConst = dyn_cast(expr.getRHS())) { if (rhsConst.getValue() <= 0) { emitError(loc, "modulo by non-positive value is not supported"); return nullptr; } + + // Special case: x mod n where n is a power of 2 can be optimized to x & + // (n-1) + int64_t rhsValue = rhsConst.getValue(); + if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) { + auto lhs = visit(expr.getLHS()); + assert(lhs && "unexpected affine expr lowering failure"); + + Value maskCst = + builder.create(loc, rhsValue - 1); + return builder.create(loc, lhs, maskCst); + } } auto lhs = visit(expr.getLHS()); diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 550ea71882e14..07f7c64fe6ea5 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -927,3 +927,12 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me // CHECK: scf.reduce.return %[[RES]] : i64 // CHECK: } // CHECK: } + +#map_mod_8 = affine_map<(i) -> (i mod 8)> +// CHECK-LABEL: func @affine_apply_mod_8 +func.func @affine_apply_mod_8(%arg0 : index) -> (index) { + // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index + // CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index + %0 = affine.apply #map_mod_8 (%arg0) + return %0 : index +}