Skip to content

Commit 4153c7e

Browse files
committed
[MLIR:Python] Fix race on PyOperations.
Joint work with @vfdev-5 We found the following TSAN race report in JAX's CI: jax-ml/jax#28551 ``` WARNING: ThreadSanitizer: data race (pid=35893) Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0): #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54) #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d) #2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d) #3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d) ... Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0): #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54) #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012) #2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54) #3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54) #4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54) #5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b) #6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422) #7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54) ... ``` At the simplest level, the `valid` field of a PyOperation must be protected by a lock, because it may be concurrently accessed from multiple threads. Much more interesting, however is how we get into the situation described by the two stack traces above in the first place. The scenario that triggers this is the following: * thread T56 holds the last Python reference on a PyOperation, and decides to release it. * After T56 starts to release its reference, but before T56 removes the PyOperation from the liveOperations map a second thread T57 comes along and looks up the same MlirOperation in the liveOperations map. * Finding the operation to be present, thread T57 increments the reference count of that PyOperation and returns it to the caller. This is illegal! Python is in the process of calling the destructor of that object, and once an object is in that state it cannot be safely revived. To fix this, whenever we increment the reference count of a PyOperation that we found via the liveOperations map and to which we only hold a non-owning reference, we must use the Python 3.14+ API `PyUnstable_TryIncRef`, which exists precisely for this scenario (python/cpython#128844). That API does not exist under Python 3.13, so we need a backport of it in that case, for which we the backport that both nanobind and pybind11 also use. Fixes jax-ml/jax#28551
1 parent 2d287f5 commit 4153c7e

File tree

3 files changed

+215
-45
lines changed

3 files changed

+215
-45
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 137 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,75 @@ class PyOpOperandIterator {
635635
MlirOpOperand opOperand;
636636
};
637637

638+
#if !defined(Py_GIL_DISABLED)
639+
inline void enableTryIncRef(nb::handle obj) noexcept {}
640+
inline bool tryIncRef(nb::handle obj) noexcept {
641+
if (Py_REFCNT(obj.ptr()) > 0) {
642+
Py_INCREF(obj.ptr());
643+
return true;
644+
}
645+
return false;
646+
}
647+
648+
#elif PY_VERSION_HEX >= 0x030E00A5
649+
650+
// CPython 3.14 provides an unstable API for these.
651+
inline void enableTryIncRef(nb::handle obj) noexcept {
652+
PyUnstable_EnableTryIncRef(obj.ptr());
653+
}
654+
inline bool tryIncRef(nb::handle obj) noexcept {
655+
return PyUnstable_TryIncRef(obj.ptr());
656+
}
657+
658+
#else
659+
660+
// For CPython 3.13 there is no API for this, and so we must implement our own.
661+
// This code originates from https://github.com/wjakob/nanobind/pull/865/files.
662+
void enableTryIncRef(nb::handle h) noexcept {
663+
// Since this is called during object construction, we know that we have
664+
// the only reference to the object and can use a non-atomic write.
665+
PyObject *obj = h.ptr();
666+
assert(h->ob_ref_shared == 0);
667+
h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
668+
}
669+
670+
bool tryIncRef(nb::handle h) noexcept {
671+
PyObject *obj = h.ptr();
672+
// See
673+
// https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
674+
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
675+
local += 1;
676+
if (local == 0) {
677+
// immortal
678+
return true;
679+
}
680+
if (_Py_IsOwnedByCurrentThread(obj)) {
681+
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
682+
#ifdef Py_REF_DEBUG
683+
_Py_INCREF_IncRefTotal();
684+
#endif
685+
return true;
686+
}
687+
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
688+
for (;;) {
689+
// If the shared refcount is zero and the object is either merged
690+
// or may not have weak references, then we cannot incref it.
691+
if (shared == 0 || shared == _Py_REF_MERGED) {
692+
return false;
693+
}
694+
695+
if (_Py_atomic_compare_exchange_ssize(&obj->ob_ref_shared, &shared,
696+
shared +
697+
(1 << _Py_REF_SHARED_SHIFT))) {
698+
#ifdef Py_REF_DEBUG
699+
_Py_INCREF_IncRefTotal();
700+
#endif
701+
return true;
702+
}
703+
}
704+
}
705+
#endif
706+
638707
} // namespace
639708

