From 9b3a14e425150a42b08e1f046b4646e6e5939920 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 22 Mar 2024 11:41:40 -0700 Subject: [PATCH] [mlir] Make C/Python ExecutionEngine constructible with an Operation. This continues the long deprivileging of mlir.ir.Module as having any semantic meaning. Given the potential for silent/deadly failures by changing a C API signature, I added a new C API entrypoint with a new name and marked the original as deprecated. The `ExecutionEngine()` constructor was extended to accept either a `Module` or an `Operation`, so there should be no user-level API breakage. Test was added to verify. Python ExecutionEngine tests were modernized to use `Operation.parse` and explicit outer modules. --- mlir/include/mlir-c/ExecutionEngine.h | 9 +- .../Bindings/Python/ExecutionEngineModule.cpp | 25 ++++- .../CAPI/ExecutionEngine/ExecutionEngine.cpp | 16 ++- .../mlir/_mlir_libs/_mlirExecutionEngine.pyi | 4 +- mlir/test/python/execution_engine.py | 105 +++++++++++++----- 5 files changed, 119 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 99cddc5c2598d..311451a029181 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -42,8 +42,15 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void); /// that will be loaded are specified via `numPaths` and `sharedLibPaths` /// respectively. /// TODO: figure out other options. +MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreateFromOp( + MlirOperation op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths, bool enableObjectDump); + +// Deprecated variant which takes an MlirModule instead of an operation. +// This is being preserved as of 2024-Mar for short term consistency and should +// be dropped soon. MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( - MlirModule op, int optLevel, int numPaths, + MlirModule module, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, bool enableObjectDump); /// Destroy an ExecutionEngine instance. diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index b3df30583fc96..9ed5ee80f97f8 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -71,15 +71,34 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "ExecutionEngine", py::module_local()) - .def(py::init<>([](MlirModule module, int optLevel, + .def(py::init<>([](py::object operation_or_module, int optLevel, const std::vector &sharedLibPaths, bool enableObjectDump) { + // Manually type cast from either a Module or Operation. The + // automatic type casters do not handle such cascades well, + // so be explicit. + py::object capsule = mlirApiObjectToCapsule(operation_or_module); + MlirOperation module_op = + mlirPythonCapsuleToOperation(capsule.ptr()); + if (mlirOperationIsNull(module_op)) { + // If null, then a PyErr_Set has set an exception, which we must + // clear. + PyErr_Clear(); + MlirModule mod = mlirPythonCapsuleToModule(capsule.ptr()); + if (mlirModuleIsNull(mod)) { + throw py::type_error( + "ExecutionEngine expects a Module or Operation"); + } + module_op = mlirModuleGetOperation(mod); + } + llvm::SmallVector libPaths; for (const std::string &path : sharedLibPaths) libPaths.push_back({path.c_str(), path.length()}); MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module, optLevel, libPaths.size(), - libPaths.data(), enableObjectDump); + mlirExecutionEngineCreateFromOp( + module_op, optLevel, libPaths.size(), libPaths.data(), + enableObjectDump); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 507be9171d328..8bd7e8b354f34 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -20,9 +20,18 @@ using namespace mlir; extern "C" MlirExecutionEngine -mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, +mlirExecutionEngineCreate(MlirModule module, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, bool enableObjectDump) { + return mlirExecutionEngineCreateFromOp(mlirModuleGetOperation(module), + optLevel, numPaths, sharedLibPaths, + enableObjectDump); +} + +extern "C" MlirExecutionEngine +mlirExecutionEngineCreateFromOp(MlirOperation op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths, + bool enableObjectDump) { static bool initOnce = [] { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm @@ -104,9 +113,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = - { llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported }; + symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported}; return symbolMap; }); } diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi index 893dab8a431fd..c32b5db13241c 100644 --- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -4,7 +4,7 @@ # * Relative imports for cross-module references. # * Add __all__ -from typing import List, Sequence +from typing import List, Sequence,Union from ._mlir import ir as _ir @@ -13,7 +13,7 @@ __all__ = [ ] class ExecutionEngine: - def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ... + def __init__(self, module: Union[_ir.Operation, _ir.Module], opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... def dump_to_object_file(self, file_name: str) -> None: ... diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index e8b47007a8907..647e6667b69a3 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -21,17 +21,46 @@ def run(f): assert Context._get_live_count() == 0 -# Verify capsule interop. -# CHECK-LABEL: TEST: testCapsule -def testCapsule(): +# Verify capsule interop for passing an Operation. +# CHECK-LABEL: TEST: testAcceptsOperation +def testAcceptsOperation(): + with Context(): + module = Operation.parse( + r""" +builtin.module { +llvm.func @none() { +llvm.return +} +} + """ + ) + execution_engine = ExecutionEngine(module) + execution_engine_capsule = execution_engine._CAPIPtr + # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr + log(repr(execution_engine_capsule)) + execution_engine._testing_release() + execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule) + # CHECK: _mlirExecutionEngine.ExecutionEngine + log(repr(execution_engine1)) + + +run(testAcceptsOperation) + + +# Verify capsule interop for passing a Module. +# CHECK-LABEL: TEST: testAcceptsModule +def testAcceptsModule(): with Context(): module = Module.parse( r""" +builtin.module { llvm.func @none() { - llvm.return +llvm.return +} } """ ) + print("MODULE:", type(module)) execution_engine = ExecutionEngine(module) execution_engine_capsule = execution_engine._CAPIPtr # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr @@ -42,7 +71,7 @@ def testCapsule(): log(repr(execution_engine1)) -run(testCapsule) +run(testAcceptsModule) # Test invalid ExecutionEngine creation @@ -50,9 +79,11 @@ def testCapsule(): def testInvalidModule(): with Context(): # Builtin function - module = Module.parse( + module = Operation.parse( r""" + builtin.module { func.func @foo() { return } + } """ ) # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine. @@ -69,7 +100,7 @@ def lowerToLLVM(module): pm = PassManager.parse( "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" ) - pm.run(module.operation) + pm.run(module) return module @@ -77,10 +108,12 @@ def lowerToLLVM(module): # CHECK-LABEL: TEST: testInvokeVoid def testInvokeVoid(): with Context(): - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @void() attributes { llvm.emit_c_interface } { return +} } """ ) @@ -96,11 +129,13 @@ def testInvokeVoid(): # CHECK-LABEL: TEST: testInvokeFloatAdd def testInvokeFloatAdd(): with Context(): - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } { %add = arith.addf %arg0, %arg1 : f32 return %add : f32 +} } """ ) @@ -129,13 +164,15 @@ def callback(a, b): with Context(): # The module just forwards to a runtime function known as "some_callback_into_python". - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } { %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32) return %resf : f32 } func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface } +} """ ) execution_engine = ExecutionEngine(lowerToLLVM(module)) @@ -168,13 +205,15 @@ def callback(a): with Context(): # The module just forwards to a runtime function known as "some_callback_into_python". - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } { call @some_callback_into_python(%arg0) : (memref<*xf32>) -> () return } func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface } +} """ ) execution_engine = ExecutionEngine(lowerToLLVM(module)) @@ -221,13 +260,15 @@ def callback(a): with Context(): # The module just forwards to a runtime function known as "some_callback_into_python". - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } { call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> () return } func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface } +} """ ) execution_engine = ExecutionEngine(lowerToLLVM(module)) @@ -262,8 +303,9 @@ def callback(a): with Context(): # The module takes a subview of the argument memref and calls the callback with it - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref, index, index, index %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref to memref<2xf32, strided<[1], offset: 3>> @@ -272,6 +314,7 @@ def callback(a): return } func.func private @some_callback_into_python(memref>) attributes {llvm.emit_c_interface} +} """ ) execution_engine = ExecutionEngine(lowerToLLVM(module)) @@ -301,8 +344,9 @@ def callback(a): with Context(): # The module takes a subview of the argument memref, casts it to an unranked memref and # calls the callback with it. - module = Module.parse( + module = Operation.parse( r""" +builtin.module { func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} { %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref, index, index, index %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref to memref<2xf32, strided<[1], offset: 3>> @@ -311,6 +355,7 @@ def callback(a): return } func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface} +} """ ) execution_engine = ExecutionEngine(lowerToLLVM(module)) @@ -330,9 +375,9 @@ def callback(a): # CHECK-LABEL: TEST: testMemrefAdd def testMemrefAdd(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main(%arg0: memref<1xf32>, %arg1: memref, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } { %0 = arith.constant 0 : index %1 = memref.load %arg0[%0] : memref<1xf32> @@ -372,9 +417,9 @@ def testMemrefAdd(): # CHECK-LABEL: TEST: testF16MemrefAdd def testF16MemrefAdd(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main(%arg0: memref<1xf16>, %arg1: memref<1xf16>, %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } { @@ -422,9 +467,9 @@ def testF16MemrefAdd(): # CHECK-LABEL: TEST: testComplexMemrefAdd def testComplexMemrefAdd(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex>, %arg2: memref<1xcomplex>) attributes { llvm.emit_c_interface } { @@ -472,9 +517,9 @@ def testComplexMemrefAdd(): # CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd def testComplexUnrankedMemrefAdd(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main(%arg0: memref<*xcomplex>, %arg1: memref<*xcomplex>, %arg2: memref<*xcomplex>) attributes { llvm.emit_c_interface } { @@ -525,9 +570,9 @@ def testComplexUnrankedMemrefAdd(): # CHECK-LABEL: TEST: testDynamicMemrefAdd2D def testDynamicMemrefAdd2D(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index @@ -589,9 +634,9 @@ def testDynamicMemrefAdd2D(): # CHECK-LABEL: TEST: testSharedLibLoad def testSharedLibLoad(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } { %c0 = arith.constant 0 : index %cst42 = arith.constant 42.0 : f32 @@ -640,9 +685,9 @@ def testSharedLibLoad(): # CHECK-LABEL: TEST: testNanoTime def testNanoTime(): with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main() attributes { llvm.emit_c_interface } { %now = call @nanoTime() : () -> i64 %memref = memref.alloca() : memref<1xi64> @@ -686,9 +731,9 @@ def testDumpToObjectFile(): try: with Context(): - module = Module.parse( + module = Operation.parse( """ - module { + builtin.module { func.func @main() attributes { llvm.emit_c_interface } { return }