Skip to content

[LSR] Clean up code using SCEVPatternMatch (NFC) #145556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
inline class_match<const SCEVConstant> m_SCEVConstant() {
return class_match<const SCEVConstant>();
}
inline class_match<const SCEVVScale> m_SCEVVScale() {
return class_match<const SCEVVScale>();
}

template <typename Class> struct bind_ty {
Class *&VR;
Expand Down
85 changes: 37 additions & 48 deletions llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,10 +923,11 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
/// If S involves the addition of a constant integer value, return that integer
/// value, and mutate S to point to a new SCEV with that value excluded.
static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
if (C->getAPInt().getSignificantBits() <= 64) {
S = SE.getConstant(C->getType(), 0);
return Immediate::getFixed(C->getValue()->getSExtValue());
const APInt *C;
if (match(S, m_scev_APInt(C))) {
if (C->getSignificantBits() <= 64) {
S = SE.getConstant(S->getType(), 0);
return Immediate::getFixed(C->getSExtValue());
}
} else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
SmallVector<const SCEV *, 8> NewOps(Add->operands());
Expand All @@ -942,14 +943,10 @@ static Immediate ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
return Result;
} else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
if (EnableVScaleImmediates && M->getNumOperands() == 2) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
if (isa<SCEVVScale>(M->getOperand(1))) {
S = SE.getConstant(M->getType(), 0);
return Immediate::getScalable(C->getValue()->getSExtValue());
}
}
} else if (EnableVScaleImmediates &&
match(S, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale()))) {
S = SE.getConstant(S->getType(), 0);
return Immediate::getScalable(C->getSExtValue());
}
return Immediate::getZero();
}
Expand Down Expand Up @@ -1133,23 +1130,22 @@ static bool isHighCostExpansion(const SCEV *S,
return false;
}

if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
if (Mul->getNumOperands() == 2) {
// Multiplication by a constant is ok
if (isa<SCEVConstant>(Mul->getOperand(0)))
return isHighCostExpansion(Mul->getOperand(1), Processed, SE);

// If we have the value of one operand, check if an existing
// multiplication already generates this expression.
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Mul->getOperand(1))) {
Value *UVal = U->getValue();
for (User *UR : UVal->users()) {
// If U is a constant, it may be used by a ConstantExpr.
Instruction *UI = dyn_cast<Instruction>(UR);
if (UI && UI->getOpcode() == Instruction::Mul &&
SE.isSCEVable(UI->getType())) {
return SE.getSCEV(UI) == Mul;
}
const SCEV *Op0, *Op1;
if (match(S, m_scev_Mul(m_SCEV(Op0), m_SCEV(Op1)))) {
// Multiplication by a constant is ok
if (isa<SCEVConstant>(Op0))
return isHighCostExpansion(Op1, Processed, SE);

// If we have the value of one operand, check if an existing
// multiplication already generates this expression.
if (const auto *U = dyn_cast<SCEVUnknown>(Op1)) {
Value *UVal = U->getValue();
for (User *UR : UVal->users()) {
// If U is a constant, it may be used by a ConstantExpr.
Instruction *UI = dyn_cast<Instruction>(UR);
if (UI && UI->getOpcode() == Instruction::Mul &&
SE.isSCEVable(UI->getType())) {
return SE.getSCEV(UI) == S;
}
}
}
Expand Down Expand Up @@ -3333,14 +3329,11 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
IncOffset = Immediate::getFixed(IncConst->getValue()->getSExtValue());
} else {
// Look for mul(vscale, constant), to detect a scalable offset.
auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
if (!IncVScale || IncVScale->getNumOperands() != 2 ||
!isa<SCEVVScale>(IncVScale->getOperand(1)))
const APInt *C;
if (!match(IncExpr, m_scev_Mul(m_scev_APInt(C), m_SCEVVScale())) ||
C->getSignificantBits() > 64)
return false;
auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
return false;
IncOffset = Immediate::getScalable(Scale->getValue()->getSExtValue());
IncOffset = Immediate::getScalable(C->getSExtValue());
}

if (!isAddressUse(TTI, UserInst, Operand))
Expand Down Expand Up @@ -3818,6 +3811,8 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
return nullptr;
}
const SCEV *Start, *Step;
const SCEVConstant *Op0;
const SCEV *Op1;
if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
// Split a non-zero base out of an addrec.
if (Start->isZero())
Expand All @@ -3839,19 +3834,13 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
SCEV::FlagAnyWrap);
}
} else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
} else if (match(S, m_scev_Mul(m_SCEVConstant(Op0), m_SCEV(Op1)))) {
// Break (C * (a + b + c)) into C*a + C*b + C*c.
if (Mul->getNumOperands() != 2)
return S;
if (const SCEVConstant *Op0 =
dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
const SCEV *Remainder =
CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1);
if (Remainder)
Ops.push_back(SE.getMulExpr(C, Remainder));
return nullptr;
}
C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
const SCEV *Remainder = CollectSubexprs(Op1, C, Ops, L, SE, Depth + 1);
if (Remainder)
Ops.push_back(SE.getMulExpr(C, Remainder));
return nullptr;
}
return S;
}
Expand Down