640709
//------------------------------------------------------------------------------
@@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
706775
return liveOperations.size();
707776
}
708777

709-
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
710-
std::vector<PyOperation *> liveObjects;
778+
std::vector<nb::object> PyMlirContext::getLiveOperationObjects() {
779+
std::vector<nb::object> liveObjects;
711780
nb::ft_lock_guard lock(liveOperationsMutex);
712-
for (auto &entry : liveOperations)
713-
liveObjects.push_back(entry.second.second);
781+
for (auto &entry : liveOperations) {
782+
// It is not safe to unconditionally increment the reference count here
783+
// because an operation that is in the process of being deleted by another
784+
// thread may still be present in the map.
785+
if (tryIncRef(entry.second.first)) {
786+
liveObjects.push_back(nb::steal(entry.second.first));
787+
}
788+
}
714789
return liveObjects;
715790
}
716791

@@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
720795
{
721796
nb::ft_lock_guard lock(liveOperationsMutex);
722797
std::swap(operations, liveOperations);
798+
for (auto &op : operations)
799+
op.second.second->setInvalidLocked();
723800
}
724-
for (auto &op : operations)
725-
op.second.second->setInvalid();
726801
size_t numInvalidated = operations.size();
727802
return numInvalidated;
728803
}
729804

730-
void PyMlirContext::clearOperation(MlirOperation op) {
731-
PyOperation *py_op;
732-
{
733-
nb::ft_lock_guard lock(liveOperationsMutex);
734-
auto it = liveOperations.find(op.ptr);
735-
if (it == liveOperations.end()) {
736-
return;
737-
}
738-
py_op = it->second.second;
739-
liveOperations.erase(it);
805+
void PyMlirContext::clearOperationLocked(MlirOperation op) {
806+
auto it = liveOperations.find(op.ptr);
807+
if (it == liveOperations.end()) {
808+
return;
740809
}
741-
py_op->setInvalid();
810+
PyOperation *py_op = it->second.second;
811+
py_op->setInvalidLocked();
812+
liveOperations.erase(it);
813+
}
814+
815+
void PyMlirContext::clearOperation(MlirOperation op) {
816+
nb::ft_lock_guard lock(liveOperationsMutex);
817+
clearOperationLocked(op);
742818
}
743819

744820
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -766,14 +842,14 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
766842
clearOperationsInside(opRef->getOperation());
767843
}
768844

769-
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
845+
void PyMlirContext::clearOperationAndInsideLocked(PyOperationBase &op) {
770846
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
771847
void *userData) {
772848
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
773-
contextRef->clearOperation(op);
849+
contextRef->clearOperationLocked(op);
774850
return MlirWalkResult::MlirWalkResultAdvance;
775851
};
776-
mlirOperationWalk(op.getOperation(), invalidatingCallback,
852+
mlirOperationWalk(op.getOperation().getLocked(), invalidatingCallback,
777853
&op.getOperation().getContext(), MlirWalkPreOrder);
778854
}
779855

@@ -1203,19 +1279,23 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
12031279
: BaseContextObject(std::move(contextRef)), operation(operation) {}
12041280

12051281
PyOperation::~PyOperation() {
1282+
PyMlirContextRef context = getContext();
1283+
nb::ft_lock_guard lock(context->liveOperationsMutex);
12061284
// If the operation has already been invalidated there is nothing to do.
12071285
if (!valid)
12081286
return;
12091287

12101288
// Otherwise, invalidate the operation and remove it from live map when it is
12111289
// attached.
12121290
if (isAttached()) {
1213-
getContext()->clearOperation(*this);
1291+
// Since the operation was valid, we know that it is this object present
1292+
// in the map, not some other object.
1293+
context->liveOperations.erase(operation.ptr);
12141294
} else {
12151295
// And destroy it when it is detached, i.e. owned by Python, in which case
12161296
// all nested operations must be invalidated at removed from the live map as
12171297
// well.
1218-
erase();
1298+
eraseLocked();
12191299
}
12201300
}
12211301

@@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
12411321
// Create.
12421322
PyOperationRef unownedOperation =
12431323
makeObjectRef<PyOperation>(std::move(contextRef), operation);
1324+
enableTryIncRef(unownedOperation.getObject());
12441325
unownedOperation->handle = unownedOperation.getObject();
12451326
if (parentKeepAlive) {
12461327
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
@@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
12541335
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
12551336
auto &liveOperations = contextRef->liveOperations;
12561337
auto it = liveOperations.find(operation.ptr);
1257-
if (it == liveOperations.end()) {
1258-
// Create.
1259-
PyOperationRef result = createInstance(std::move(contextRef), operation,
1260-
std::move(parentKeepAlive));
1261-
liveOperations[operation.ptr] =
1262-
std::make_pair(result.getObject(), result.get());
1263-
return result;
1338+
if (it != liveOperations.end()) {
1339+
PyOperation *existing = it->second.second;
1340+
nb::handle pyRef = it->second.first;
1341+
1342+
// Try to increment the reference count of the existing entry. This can fail
1343+
// if the object is in the process of being destroyed by another thread.
1344+
if (tryIncRef(pyRef)) {
1345+
return PyOperationRef(existing, nb::steal<nb::object>(pyRef));
1346+
}
1347+
1348+
// Mark the existing entry as invalid, since we are about to replace it.
1349+
existing->setInvalidLocked();
12641350
}
1265-
// Use existing.
1266-
PyOperation *existing = it->second.second;
1267-
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1268-
return PyOperationRef(existing, std::move(pyRef));
1351+
1352+
// Create a new wrapper object.
1353+
PyOperationRef result = createInstance(std::move(contextRef), operation,
1354+
std::move(parentKeepAlive));
1355+
liveOperations[operation.ptr] =
1356+
std::make_pair(result.getObject(), result.get());
1357+
return result;
12691358
}
12701359

12711360
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1297,11 +1386,17 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
12971386
}
12981387

12991388
void PyOperation::checkValid() const {
1389+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
1390+
checkValidLocked();
1391+
}
1392+
1393+
void PyOperation::checkValidLocked() const {
13001394
if (!valid) {
13011395
throw std::runtime_error("the operation has been invalidated");
13021396
}
13031397
}
13041398

1399+
13051400
void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
13061401
bool enableDebugInfo, bool prettyDebugInfo,
13071402
bool printGenericOpForm, bool useLocalScope,
@@ -1638,12 +1733,17 @@ nb::object PyOperation::createOpView() {
16381733
return nb::cast(PyOpView(getRef().getObject()));
16391734
}
16401735

1641-
void PyOperation::erase() {
1642-
checkValid();
1643-
getContext()->clearOperationAndInside(*this);
1736+
void PyOperation::eraseLocked() {
1737+
checkValidLocked();
1738+
getContext()->clearOperationAndInsideLocked(*this);
16441739
mlirOperationDestroy(operation);
16451740
}
16461741

1742+
void PyOperation::erase() {
1743+
nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
1744+
eraseLocked();
1745+
}
1746+
16471747
namespace {
16481748
/// CRTP base class for Python MLIR values that subclass Value and should be
16491749
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2324,7 +2424,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
23242424
// The operation is also erased, so we must invalidate it. There may be Python
23252425
// references to this operation so we don't want to delete it from the list of
23262426
// live operations here.
2327-
symbol.getOperation().valid = false;
2427+
symbol.getOperation().setInvalid();
23282428
}
23292429

23302430
void PySymbolTable::dunderDel(const std::string &name) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class PyObjectRef {
8383
}
8484

8585
T *get() { return referrent; }
86-
T *operator->() {
86+
T *operator->() const {
8787
assert(referrent && object);
8888
return referrent;
8989
}
@@ -229,7 +229,7 @@ class PyMlirContext {
229229
static size_t getLiveCount();
230230

231231
/// Get a list of Python objects which are still in the live context map.
232-
std::vector<PyOperation *> getLiveOperationObjects();
232+
std::vector<nanobind::object> getLiveOperationObjects();
233233

234234
/// Gets the count of live operations associated with this context.
235235
/// Used for testing.
@@ -254,9 +254,10 @@ class PyMlirContext {
254254
void clearOperationsInside(PyOperationBase &op);
255255
void clearOperationsInside(MlirOperation op);
256256

257-
/// Clears the operaiton _and_ all operations inside using
258-
/// `clearOperation(MlirOperation)`.
259-
void clearOperationAndInside(PyOperationBase &op);
257+
/// Clears the operation _and_ all operations inside using
258+
/// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is
259+
/// held.
260+
void clearOperationAndInsideLocked(PyOperationBase &op);
260261

261262
/// Gets the count of live modules associated with this context.
262263
/// Used for testing.
@@ -278,6 +279,9 @@ class PyMlirContext {
278279
struct ErrorCapture;
279280

280281
private:
282+
// Similar to clearOperation, but requires the liveOperations mutex to be held
283+
void clearOperationLocked(MlirOperation op);
284+
281285
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
282286
// preserving the relationship that an MlirContext maps to a single
283287
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -302,6 +306,9 @@ class PyMlirContext {
302306
// attempt to access it will raise an error.
303307
using LiveOperationMap =
304308
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
309+
310+
// liveOperationsMutex guards both liveOperations and the valid field of
311+
// PyOperation objects in free-threading mode.
305312
nanobind::ft_mutex liveOperationsMutex;
306313

307314
// Guarded by liveOperationsMutex in free-threading mode.
@@ -336,6 +343,7 @@ class BaseContextObject {
336343
}
337344

338345
/// Accesses the context reference.
346+
const PyMlirContextRef &getContext() const { return contextRef; }
339347
PyMlirContextRef &getContext() { return contextRef; }
340348

341349
private:
@@ -677,6 +685,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
677685
checkValid();
678686
return operation;
679687
}
688+
MlirOperation getLocked() const {
689+
checkValidLocked();
690+
return operation;
691+
}
680692

681693
PyOperationRef getRef() {
682694
return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
@@ -692,6 +704,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
692704
attached = false;
693705
}
694706
void checkValid() const;
707+
void checkValidLocked() const;
695708

696709
/// Gets the owning block or raises an exception if the operation has no
697710
/// owning block.
@@ -725,19 +738,27 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
725738
/// parent context's live operations map, and sets the valid bit false.
726739
void erase();
727740

728-
/// Invalidate the operation.
729-
void setInvalid() { valid = false; }
730-
731741
/// Clones this operation.
732742
nanobind::object clone(const nanobind::object &ip);
733743

744+
/// Invalidate the operation.
745+
void setInvalid() {
746+
nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex);
747+
setInvalidLocked();
748+
}
749+
/// Like setInvalid(), but requires the liveOperations mutex to be held.
750+
void setInvalidLocked() { valid = false; }
751+
734752
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
735753

736754
private:
737755
static PyOperationRef createInstance(PyMlirContextRef contextRef,
738756
MlirOperation operation,
739757
nanobind::object parentKeepAlive);
740758

759+
// Like erase(), but requires the caller to hold the liveOperationsMutex.
760+
void eraseLocked();
761+
741762
MlirOperation operation;
742763
nanobind::handle handle;
743764
// Keeps the parent alive, regardless of whether it is an Operation or
@@ -748,6 +769,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
748769
// ir_operation.py regarding testing corresponding lifetime guarantees.
749770
nanobind::object parentKeepAlive;
750771
bool attached = true;
772+
773+
// Guarded by 'context->liveOperationsMutex'. Valid objects must be present
774+
// in context->liveOperations.
751775
bool valid = true;
752776

753777
friend class PyOperationBase;

0 commit comments

Comments
 (0)