From f5b22632e74040da26279379b1564a695d0e7207 Mon Sep 17 00:00:00 2001 From: Vasileios Porpodas Date: Fri, 23 Aug 2024 10:38:50 -0700 Subject: [PATCH] [SandboxIR] Implement GlobalIFunc This patch implements sandboxir::GlobalIFunc mirroring llvm::GlobalIFunc. --- llvm/include/llvm/SandboxIR/SandboxIR.h | 90 +++++++++++++- llvm/lib/SandboxIR/SandboxIR.cpp | 37 ++++++ llvm/unittests/SandboxIR/SandboxIRTest.cpp | 137 ++++++++++++++++++--- llvm/unittests/SandboxIR/TrackerTest.cpp | 32 +++++ 4 files changed, 279 insertions(+), 17 deletions(-) diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 24c34466b4415..624309def4df9 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -128,6 +128,7 @@ class DSOLocalEquivalent; class ConstantTokenNone; class GlobalValue; class GlobalObject; +class GlobalIFunc; class Context; class Function; class Instruction; @@ -332,6 +333,7 @@ class Value { friend class GlobalValue; // For `Val`. friend class DSOLocalEquivalent; // For `Val`. friend class GlobalObject; // For `Val`. + friend class GlobalIFunc; // For `Val`. /// All values point to the context. Context &Ctx; @@ -1128,6 +1130,7 @@ class GlobalValue : public Constant { friend class Context; // For constructor. public: + using LinkageTypes = llvm::GlobalValue::LinkageTypes; /// For isa/dyn_cast. static bool classof(const sandboxir::Value *From) { switch (From->getSubclassID()) { @@ -1285,6 +1288,88 @@ class GlobalObject : public GlobalValue { } }; +/// Provides API functions, like getIterator() and getReverseIterator() to +/// GlobalIFunc, Function, GlobalVariable and GlobalAlias. In LLVM IR these are +/// provided by ilist_node. +template +class GlobalWithNodeAPI : public ParentT { + /// Helper for mapped_iterator. + struct LLVMGVToGV { + Context &Ctx; + LLVMGVToGV(Context &Ctx) : Ctx(Ctx) {} + GlobalT &operator()(LLVMGlobalT &LLVMGV) const; + }; + +public: + GlobalWithNodeAPI(Value::ClassID ID, LLVMParentT *C, Context &Ctx) + : ParentT(ID, C, Ctx) {} + + // TODO: Missing getParent(). Should be added once Module is available. + + using iterator = mapped_iterator< + decltype(static_cast(nullptr)->getIterator()), LLVMGVToGV>; + using reverse_iterator = mapped_iterator< + decltype(static_cast(nullptr)->getReverseIterator()), + LLVMGVToGV>; + iterator getIterator() const { + auto *LLVMGV = cast(this->Val); + LLVMGVToGV ToGV(this->Ctx); + return map_iterator(LLVMGV->getIterator(), ToGV); + } + reverse_iterator getReverseIterator() const { + auto *LLVMGV = cast(this->Val); + LLVMGVToGV ToGV(this->Ctx); + return map_iterator(LLVMGV->getReverseIterator(), ToGV); + } +}; + +class GlobalIFunc final + : public GlobalWithNodeAPI { + GlobalIFunc(llvm::GlobalObject *C, Context &Ctx) + : GlobalWithNodeAPI(ClassID::GlobalIFunc, C, Ctx) {} + friend class Context; // For constructor. + +public: + /// For isa/dyn_cast. + static bool classof(const sandboxir::Value *From) { + return From->getSubclassID() == ClassID::GlobalIFunc; + } + + // TODO: Missing create() because we don't have a sandboxir::Module yet. + + // TODO: Missing functions: copyAttributesFrom(), removeFromParent(), + // eraseFromParent() + + void setResolver(Constant *Resolver); + + Constant *getResolver() const; + + // Return the resolver function after peeling off potential ConstantExpr + // indirection. + Function *getResolverFunction(); + const Function *getResolverFunction() const { + return const_cast(this)->getResolverFunction(); + } + + static bool isValidLinkage(LinkageTypes L) { + return llvm::GlobalIFunc::isValidLinkage(L); + } + + // TODO: Missing applyAlongResolverPath(). + +#ifndef NDEBUG + void verify() const override { + assert(isa(Val) && "Expected a GlobalIFunc!"); + } + void dumpOS(raw_ostream &OS) const override { + dumpCommonPrefix(OS); + dumpCommonSuffix(OS); + } +#endif +}; + class BlockAddress final : public Constant { BlockAddress(llvm::BlockAddress *C, Context &Ctx) : Constant(ClassID::BlockAddress, C, Ctx) {} @@ -4219,7 +4304,8 @@ class Context { size_t getNumValues() const { return LLVMValueToValueMap.size(); } }; -class Function : public GlobalObject { +class Function : public GlobalWithNodeAPI { /// Helper for mapped_iterator. struct LLVMBBToBB { Context &Ctx; @@ -4230,7 +4316,7 @@ class Function : public GlobalObject { }; /// Use Context::createFunction() instead. Function(llvm::Function *F, sandboxir::Context &Ctx) - : GlobalObject(ClassID::Function, F, Ctx) {} + : GlobalWithNodeAPI(ClassID::Function, F, Ctx) {} friend class Context; // For constructor. public: diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index 2f20fd3ff1dcc..03d3e9e607f01 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -2519,6 +2519,39 @@ void GlobalObject::setSection(StringRef S) { cast(Val)->setSection(S); } +template +GlobalT &GlobalWithNodeAPI:: + LLVMGVToGV::operator()(LLVMGlobalT &LLVMGV) const { + return cast(*Ctx.getValue(&LLVMGV)); +} + +namespace llvm::sandboxir { +// Explicit instantiations. +template class GlobalWithNodeAPI; +template class GlobalWithNodeAPI; +} // namespace llvm::sandboxir + +void GlobalIFunc::setResolver(Constant *Resolver) { + Ctx.getTracker() + .emplaceIfTracking< + GenericSetter<&GlobalIFunc::getResolver, &GlobalIFunc::setResolver>>( + this); + cast(Val)->setResolver( + cast(Resolver->Val)); +} + +Constant *GlobalIFunc::getResolver() const { + return Ctx.getOrCreateConstant(cast(Val)->getResolver()); +} + +Function *GlobalIFunc::getResolverFunction() { + return cast(Ctx.getOrCreateConstant( + cast(Val)->getResolverFunction())); +} + void GlobalValue::setUnnamedAddr(UnnamedAddr V) { Ctx.getTracker() .emplaceIfTrackingsecond = std::unique_ptr( new Function(cast(C), *this)); break; + case llvm::Value::GlobalIFuncVal: + It->second = std::unique_ptr( + new GlobalIFunc(cast(C), *this)); + break; default: It->second = std::unique_ptr(new Constant(C, *this)); break; diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index b1f3a6c0cf550..3b80dbd8fb66e 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -859,6 +859,84 @@ define void @foo() { EXPECT_EQ(GO->canIncreaseAlignment(), LLVMGO->canIncreaseAlignment()); } +TEST_F(SandboxIRTest, GlobalIFunc) { + parseIR(C, R"IR( +declare external void @bar() +@ifunc0 = ifunc void(), ptr @foo +@ifunc1 = ifunc void(), ptr @foo +define void @foo() { + call void @ifunc0() + call void @ifunc1() + call void @bar() + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + auto *LLVMBB = &*LLVMF.begin(); + auto LLVMIt = LLVMBB->begin(); + auto *LLVMCall0 = cast(&*LLVMIt++); + auto *LLVMIFunc0 = cast(LLVMCall0->getCalledOperand()); + + sandboxir::Context Ctx(C); + + auto &F = *Ctx.createFunction(&LLVMF); + auto *BB = &*F.begin(); + auto It = BB->begin(); + auto *Call0 = cast(&*It++); + auto *Call1 = cast(&*It++); + auto *CallBar = cast(&*It++); + // Check classof(), creation. + auto *IFunc0 = cast(Call0->getCalledOperand()); + auto *IFunc1 = cast(Call1->getCalledOperand()); + auto *Bar = cast(CallBar->getCalledOperand()); + + // Check getIterator(). + { + auto It0 = IFunc0->getIterator(); + auto It1 = IFunc1->getIterator(); + EXPECT_EQ(&*It0, IFunc0); + EXPECT_EQ(&*It1, IFunc1); + EXPECT_EQ(std::next(It0), It1); + EXPECT_EQ(std::prev(It1), It0); + EXPECT_EQ(&*std::next(It0), IFunc1); + EXPECT_EQ(&*std::prev(It1), IFunc0); + } + // Check getReverseIterator(). + { + auto RevIt0 = IFunc0->getReverseIterator(); + auto RevIt1 = IFunc1->getReverseIterator(); + EXPECT_EQ(&*RevIt0, IFunc0); + EXPECT_EQ(&*RevIt1, IFunc1); + EXPECT_EQ(std::prev(RevIt0), RevIt1); + EXPECT_EQ(std::next(RevIt1), RevIt0); + EXPECT_EQ(&*std::prev(RevIt0), IFunc1); + EXPECT_EQ(&*std::next(RevIt1), IFunc0); + } + + // Check setResolver(), getResolver(). + EXPECT_EQ(IFunc0->getResolver(), Ctx.getValue(LLVMIFunc0->getResolver())); + auto *OrigResolver = IFunc0->getResolver(); + auto *NewResolver = Bar; + EXPECT_NE(NewResolver, OrigResolver); + IFunc0->setResolver(NewResolver); + EXPECT_EQ(IFunc0->getResolver(), NewResolver); + IFunc0->setResolver(OrigResolver); + EXPECT_EQ(IFunc0->getResolver(), OrigResolver); + // Check getResolverFunction(). + EXPECT_EQ(IFunc0->getResolverFunction(), + Ctx.getValue(LLVMIFunc0->getResolverFunction())); + // Check isValidLinkage(). + for (auto L : + {GlobalValue::ExternalLinkage, GlobalValue::AvailableExternallyLinkage, + GlobalValue::LinkOnceAnyLinkage, GlobalValue::LinkOnceODRLinkage, + GlobalValue::WeakAnyLinkage, GlobalValue::WeakODRLinkage, + GlobalValue::AppendingLinkage, GlobalValue::InternalLinkage, + GlobalValue::PrivateLinkage, GlobalValue::ExternalWeakLinkage, + GlobalValue::CommonLinkage}) { + EXPECT_EQ(IFunc0->isValidLinkage(L), LLVMIFunc0->isValidLinkage(L)); + } +} + TEST_F(SandboxIRTest, BlockAddress) { parseIR(C, R"IR( define void @foo(ptr %ptr) { @@ -1200,29 +1278,58 @@ define void @foo(i8 %v) { TEST_F(SandboxIRTest, Function) { parseIR(C, R"IR( -define void @foo(i32 %arg0, i32 %arg1) { +define void @foo0(i32 %arg0, i32 %arg1) { bb0: br label %bb1 bb1: ret void } +define void @foo1() { + ret void +} + )IR"); - llvm::Function *LLVMF = &*M->getFunction("foo"); - llvm::Argument *LLVMArg0 = LLVMF->getArg(0); - llvm::Argument *LLVMArg1 = LLVMF->getArg(1); + llvm::Function *LLVMF0 = &*M->getFunction("foo0"); + llvm::Function *LLVMF1 = &*M->getFunction("foo1"); + llvm::Argument *LLVMArg0 = LLVMF0->getArg(0); + llvm::Argument *LLVMArg1 = LLVMF0->getArg(1); sandboxir::Context Ctx(C); - sandboxir::Function *F = Ctx.createFunction(LLVMF); + sandboxir::Function *F0 = Ctx.createFunction(LLVMF0); + sandboxir::Function *F1 = Ctx.createFunction(LLVMF1); + + // Check getIterator(). + { + auto It0 = F0->getIterator(); + auto It1 = F1->getIterator(); + EXPECT_EQ(&*It0, F0); + EXPECT_EQ(&*It1, F1); + EXPECT_EQ(std::next(It0), It1); + EXPECT_EQ(std::prev(It1), It0); + EXPECT_EQ(&*std::next(It0), F1); + EXPECT_EQ(&*std::prev(It1), F0); + } + // Check getReverseIterator(). + { + auto RevIt0 = F0->getReverseIterator(); + auto RevIt1 = F1->getReverseIterator(); + EXPECT_EQ(&*RevIt0, F0); + EXPECT_EQ(&*RevIt1, F1); + EXPECT_EQ(std::prev(RevIt0), RevIt1); + EXPECT_EQ(std::next(RevIt1), RevIt0); + EXPECT_EQ(&*std::prev(RevIt0), F1); + EXPECT_EQ(&*std::next(RevIt1), F0); + } // Check F arguments - EXPECT_EQ(F->arg_size(), 2u); - EXPECT_FALSE(F->arg_empty()); - EXPECT_EQ(F->getArg(0), Ctx.getValue(LLVMArg0)); - EXPECT_EQ(F->getArg(1), Ctx.getValue(LLVMArg1)); + EXPECT_EQ(F0->arg_size(), 2u); + EXPECT_FALSE(F0->arg_empty()); + EXPECT_EQ(F0->getArg(0), Ctx.getValue(LLVMArg0)); + EXPECT_EQ(F0->getArg(1), Ctx.getValue(LLVMArg1)); // Check F.begin(), F.end(), Function::iterator - llvm::BasicBlock *LLVMBB = &*LLVMF->begin(); - for (sandboxir::BasicBlock &BB : *F) { + llvm::BasicBlock *LLVMBB = &*LLVMF0->begin(); + for (sandboxir::BasicBlock &BB : *F0) { EXPECT_EQ(&BB, Ctx.getValue(LLVMBB)); LLVMBB = LLVMBB->getNextNode(); } @@ -1232,17 +1339,17 @@ define void @foo(i32 %arg0, i32 %arg1) { // Check F.dumpNameAndArgs() std::string Buff; raw_string_ostream BS(Buff); - F->dumpNameAndArgs(BS); - EXPECT_EQ(Buff, "void @foo(i32 %arg0, i32 %arg1)"); + F0->dumpNameAndArgs(BS); + EXPECT_EQ(Buff, "void @foo0(i32 %arg0, i32 %arg1)"); } { // Check F.dump() std::string Buff; raw_string_ostream BS(Buff); BS << "\n"; - F->dumpOS(BS); + F0->dumpOS(BS); EXPECT_EQ(Buff, R"IR( -void @foo(i32 %arg0, i32 %arg1) { +void @foo0(i32 %arg0, i32 %arg1) { bb0: br label %bb1 ; SB4. (Br) diff --git a/llvm/unittests/SandboxIR/TrackerTest.cpp b/llvm/unittests/SandboxIR/TrackerTest.cpp index 6454c54336e6a..d4ff4fd6464e5 100644 --- a/llvm/unittests/SandboxIR/TrackerTest.cpp +++ b/llvm/unittests/SandboxIR/TrackerTest.cpp @@ -1558,6 +1558,38 @@ define void @foo() { EXPECT_EQ(GV->getVisibility(), OrigVisibility); } +TEST_F(TrackerTest, GlobalIFuncSetters) { + parseIR(C, R"IR( +declare external void @bar() +@ifunc = ifunc void(), ptr @foo +define void @foo() { + call void @ifunc() + call void @bar() + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + + auto &F = *Ctx.createFunction(&LLVMF); + auto *BB = &*F.begin(); + auto It = BB->begin(); + auto *Call0 = cast(&*It++); + auto *Call1 = cast(&*It++); + // Check classof(), creation. + auto *IFunc = cast(Call0->getCalledOperand()); + auto *Bar = cast(Call1->getCalledOperand()); + // Check setResolver(). + auto *OrigResolver = IFunc->getResolver(); + auto *NewResolver = Bar; + EXPECT_NE(NewResolver, OrigResolver); + Ctx.save(); + IFunc->setResolver(NewResolver); + EXPECT_EQ(IFunc->getResolver(), NewResolver); + Ctx.revert(); + EXPECT_EQ(IFunc->getResolver(), OrigResolver); +} + TEST_F(TrackerTest, SetVolatile) { parseIR(C, R"IR( define void @foo(ptr %arg0, i8 %val) {