diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 00925ae920aad9..aece7abc0c87db 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -56,6 +56,7 @@ Iterator Arguments Results :func:`groupby` iterable[, key] sub-iterators grouped by value of key(v) ``groupby(['A','B','DEF'], len) → (1, A B) (3, DEF)`` :func:`islice` seq, [start,] stop [, step] elements from seq[start:stop:step] ``islice('ABCDEFG', 2, None) → C D E F G`` :func:`pairwise` iterable (p[0], p[1]), (p[1], p[2]) ``pairwise('ABCDEFG') → AB BC CD DE EF FG`` +:func:`serialize` iterable p0, p1, p2, ... ``serialize([1,4,6]) → 1 4 6`` :func:`starmap` func, seq func(\*seq[0]), func(\*seq[1]), ... ``starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000`` :func:`takewhile` predicate, seq seq[0], seq[1], until predicate fails ``takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4`` :func:`tee` it, n it1, it2, ... itn splits one iterator into n ``tee('ABC', 2) → A B C, A B C`` @@ -648,6 +649,19 @@ loops that truncate the stream. >>> list(map(pow, range(10), repeat(2))) [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] +.. function:: serialize(iterable) + + Make an iterator thread-safe. [TBD] + + Roughly equivalent to:: + + class serialize(Iterator): + def __init__(self, it): + self._it = iter(it) + self._lock = Lock() + def __next__(self): + with self._lock: + return next(self._it) .. function:: starmap(function, iterable) diff --git a/Lib/test/test_free_threading/test_itertools_serialize.py b/Lib/test/test_free_threading/test_itertools_serialize.py new file mode 100644 index 00000000000000..a030c3dd19754f --- /dev/null +++ b/Lib/test/test_free_threading/test_itertools_serialize.py @@ -0,0 +1,82 @@ +import unittest +from threading import Thread, Barrier +from itertools import serialize +from test.support import threading_helper + + +threading_helper.requires_working_threading(module=True) + +class non_atomic_iterator: + + def __init__(self, it): + self.it = iter(it) + + def __iter__(self): + return self + + def __next__(self): + a = next(self.it) + b = next(self.it) + return a, b + +def count(): + i = 0 + while True: + i += 1 + yield i + +class SerializeThreading(unittest.TestCase): + + @threading_helper.reap_threads + def test_serialize(self): + number_of_threads = 10 + number_of_iterations = 10 + barrier = Barrier(number_of_threads) + def work(it): + while True: + try: + a, b = next(it) + assert a + 1 == b + except StopIteration: + break + + data = tuple(range(400)) + for it in range(number_of_iterations): + serialize_iterator = serialize(non_atomic_iterator(data,)) + worker_threads = [] + for ii in range(number_of_threads): + worker_threads.append( + Thread(target=work, args=[serialize_iterator])) + + with threading_helper.start_threads(worker_threads): + pass + + barrier.reset() + + @threading_helper.reap_threads + def test_serialize_generator(self): + number_of_threads = 5 + number_of_iterations = 4 + barrier = Barrier(number_of_threads) + def work(it): + barrier.wait() + for _ in range(1_000): + try: + next(it) + except StopIteration: + break + + for it in range(number_of_iterations): + generator = serialize(count()) + worker_threads = [] + for ii in range(number_of_threads): + worker_threads.append( + Thread(target=work, args=[generator])) + + with threading_helper.start_threads(worker_threads): + pass + + barrier.reset() + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 61bea9dba07fec..24e3151a910148 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -2331,6 +2331,19 @@ def test_tee(self): self.assertRaises(TypeError, tee, N(s)) self.assertRaises(ZeroDivisionError, list, tee(E(s))[0]) + def test_serialize(self): + for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)): + for g in (G, I, Ig, S, L, R): + seq = list(g(s)) + expected = seq + actual = list(serialize(g(s))) + self.assertEqual(actual, expected) + self.assertRaises(TypeError, serialize, X(s)) + self.assertRaises(TypeError, serialize, N(s)) + self.assertRaises(ZeroDivisionError, list, serialize(E(s))) + for arg in [1, True, sys]: + self.assertRaises(TypeError, serialize, arg) + class LengthTransparency(unittest.TestCase): def test_repeat(self): diff --git a/Modules/clinic/itertoolsmodule.c.h b/Modules/clinic/itertoolsmodule.c.h index 0af82e7eb05be8..d9dc66b33e1f93 100644 --- a/Modules/clinic/itertoolsmodule.c.h +++ b/Modules/clinic/itertoolsmodule.c.h @@ -965,4 +965,34 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs) exit: return return_value; } -/*[clinic end generated code: output=999758202a532e0a input=a9049054013a1b77]*/ + +PyDoc_STRVAR(itertools_serialize__doc__, +"serialize(iterable, /)\n" +"--\n" +"\n" +"Make an iterator thread-safe [tbd]"); + +static PyObject * +itertools_serialize_impl(PyTypeObject *type, PyObject *iterable); + +static PyObject * +itertools_serialize(PyTypeObject *type, PyObject *args, PyObject *kwargs) +{ + PyObject *return_value = NULL; + PyTypeObject *base_tp = clinic_state()->serialize_type; + PyObject *iterable; + + if ((type == base_tp || type->tp_init == base_tp->tp_init) && + !_PyArg_NoKeywords("serialize", kwargs)) { + goto exit; + } + if (!_PyArg_CheckPositional("serialize", PyTuple_GET_SIZE(args), 1, 1)) { + goto exit; + } + iterable = PyTuple_GET_ITEM(args, 0); + return_value = itertools_serialize_impl(type, iterable); + +exit: + return return_value; +} +/*[clinic end generated code: output=5f9393576be897be input=a9049054013a1b77]*/ diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index 943c1e8607b38f..a4b0393782aebc 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -32,6 +32,7 @@ typedef struct { PyTypeObject *permutations_type; PyTypeObject *product_type; PyTypeObject *repeat_type; + PyTypeObject *serialize_type; PyTypeObject *starmap_type; PyTypeObject *takewhile_type; PyTypeObject *tee_type; @@ -85,8 +86,9 @@ class itertools.compress "compressobject *" "clinic_state()->compress_type" class itertools.filterfalse "filterfalseobject *" "clinic_state()->filterfalse_type" class itertools.count "countobject *" "clinic_state()->count_type" class itertools.pairwise "pairwiseobject *" "clinic_state()->pairwise_type" +class itertools.serialize "serializeobject *" "clinic_state()->serialize_type" [clinic start generated code]*/ -/*[clinic end generated code: output=da39a3ee5e6b4b0d input=aa48fe4de9d4080f]*/ +/*[clinic end generated code: output=da39a3ee5e6b4b0d input=1261e430ec3a27e1]*/ #define clinic_state() (find_state_by_type(type)) #define clinic_state_by_cls() (get_module_state_by_cls(base_tp)) @@ -3697,6 +3699,100 @@ static PyType_Spec repeat_spec = { .slots = repeat_slots, }; +/* serialize object **************************************************************/ + +typedef struct { + PyObject_HEAD + PyObject *it; +} serializeobject; + +#define serializeobject_CAST(op) ((serializeobject *)(op)) + +/*[clinic input] +@classmethod +itertools.serialize.__new__ + iterable: object + / +Make an iterator thread-safe [tbd] + +[clinic start generated code]*/ + +static PyObject * +itertools_serialize_impl(PyTypeObject *type, PyObject *iterable) +/*[clinic end generated code: output=abd19e483759f1b7 input=0099ab7fd57cdc9f]*/ +{ + /* Get iterator. */ + PyObject *it = PyObject_GetIter(iterable); + if (it == NULL) + return NULL; + + serializeobject *lz = (serializeobject *)type->tp_alloc(type, 0); + lz->it = it; + + return (PyObject *)lz; +} + +static void +serialize_dealloc(PyObject *op) +{ + serializeobject *lz = serializeobject_CAST(op); + PyTypeObject *tp = Py_TYPE(lz); + PyObject_GC_UnTrack(lz); + Py_XDECREF(lz->it); + tp->tp_free(lz); + Py_DECREF(tp); +} + +static int +serialize_traverse(PyObject *op, visitproc visit, void *arg) +{ + serializeobject *lz = serializeobject_CAST(op); + Py_VISIT(Py_TYPE(lz)); + Py_VISIT(lz->it); + return 0; +} + +static PyObject * +serialize_next(PyObject *op) +{ + serializeobject *lz = serializeobject_CAST(op); + PyObject *result = NULL; + + Py_BEGIN_CRITICAL_SECTION(op); // or lock on op->it ? + PyObject *it = lz->it; + if (it != NULL) { + result = PyIter_Next(lz->it); + if (result == NULL) { + /* Note: StopIteration is already cleared by PyIter_Next() */ + if (PyErr_Occurred()) + return NULL; + Py_CLEAR(lz->it); + } + } + Py_END_CRITICAL_SECTION(); + return result; +} + +static PyType_Slot serialize_slots[] = { + {Py_tp_dealloc, serialize_dealloc}, + {Py_tp_getattro, PyObject_GenericGetAttr}, + {Py_tp_doc, (void *)itertools_serialize__doc__}, + {Py_tp_traverse, serialize_traverse}, + {Py_tp_iter, PyObject_SelfIter}, + {Py_tp_iternext, serialize_next}, + {Py_tp_new, itertools_serialize}, + {Py_tp_free, PyObject_GC_Del}, + {0, NULL}, +}; + +static PyType_Spec serialize_spec = { + .name = "itertools.serialize", + .basicsize = sizeof(serializeobject), + .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_IMMUTABLETYPE), + .slots = serialize_slots, +}; + /* ziplongest object *********************************************************/ @@ -3963,6 +4059,7 @@ itertoolsmodule_traverse(PyObject *mod, visitproc visit, void *arg) Py_VISIT(state->permutations_type); Py_VISIT(state->product_type); Py_VISIT(state->repeat_type); + Py_VISIT(state->serialize_type); Py_VISIT(state->starmap_type); Py_VISIT(state->takewhile_type); Py_VISIT(state->tee_type); @@ -3992,6 +4089,7 @@ itertoolsmodule_clear(PyObject *mod) Py_CLEAR(state->permutations_type); Py_CLEAR(state->product_type); Py_CLEAR(state->repeat_type); + Py_CLEAR(state->serialize_type); Py_CLEAR(state->starmap_type); Py_CLEAR(state->takewhile_type); Py_CLEAR(state->tee_type); @@ -4038,6 +4136,7 @@ itertoolsmodule_exec(PyObject *mod) ADD_TYPE(mod, state->permutations_type, &permutations_spec); ADD_TYPE(mod, state->product_type, &product_spec); ADD_TYPE(mod, state->repeat_type, &repeat_spec); + ADD_TYPE(mod, state->serialize_type, &serialize_spec); ADD_TYPE(mod, state->starmap_type, &starmap_spec); ADD_TYPE(mod, state->takewhile_type, &takewhile_spec); ADD_TYPE(mod, state->tee_type, &tee_spec);