From 31588d29e9299dfcbf6909fd0bcd71dca73e3e8d Mon Sep 17 00:00:00 2001 From: Paul Kirth Date: Tue, 26 Mar 2024 00:49:00 +0000 Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?= =?UTF-8?q?itial=20version?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.4 --- llvm/include/llvm/IR/ProfDataUtils.h | 9 +++- llvm/lib/IR/ProfDataUtils.cpp | 45 +++++++++++++------ .../Transforms/Utils/LoopRotationUtils.cpp | 2 +- llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 7 +-- 4 files changed, 41 insertions(+), 22 deletions(-) diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index 255fa2ff1c790..dc983eed13a8d 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -65,10 +65,15 @@ bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights); /// Faster version of extractBranchWeights() that skips checks and must only -/// be called with "branch_weights" metadata nodes. -void extractFromBranchWeightMD(const MDNode *ProfileData, +/// be called with "branch_weights" metadata nodes. Supports uint32_t. +void extractFromBranchWeightMD32(const MDNode *ProfileData, SmallVectorImpl &Weights); +/// Faster version of extractBranchWeights() that skips checks and must only +/// be called with "branch_weights" metadata nodes. Supports uint64_t. +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl &Weights); + /// Extract branch weights attatched to an Instruction /// /// \param I The Instruction to extract weights from. diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index b1a10d0ce5a52..b4e09e76993f9 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { return ProfDataName->getString().equals(Name); } +template >> +static void extractFromBranchWeightMD(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + assert(isBranchWeightMD(ProfileData) && "wrong metadata"); + + unsigned NOps = ProfileData->getNumOperands(); + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + + for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { + ConstantInt *Weight = + mdconst::dyn_extract(ProfileData->getOperand(Idx)); + assert(Weight && "Malformed branch_weight in MD_prof node"); + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights[Idx - WeightsIdx] = Weight->getZExtValue(); + } +} + } // namespace namespace llvm { @@ -100,24 +120,21 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) { return nullptr; } -void extractFromBranchWeightMD(const MDNode *ProfileData, +void extractFromBranchWeightMD32(const MDNode *ProfileData, SmallVectorImpl &Weights) { - assert(isBranchWeightMD(ProfileData) && "wrong metadata"); - - unsigned NOps = ProfileData->getNumOperands(); - assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); - Weights.resize(NOps - WeightsIdx); + extractFromBranchWeightMD(ProfileData, Weights); +} - for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { - ConstantInt *Weight = - mdconst::dyn_extract(ProfileData->getOperand(Idx)); - assert(Weight && "Malformed branch_weight in MD_prof node"); - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights[Idx - WeightsIdx] = Weight->getZExtValue(); - } +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + extractFromBranchWeightMD(ProfileData, Weights); } + + + + + bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData)) diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index bc67117113719..f4b43ce370a5d 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, return; SmallVector Weights; - extractFromBranchWeightMD(WeightMD, Weights); + extractFromBranchWeightMD32(WeightMD, Weights); if (Weights.size() != 2) return; uint32_t OrigLoopExitWeight = Weights[0]; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 55bbffb18879f..a425e26d490e4 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, static void GetBranchWeights(Instruction *TI, SmallVectorImpl &Weights) { MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); - assert(MD); - for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt *CI = mdconst::extract(MD->getOperand(i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + assert(MD && "Invalid branch-weight metadata"); + extractFromBranchWeightMD64(MD, Weights); // If TI is a conditional eq, the default case is the false case, // and the corresponding branch-weight data is at index 2. We swap the From b69df3fade21cbae9a81cb6161b8f171e60762fd Mon Sep 17 00:00:00 2001 From: Paul Kirth Date: Tue, 26 Mar 2024 00:55:03 +0000 Subject: [PATCH 2/2] git clang-format Created using spr 1.3.4 --- llvm/include/llvm/IR/ProfDataUtils.h | 4 ++-- llvm/lib/IR/ProfDataUtils.cpp | 9 ++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index dc983eed13a8d..457ffdff8fe37 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -67,12 +67,12 @@ bool extractBranchWeights(const MDNode *ProfileData, /// Faster version of extractBranchWeights() that skips checks and must only /// be called with "branch_weights" metadata nodes. Supports uint32_t. void extractFromBranchWeightMD32(const MDNode *ProfileData, - SmallVectorImpl &Weights); + SmallVectorImpl &Weights); /// Faster version of extractBranchWeights() that skips checks and must only /// be called with "branch_weights" metadata nodes. Supports uint64_t. void extractFromBranchWeightMD64(const MDNode *ProfileData, - SmallVectorImpl &Weights); + SmallVectorImpl &Weights); /// Extract branch weights attatched to an Instruction /// diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index b4e09e76993f9..36e165e641f46 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -121,20 +121,15 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) { } void extractFromBranchWeightMD32(const MDNode *ProfileData, - SmallVectorImpl &Weights) { + SmallVectorImpl &Weights) { extractFromBranchWeightMD(ProfileData, Weights); } void extractFromBranchWeightMD64(const MDNode *ProfileData, - SmallVectorImpl &Weights) { + SmallVectorImpl &Weights) { extractFromBranchWeightMD(ProfileData, Weights); } - - - - - bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData))