From eebabb8defb3b504a882bf27cb466997b123c7d0 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Tue, 24 Jun 2025 18:12:09 +0100 Subject: [PATCH 1/3] [LSR] Clean up code using SCEVPatternMatch (NFC) --- .../Transforms/Scalar/LoopStrengthReduce.cpp | 111 ++++++++---------- 1 file changed, 48 insertions(+), 63 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 4ba69034d6448..9ffcfb47b0a89 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -923,10 +923,12 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, /// If S involves the addition of a constant integer value, return that integer /// value, and mutate S to point to a new SCEV with that value excluded. static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { - if (const SCEVConstant *C = dyn_cast(S)) { - if (C->getAPInt().getSignificantBits() <= 64) { - S = SE.getConstant(C->getType(), 0); - return Immediate::getFixed(C->getValue()->getSExtValue()); + const APInt *C; + const SCEV *Op1; + if (match(S, m_scev_APInt(C))) { + if (C->getSignificantBits() <= 64) { + S = SE.getConstant(S->getType(), 0); + return Immediate::getFixed(C->getSExtValue()); } } else if (const SCEVAddExpr *Add = dyn_cast(S)) { SmallVector NewOps(Add->operands()); @@ -942,14 +944,11 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) SCEV::FlagAnyWrap); return Result; - } else if (const SCEVMulExpr *M = dyn_cast(S)) { - if (EnableVScaleImmediates && M->getNumOperands() == 2) { - if (const SCEVConstant *C = dyn_cast(M->getOperand(0))) - if (isa(M->getOperand(1))) { - S = SE.getConstant(M->getType(), 0); - return Immediate::getScalable(C->getValue()->getSExtValue()); - } - } + } else if (EnableVScaleImmediates && + match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) && + isa(Op1)) { + S = SE.getConstant(S->getType(), 0); + return Immediate::getScalable(C->getSExtValue()); } return Immediate::getZero(); } @@ -1133,23 +1132,22 @@ static bool isHighCostExpansion(const SCEV *S, return false; } - if (const SCEVMulExpr *Mul = dyn_cast(S)) { - if (Mul->getNumOperands() == 2) { - // Multiplication by a constant is ok - if (isa(Mul->getOperand(0))) - return isHighCostExpansion(Mul->getOperand(1), Processed, SE); - - // If we have the value of one operand, check if an existing - // multiplication already generates this expression. - if (const SCEVUnknown *U = dyn_cast(Mul->getOperand(1))) { - Value *UVal = U->getValue(); - for (User *UR : UVal->users()) { - // If U is a constant, it may be used by a ConstantExpr. - Instruction *UI = dyn_cast(UR); - if (UI && UI->getOpcode() == Instruction::Mul && - SE.isSCEVable(UI->getType())) { - return SE.getSCEV(UI) == Mul; - } + const SCEV *Op0, *Op1; + if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) { + // Multiplication by a constant is ok + if (isa(Op0)) + return isHighCostExpansion(Op1, Processed, SE); + + // If we have the value of one operand, check if an existing + // multiplication already generates this expression. + if (const auto *U = dyn_cast(Op1)) { + Value *UVal = U->getValue(); + for (User *UR : UVal->users()) { + // If U is a constant, it may be used by a ConstantExpr. + Instruction *UI = dyn_cast(UR); + if (UI && UI->getOpcode() == Instruction::Mul && + SE.isSCEVable(UI->getType())) { + return SE.getSCEV(UI) == S; } } } @@ -3333,14 +3331,12 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue()); } else { // Look for mul(vscale, constant), to detect a scalable offset. - auto *IncVScale = dyn_cast(IncExpr); - if (!IncVScale || IncVScale->getNumOperands() != 2 || - !isa(IncVScale->getOperand(1))) - return false; - auto *Scale = dyn_cast(IncVScale->getOperand(0)); - if (!Scale || Scale->getType()->getScalarSizeInBits() > 64) + const APInt *C; + const SCEV *Op1; + if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) || + !isa(Op1) || C->getSignificantBits() > 64) return false; - IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue()); + IncOffset = Immediate::getScalable(C->getSExtValue()); } if (!isAddressUse(TTI, UserInst, Operand)) @@ -3818,6 +3814,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, return nullptr; } const SCEV *Start, *Step; + const SCEVConstant *Op0; + const SCEV *Op1; if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) { // Split a non-zero base out of an addrec. if (Start->isZero()) @@ -3839,19 +3837,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) SCEV::FlagAnyWrap); } - } else if (const SCEVMulExpr *Mul = dyn_cast(S)) { + } else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) { // Break (C * (a + b + c)) into C*a + C*b + C*c. - if (Mul->getNumOperands() != 2) - return S; - if (const SCEVConstant *Op0 = - dyn_cast(Mul->getOperand(0))) { - C = C ? cast(SE.getMulExpr(C, Op0)) : Op0; - const SCEV *Remainder = - CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1); - if (Remainder) - Ops.push_back(SE.getMulExpr(C, Remainder)); - return nullptr; - } + C = C ? cast(SE.getMulExpr(C, Op0)) : Op0; + const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1); + if (Remainder) + Ops.push_back(SE.getMulExpr(C, Remainder)); + return nullptr; } return S; } @@ -6478,13 +6470,10 @@ struct SCEVDbgValueBuilder { /// Components of the expression are omitted if they are an identity function. /// Chain (non-affine) SCEVs are not supported. bool SCEVToValueExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) { - assert(SAR.isAffine() && "Expected affine SCEV"); - // TODO: Is this check needed? - if (isa(SAR.getStart())) - return false; - - const SCEV *Start = SAR.getStart(); - const SCEV *Stride = SAR.getStepRecurrence(SE); + const SCEV *Start, *Stride; + [[maybe_unused]] bool Match = + match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride))); + assert(Match && "Expected affine SCEV"); // Skip pushing arithmetic noops. if (!isIdentityFunction(llvm::dwarf::DW_OP_mul, Stride)) { @@ -6549,14 +6538,10 @@ struct SCEVDbgValueBuilder { /// Components of the expression are omitted if they are an identity function. bool SCEVToIterCountExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) { - assert(SAR.isAffine() && "Expected affine SCEV"); - if (isa(SAR.getStart())) { - LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV. Unsupported nested AddRec: " - << SAR << '\n'); - return false; - } - const SCEV *Start = SAR.getStart(); - const SCEV *Stride = SAR.getStepRecurrence(SE); + const SCEV *Start, *Stride; + [[maybe_unused]] bool Match = + match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride))); + assert(Match && "Expected affine SCEV"); // Skip pushing arithmetic noops. if (!isIdentityFunction(llvm::dwarf::DW_OP_minus, Start)) { From 1d0a32ec3fa30e36349a5ddac998fab625237094 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Fri, 27 Jun 2025 14:49:41 +0100 Subject: [PATCH 2/3] [LSR/SCEVPM] Introduce and use m_SCEVVScale --- llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 3 +++ llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 9 +++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 8e9d7e0b72142..09e3945f5a8ff 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -64,6 +64,9 @@ inline class_match m_SCEV() { return class_match(); } inline class_match m_SCEVConstant() { return class_match(); } +inline class_match m_SCEVVScale() { + return class_match(); +} template struct bind_ty { Class *&VR; diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 9ffcfb47b0a89..dcaa3a22638e0 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -924,7 +924,6 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, /// value, and mutate S to point to a new SCEV with that value excluded. static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { const APInt *C; - const SCEV *Op1; if (match(S, m_scev_APInt(C))) { if (C->getSignificantBits() <= 64) { S = SE.getConstant(S->getType(), 0); @@ -945,8 +944,7 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { SCEV::FlagAnyWrap); return Result; } else if (EnableVScaleImmediates && - match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) && - isa(Op1)) { + match(S, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale()))) { S = SE.getConstant(S->getType(), 0); return Immediate::getScalable(C->getSExtValue()); } @@ -3332,9 +3330,8 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, } else { // Look for mul(vscale, constant), to detect a scalable offset. const APInt *C; - const SCEV *Op1; - if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op1))) || - !isa(Op1) || C->getSignificantBits() > 64) + if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale())) || + C->getSignificantBits() > 64) return false; IncOffset = Immediate::getScalable(C->getSExtValue()); } From 067bb35b9c0af0a93c5afcb7d97988cbfb115d19 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Fri, 27 Jun 2025 17:07:00 +0100 Subject: [PATCH 3/3] [LSR] Revert unpack-with-m_scev_AffineAddRec --- .../Transforms/Scalar/LoopStrengthReduce.cpp | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index dcaa3a22638e0..786e4516ace05 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6467,10 +6467,13 @@ struct SCEVDbgValueBuilder { /// Components of the expression are omitted if they are an identity function. /// Chain (non-affine) SCEVs are not supported. bool SCEVToValueExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) { - const SCEV *Start, *Stride; - [[maybe_unused]] bool Match = - match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride))); - assert(Match && "Expected affine SCEV"); + assert(SAR.isAffine() && "Expected affine SCEV"); + // TODO: Is this check needed? + if (isa(SAR.getStart())) + return false; + + const SCEV *Start = SAR.getStart(); + const SCEV *Stride = SAR.getStepRecurrence(SE); // Skip pushing arithmetic noops. if (!isIdentityFunction(llvm::dwarf::DW_OP_mul, Stride)) { @@ -6535,10 +6538,14 @@ struct SCEVDbgValueBuilder { /// Components of the expression are omitted if they are an identity function. bool SCEVToIterCountExpr(const llvm::SCEVAddRecExpr &SAR, ScalarEvolution &SE) { - const SCEV *Start, *Stride; - [[maybe_unused]] bool Match = - match(&SAR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Stride))); - assert(Match && "Expected affine SCEV"); + assert(SAR.isAffine() && "Expected affine SCEV"); + if (isa(SAR.getStart())) { + LLVM_DEBUG(dbgs() << "scev-salvage: IV SCEV. Unsupported nested AddRec: " + << SAR << '\n'); + return false; + } + const SCEV *Start = SAR.getStart(); + const SCEV *Stride = SAR.getStepRecurrence(SE); // Skip pushing arithmetic noops. if (!isIdentityFunction(llvm::dwarf::DW_OP_minus, Start)) {