diff --git a/include/CppInterOp/CppInterOp.h b/include/CppInterOp/CppInterOp.h index 2f3d330db..7b428e4b2 100644 --- a/include/CppInterOp/CppInterOp.h +++ b/include/CppInterOp/CppInterOp.h @@ -224,6 +224,8 @@ class JitCall { ///\param[in] nary - Use array new if we have to construct an array of /// objects (nary > 1). ///\param[in] args - a pointer to a argument list and argument size. + ///\param[in] is_arena - a pointer that indicates if placement new is to be + /// used // FIXME: Change the type of withFree from int to bool in the wrapper code. void InvokeConstructor(void* result, unsigned long nary = 1, ArgList args = {}, void* is_arena = nullptr) const { @@ -849,7 +851,7 @@ CPPINTEROP_API void Deallocate(TCppScope_t scope, TCppObject_t address, /// Creates one or more objects of class \c scope by calling its default /// constructor. -/// \param[in] scope Class to construct +/// \param[in] scope Class to construct, or handle to Constructor /// \param[in] arena If set, this API uses placement new to construct at this /// address. /// \param[in] is used to indicate the number of objects to construct. diff --git a/lib/CppInterOp/CppInterOp.cpp b/lib/CppInterOp/CppInterOp.cpp index afbd0e368..cd776dbce 100755 --- a/lib/CppInterOp/CppInterOp.cpp +++ b/lib/CppInterOp/CppInterOp.cpp @@ -3784,22 +3784,25 @@ void Deallocate(TCppScope_t scope, TCppObject_t address, TCppIndex_t count) { // FIXME: Add optional arguments to the operator new. TCppObject_t Construct(compat::Interpreter& interp, TCppScope_t scope, void* arena /*=nullptr*/, TCppIndex_t count /*=1UL*/) { - auto* Class = (Decl*)scope; - // FIXME: Diagnose. - if (!HasDefaultConstructor(Class)) - return nullptr; - auto* const Ctor = GetDefaultConstructor(interp, Class); - if (JitCall JC = MakeFunctionCallable(&interp, Ctor)) { - if (arena) { - JC.InvokeConstructor(&arena, count, {}, - (void*)~0); // Tell Invoke to use placement new. - return arena; - } + if (!Cpp::IsConstructor(scope) && !Cpp::IsClass(scope)) + return nullptr; + if (Cpp::IsClass(scope) && !HasDefaultConstructor(scope)) + return nullptr; - void* obj = nullptr; - JC.InvokeConstructor(&obj, count, {}, nullptr); - return obj; + TCppFunction_t ctor = nullptr; + if (Cpp::IsClass(scope)) + ctor = Cpp::GetDefaultConstructor(scope); + else // a ctor + ctor = scope; + + if (JitCall JC = MakeFunctionCallable(&interp, ctor)) { + // invoke the constructor (placement/heap) in one shot + // flag is non-null for placement new, null for normal new + void* is_arena = arena ? reinterpret_cast(1) : nullptr; + void* result = arena; + JC.InvokeConstructor(&result, count, /*args=*/{}, is_arena); + return result; } return nullptr; } diff --git a/unittests/CppInterOp/FunctionReflectionTest.cpp b/unittests/CppInterOp/FunctionReflectionTest.cpp index b69d7c701..099e55aa2 100644 --- a/unittests/CppInterOp/FunctionReflectionTest.cpp +++ b/unittests/CppInterOp/FunctionReflectionTest.cpp @@ -2122,20 +2122,24 @@ TEST(FunctionReflectionTest, Construct) { GTEST_SKIP() << "Disabled on Windows. Needs fixing."; #endif std::vector interpreter_args = {"-include", "new"}; - Cpp::CreateInterpreter(interpreter_args); + std::vector Decls, SubDecls; - Interp->declare(R"( + std::string code = R"( #include extern "C" int printf(const char*,...); class C { + public: int x; C() { x = 12345; printf("Constructor Executed"); } }; - )"); + void construct() { return; } + )"; + GetAllTopLevelDecls(code, Decls, false, interpreter_args); + GetAllSubDecls(Decls[1], SubDecls); testing::internal::CaptureStdout(); Cpp::TCppScope_t scope = Cpp::GetNamed("C"); Cpp::TCppObject_t object = Cpp::Construct(scope); @@ -2155,6 +2159,20 @@ TEST(FunctionReflectionTest, Construct) { EXPECT_EQ(output, "Constructor Executed"); output.clear(); + // Pass a constructor + testing::internal::CaptureStdout(); + where = Cpp::Allocate(scope); + EXPECT_TRUE(where == Cpp::Construct(SubDecls[3], where)); + EXPECT_TRUE(*(int*)where == 12345); + Cpp::Deallocate(scope, where); + output = testing::internal::GetCapturedStdout(); + EXPECT_EQ(output, "Constructor Executed"); + output.clear(); + + // Pass a non-class decl, this should fail + where = Cpp::Allocate(scope); + where = Cpp::Construct(Decls[2], where); + EXPECT_TRUE(where == nullptr); // C API testing::internal::CaptureStdout(); auto* I = clang_createInterpreterFromRawPtr(Cpp::GetInterpreter());