diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 8e9d7e0b72142..09e3945f5a8ff 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -64,6 +64,9 @@ inline class_match m_SCEV() { return class_match(); } inline class_match m_SCEVConstant() { return class_match(); } +inline class_match m_SCEVVScale() { + return class_match(); +} template struct bind_ty { Class *&VR; diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 4ba69034d6448..786e4516ace05 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -923,10 +923,11 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, /// If S involves the addition of a constant integer value, return that integer /// value, and mutate S to point to a new SCEV with that value excluded. static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { - if (const SCEVConstant *C = dyn_cast(S)) { - if (C->getAPInt().getSignificantBits() <= 64) { - S = SE.getConstant(C->getType(), 0); - return Immediate::getFixed(C->getValue()->getSExtValue()); + const APInt *C; + if (match(S, m_scev_APInt(C))) { + if (C->getSignificantBits() <= 64) { + S = SE.getConstant(S->getType(), 0); + return Immediate::getFixed(C->getSExtValue()); } } else if (const SCEVAddExpr *Add = dyn_cast(S)) { SmallVector NewOps(Add->operands()); @@ -942,14 +943,10 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) SCEV::FlagAnyWrap); return Result; - } else if (const SCEVMulExpr *M = dyn_cast(S)) { - if (EnableVScaleImmediates && M->getNumOperands() == 2) { - if (const SCEVConstant *C = dyn_cast(M->getOperand(0))) - if (isa(M->getOperand(1))) { - S = SE.getConstant(M->getType(), 0); - return Immediate::getScalable(C->getValue()->getSExtValue()); - } - } + } else if (EnableVScaleImmediates && + match(S, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale()))) { + S = SE.getConstant(S->getType(), 0); + return Immediate::getScalable(C->getSExtValue()); } return Immediate::getZero(); } @@ -1133,23 +1130,22 @@ static bool isHighCostExpansion(const SCEV *S, return false; } - if (const SCEVMulExpr *Mul = dyn_cast(S)) { - if (Mul->getNumOperands() == 2) { - // Multiplication by a constant is ok - if (isa(Mul->getOperand(0))) - return isHighCostExpansion(Mul->getOperand(1), Processed, SE); - - // If we have the value of one operand, check if an existing - // multiplication already generates this expression. - if (const SCEVUnknown *U = dyn_cast(Mul->getOperand(1))) { - Value *UVal = U->getValue(); - for (User *UR : UVal->users()) { - // If U is a constant, it may be used by a ConstantExpr. - Instruction *UI = dyn_cast(UR); - if (UI && UI->getOpcode() == Instruction::Mul && - SE.isSCEVable(UI->getType())) { - return SE.getSCEV(UI) == Mul; - } + const SCEV *Op0, *Op1; + if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) { + // Multiplication by a constant is ok + if (isa(Op0)) + return isHighCostExpansion(Op1, Processed, SE); + + // If we have the value of one operand, check if an existing + // multiplication already generates this expression. + if (const auto *U = dyn_cast(Op1)) { + Value *UVal = U->getValue(); + for (User *UR : UVal->users()) { + // If U is a constant, it may be used by a ConstantExpr. + Instruction *UI = dyn_cast(UR); + if (UI && UI->getOpcode() == Instruction::Mul && + SE.isSCEVable(UI->getType())) { + return SE.getSCEV(UI) == S; } } } @@ -3333,14 +3329,11 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue()); } else { // Look for mul(vscale, constant), to detect a scalable offset. - auto *IncVScale = dyn_cast(IncExpr); - if (!IncVScale || IncVScale->getNumOperands() != 2 || - !isa(IncVScale->getOperand(1))) + const APInt *C; + if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale())) || + C->getSignificantBits() > 64) return false; - auto *Scale = dyn_cast(IncVScale->getOperand(0)); - if (!Scale || Scale->getType()->getScalarSizeInBits() > 64) - return false; - IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue()); + IncOffset = Immediate::getScalable(C->getSExtValue()); } if (!isAddressUse(TTI, UserInst, Operand)) @@ -3818,6 +3811,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, return nullptr; } const SCEV *Start, *Step; + const SCEVConstant *Op0; + const SCEV *Op1; if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) { // Split a non-zero base out of an addrec. if (Start->isZero()) @@ -3839,19 +3834,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) SCEV::FlagAnyWrap); } - } else if (const SCEVMulExpr *Mul = dyn_cast(S)) { + } else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) { // Break (C * (a + b + c)) into C*a + C*b + C*c. - if (Mul->getNumOperands() != 2) - return S; - if (const SCEVConstant *Op0 = - dyn_cast(Mul->getOperand(0))) { - C = C ? cast(SE.getMulExpr(C, Op0)) : Op0; - const SCEV *Remainder = - CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1); - if (Remainder) - Ops.push_back(SE.getMulExpr(C, Remainder)); - return nullptr; - } + C = C ? cast(SE.getMulExpr(C, Op0)) : Op0; + const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1); + if (Remainder) + Ops.push_back(SE.getMulExpr(C, Remainder)); + return nullptr; } return S; }