Skip to content

[mlir][python] Clear PyOperations instead of invalidating them. #70044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,12 @@ typedef enum MlirWalkOrder {
} MlirWalkOrder;

/// Operation walker type. The handler is passed an (opaque) reference to an
/// operation a pointer to a `userData`.
/// operation and a pointer to a `userData`.
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);

/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
/// some context or other data into the callback.
/// context or other data into the callback.
MLIR_CAPI_EXPORTED
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder);
Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,12 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}

void PyMlirContext::setOperationInvalid(MlirOperation op) {
if (liveOperations.contains(op.ptr))
liveOperations[op.ptr].second->setInvalid();
void PyMlirContext::clearOperation(MlirOperation op) {
auto it = liveOperations.find(op.ptr);
if (it != liveOperations.end()) {
it->second.second->setInvalid();
liveOperations.erase(it);
}
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ class PyMlirContext {
/// place.
size_t clearLiveOperations();

/// Sets an operation invalid. This is useful for when some non-bindings
/// code destroys the operation and the bindings need to made aware. For
/// example, in the case when pass manager is run.
void setOperationInvalid(MlirOperation op);
/// Removes an operation from the live operations map and sets it invalid.
/// This is useful for when some non-bindings code destroys the operation and
/// the bindings need to made aware. For example, in the case when pass
/// manager is run.
void clearOperation(MlirOperation op);

/// Gets the count of live modules associated with this context.
/// Used for testing.
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
[](MlirOperation op, void *userData) {
callBackData *data = static_cast<callBackData *>(userData);
if (LLVM_LIKELY(data->rootSeen))
data->rootOp.getOperation()
.getContext()
->setOperationInvalid(op);
data->rootOp.getOperation().getContext()->clearOperation(
op);
else
data->rootSeen = true;
};
Expand Down
25 changes: 23 additions & 2 deletions mlir/test/python/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ def testRunPipelineError():
@run
def testPostPassOpInvalidation():
with Context() as ctx:
log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())

# CHECK: invalidate_ops=False
log("invalidate_ops=False")

# CHECK: live ops: 0
log_op_count()

module = ModuleOp.parse(
"""
module {
Expand All @@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
"""
)

# CHECK: invalidate_ops=False
log("invalidate_ops=False")
# CHECK: live ops: 1
log_op_count()

outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
Expand All @@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)

# CHECK: live ops: 4
log_op_count()

PassManager.parse("builtin.module(canonicalize)").run(
module, invalidate_ops=False
)
Expand All @@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
# CHECK: invalidate_ops=True
log("invalidate_ops=True")

# CHECK: live ops: 4
log_op_count()

module = ModuleOp.parse(
"""
module {
Expand All @@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]

# CHECK: live ops: 4
log_op_count()

PassManager.parse("builtin.module(canonicalize)").run(module)

# CHECK: live ops: 1
log_op_count()

try:
log(func_op)
except RuntimeError as e:
Expand Down