From 1c3071dc523e803bc9f5f8dad21f79a9c8617535 Mon Sep 17 00:00:00 2001 From: Emma Harper Smith Date: Mon, 19 May 2025 12:25:23 -0400 Subject: [PATCH] Disallow sharing zstd (de)compressor contexts According to the zstd author, it is not possible to share Zstandard objects across thread boundaries. To resolve this, we check if the object was created on the current thread and raise a RuntimeError if it is not. The tests are updated to ensure that the error is raised if a (de)compression context is shared across threads. --- Lib/test/test_zstd.py | 51 ++++++++---------------------------- Modules/_zstd/_zstdmodule.h | 17 ++++++++++++ Modules/_zstd/compressor.c | 19 +++++++++----- Modules/_zstd/decompressor.c | 15 ++++++++--- 4 files changed, 51 insertions(+), 51 deletions(-) diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 53ca592ea38828..a510b7a3d5d552 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -2430,83 +2430,54 @@ def test_buffer_protocol(self): self.assertEqual(f.write(arr), LENGTH) self.assertEqual(f.tell(), LENGTH) -@unittest.skip("it fails for now, see gh-133885") + class FreeThreadingMethodTests(unittest.TestCase): @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() - def test_compress_locking(self): + def test_compressor_cannot_share(self): input = b'a'* (16*_1K) num_threads = 8 - comp = ZstdCompressor() - parts = [] - for _ in range(num_threads): - res = comp.compress(input, ZstdCompressor.FLUSH_BLOCK) - if res: - parts.append(res) - rest1 = comp.flush() - expected = b''.join(parts) + rest1 comp = ZstdCompressor() - output = [] - def run_method(method, input_data, output_data): - res = method(input_data, ZstdCompressor.FLUSH_BLOCK) - if res: - output_data.append(res) + def run_method(method, input_data): + with self.assertRaises(RuntimeError): + method(input_data, ZstdCompressor.FLUSH_BLOCK) threads = [] for i in range(num_threads): - thread = threading.Thread(target=run_method, args=(comp.compress, input, output)) + thread = threading.Thread(target=run_method, args=(comp.compress, input)) threads.append(thread) with threading_helper.start_threads(threads): pass - rest2 = comp.flush() - self.assertEqual(rest1, rest2) - actual = b''.join(output) + rest2 - self.assertEqual(expected, actual) - @unittest.skipUnless(Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled') @threading_helper.reap_threads @threading_helper.requires_working_threading() - def test_decompress_locking(self): + def test_decompressor_cannot_share(self): input = compress(b'a'* (16*_1K)) num_threads = 8 # to ensure we decompress over multiple calls, set maxsize window_size = _1K * 16//num_threads - decomp = ZstdDecompressor() - parts = [] - for _ in range(num_threads): - res = decomp.decompress(input, window_size) - if res: - parts.append(res) - expected = b''.join(parts) - comp = ZstdDecompressor() - output = [] - def run_method(method, input_data, output_data): - res = method(input_data, window_size) - if res: - output_data.append(res) + def run_method(method, input_data): + with self.assertRaises(RuntimeError): + method(input_data, window_size) threads = [] for i in range(num_threads): - thread = threading.Thread(target=run_method, args=(comp.decompress, input, output)) + thread = threading.Thread(target=run_method, args=(comp.decompress, input)) threads.append(thread) with threading_helper.start_threads(threads): pass - actual = b''.join(output) - self.assertEqual(expected, actual) - - if __name__ == "__main__": unittest.main() diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index b36486442c6567..31fef0ec3d6022 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -52,4 +52,21 @@ extern void set_parameter_error(const _zstd_state* const state, int is_compress, int key_v, int value_v); +static inline int +check_object_shared(PyObject *ob, char *type) +{ +#if defined(Py_GIL_DISABLED) + if (!_Py_IsOwnedByCurrentThread(ob)) + { + PyErr_Format(PyExc_RuntimeError, + "%s cannot be shared across multiple threads.", + type); + return 1; + } + return 0; +#else + return 0; +#endif +} + #endif // !ZSTD_MODULE_H diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 38baee2be1e95b..6a788d6f6d5c62 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -575,6 +575,12 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, { PyObject *ret; + /* Check we are on the same thread as the compressor was created */ + if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0) + { + return NULL; + } + /* Check mode value */ if (mode != ZSTD_e_continue && mode != ZSTD_e_flush && @@ -587,9 +593,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, return NULL; } - /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - /* Compress */ if (self->use_multithread && mode == ZSTD_e_continue) { ret = compress_mt_continue_impl(self, data); @@ -607,7 +610,6 @@ _zstd_ZstdCompressor_compress_impl(ZstdCompressor *self, Py_buffer *data, /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); return ret; } @@ -632,6 +634,12 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) { PyObject *ret; + /* Check we are on the same thread as the compressor was created */ + if (check_object_shared((PyObject *)self, "ZstdCompressor") > 0) + { + return NULL; + } + /* Check mode value */ if (mode != ZSTD_e_end && mode != ZSTD_e_flush) { PyErr_SetString(PyExc_ValueError, @@ -641,8 +649,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) return NULL; } - /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); ret = compress_impl(self, NULL, mode); if (ret) { @@ -654,7 +660,6 @@ _zstd_ZstdCompressor_flush_impl(ZstdCompressor *self, int mode) /* Resetting cctx's session never fail */ ZSTD_CCtx_reset(self->cctx, ZSTD_reset_session_only); } - Py_END_CRITICAL_SECTION(); return ret; } diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 58f9c9f804e549..645e3fb954b777 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -639,6 +639,12 @@ _zstd_ZstdDecompressor_unused_data_get_impl(ZstdDecompressor *self) { PyObject *ret; + /* Check we are on the same thread as the decompressor was created */ + if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0) + { + return NULL; + } + if (!self->eof) { return Py_GetConstant(Py_CONSTANT_EMPTY_BYTES); } @@ -692,11 +698,12 @@ _zstd_ZstdDecompressor_decompress_impl(ZstdDecompressor *self, /*[clinic end generated code: output=a4302b3c940dbec6 input=6463dfdf98091caa]*/ { PyObject *ret; - /* Thread-safe code */ - Py_BEGIN_CRITICAL_SECTION(self); - + /* Check we are on the same thread as the decompressor was created */ + if (check_object_shared((PyObject *)self, "ZstdDecompressor") > 0) + { + return NULL; + } ret = stream_decompress(self, data, max_length); - Py_END_CRITICAL_SECTION(); return ret; }