diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index e361f33a0d836..7b121d4df3286 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -73,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void); /// /// A named attribute is essentially a (name, attribute) pair where the name is /// a string. - struct MlirNamedAttribute { MlirIdentifier name; MlirAttribute attribute; @@ -698,6 +697,24 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, /// ownership is transferred to the block of the other operation. MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); + +/// Traversal order for operation walk. +typedef enum MlirWalkOrder { + MlirWalkPreOrder, + MlirWalkPostOrder +} MlirWalkOrder; + +/// Operation walker type. The handler is passed an (opaque) reference to an +/// operation a pointer to a `userData`. +typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData); + +/// Walks operation `op` in `walkOrder` and calls `callback` on that operation. +/// `*userData` is passed to the callback as well and can be used to tunnel some +/// some context or other data into the callback. +MLIR_CAPI_EXPORTED +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder); + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 389a4621c14e5..a8ea1a381edb9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() { return numInvalidated; } +void PyMlirContext::setOperationInvalid(MlirOperation op) { + if (liveOperations.contains(op.ptr)) + liveOperations[op.ptr].second->setInvalid(); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index c5412e735dddc..26292885711a4 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -209,6 +209,11 @@ class PyMlirContext { /// place. size_t clearLiveOperations(); + /// Sets an operation invalid. This is useful for when some non-bindings + /// code destroys the operation and the bindings need to made aware. For + /// example, in the case when pass manager is run. + void setOperationInvalid(MlirOperation op); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index cdbfcfbc22957..2175cea79960c 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir-c/Pass.h" namespace py = pybind11; +using namespace py::literals; using namespace mlir; using namespace mlir::python; @@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(anchorOp.data(), anchorOp.size())); return new PyPassManager(passManager); }), - py::arg("anchor_op") = py::str("any"), - py::arg("context") = py::none(), + "anchor_op"_a = py::str("any"), "context"_a = py::none(), "Create a new PassManager for the current (or provided) Context.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) @@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](PyPassManager &passManager, bool enable) { mlirPassManagerEnableVerifier(passManager.get(), enable); }, - py::arg("enable"), "Enable / disable verify-each.") + "enable"_a, "Enable / disable verify-each.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { throw py::value_error(std::string(errorMsg.join())); return new PyPassManager(passManager); }, - py::arg("pipeline"), py::arg("context") = py::none(), + "pipeline"_a, "context"_a = py::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -111,12 +111,35 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { if (mlirLogicalResultIsFailure(status)) throw py::value_error(std::string(errorMsg.join())); }, - py::arg("pipeline"), + "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op) { + [](PyPassManager &passManager, PyOperationBase &op, + bool invalidateOps) { + if (invalidateOps) { + typedef struct { + PyOperation &rootOp; + bool rootSeen; + } callBackData; + callBackData data{op.getOperation(), false}; + // Mark all ops below the op that the passmanager will be rooted + // at (but not op itself - note the preorder) as invalid. + MlirOperationWalkCallback invalidatingCallback = + [](MlirOperation op, void *userData) { + callBackData *data = static_cast(userData); + if (LLVM_LIKELY(data->rootSeen)) + data->rootOp.getOperation() + .getContext() + ->setOperationInvalid(op); + else + data->rootSeen = true; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + static_cast(&data), MlirWalkPreOrder); + } + // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); @@ -124,7 +147,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - py::arg("operation"), + "operation"_a, "invalidate_ops"_a = true, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index c1abbbe364611..0a5151751873f 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { return unwrap(op)->moveBefore(unwrap(other)); } +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder) { + switch (walkOrder) { + + case MlirWalkPreOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + break; + case MlirWalkPostOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + } +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index a181332e219db..d5ca8884306c3 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -2210,6 +2210,51 @@ int testSymbolTable(MlirContext ctx) { return 0; } +typedef struct { + const char *x; +} callBackData; + +void walkCallBack(MlirOperation op, void *rootOpVoid) { + fprintf(stderr, "%s: %s\n", ((callBackData *)(rootOpVoid))->x, + mlirIdentifierStr(mlirOperationGetName(op)).data); +} + +int testOperationWalk(MlirContext ctx) { + // CHECK-LABEL: @testOperationWalk + fprintf(stderr, "@testOperationWalk\n"); + + const char *moduleString = "module {\n" + " func.func @foo() {\n" + " %1 = arith.constant 10: i32\n" + " arith.addi %1, %1: i32\n" + " return\n" + " }\n" + "}"; + MlirModule module = + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString(moduleString)); + + callBackData data; + data.x = "i love you"; + + // CHECK: i love you: arith.constant + // CHECK: i love you: arith.addi + // CHECK: i love you: func.return + // CHECK: i love you: func.func + // CHECK: i love you: builtin.module + mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack, + (void *)(&data), MlirWalkPostOrder); + + data.x = "i don't love you"; + // CHECK: i don't love you: builtin.module + // CHECK: i don't love you: func.func + // CHECK: i don't love you: arith.constant + // CHECK: i don't love you: arith.addi + // CHECK: i don't love you: func.return + mlirOperationWalk(mlirModuleGetOperation(module), walkCallBack, + (void *)(&data), MlirWalkPreOrder); + return 0; +} + int testDialectRegistry(void) { fprintf(stderr, "@testDialectRegistry\n"); @@ -2349,6 +2394,8 @@ int main(void) { return 14; if (testDialectRegistry()) return 15; + if (testOperationWalk(ctx)) + return 16; testExplicitThreadPools(); testDiagnostics(); diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index 4b3a02ac42bd9..e7f79ddc75113 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -4,6 +4,8 @@ from mlir.ir import * from mlir.passmanager import * from mlir.dialects.func import FuncOp +from mlir.dialects.builtin import ModuleOp + # Log everything to stderr and flush so that we have a unified stream to match # errors/info emitted by MLIR to stderr. @@ -33,6 +35,7 @@ def testCapsule(): run(testCapsule) + # CHECK-LABEL: TEST: testConstruct @run def testConstruct(): @@ -68,6 +71,7 @@ def testParseSuccess(): run(testParseSuccess) + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSpacedPipeline def testParseSpacedPipeline(): @@ -84,6 +88,7 @@ def testParseSpacedPipeline(): run(testParseSpacedPipeline) + # Verify failure on unregistered pass. # CHECK-LABEL: TEST: testParseFail def testParseFail(): @@ -102,6 +107,7 @@ def testParseFail(): run(testParseFail) + # Check that adding to a pass manager works # CHECK-LABEL: TEST: testAdd @run @@ -147,6 +153,7 @@ def testRunPipeline(): # CHECK: func.return , 1 run(testRunPipeline) + # CHECK-LABEL: TEST: testRunPipelineError @run def testRunPipelineError(): @@ -162,4 +169,94 @@ def testRunPipelineError(): # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () # CHECK: > - print(f"Exception: <{e}>") + log(f"Exception: <{e}>") + + +# CHECK-LABEL: TEST: testPostPassOpInvalidation +@run +def testPostPassOpInvalidation(): + with Context() as ctx: + module = ModuleOp.parse( + """ + module { + arith.constant 10 + func.func @foo() { + arith.constant 10 + return + } + } + """ + ) + + # CHECK: invalidate_ops=False + log("invalidate_ops=False") + + outer_const_op = module.body.operations[0] + # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 + log(outer_const_op) + + func_op = module.body.operations[1] + # CHECK: func.func @[[FOO:.*]]() { + # CHECK: %[[VAL1:.*]] = arith.constant 10 : i64 + # CHECK: return + # CHECK: } + log(func_op) + + inner_const_op = func_op.body.blocks[0].operations[0] + # CHECK: %[[VAL1]] = arith.constant 10 : i64 + log(inner_const_op) + + PassManager.parse("builtin.module(canonicalize)").run( + module, invalidate_ops=False + ) + # CHECK: func.func @foo() { + # CHECK: return + # CHECK: } + log(func_op) + + # CHECK: func.func @foo() { + # CHECK: return + # CHECK: } + log(module) + + # CHECK: invalidate_ops=True + log("invalidate_ops=True") + + module = ModuleOp.parse( + """ + module { + arith.constant 10 + func.func @foo() { + arith.constant 10 + return + } + } + """ + ) + outer_const_op = module.body.operations[0] + func_op = module.body.operations[1] + inner_const_op = func_op.body.blocks[0].operations[0] + + PassManager.parse("builtin.module(canonicalize)").run(module) + try: + log(func_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + + try: + log(outer_const_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + + try: + log(inner_const_op) + except RuntimeError as e: + # CHECK: the operation has been invalidated + log(e) + + # CHECK: func.func @foo() { + # CHECK: return + # CHECK: } + log(module)