Skip to content

Commit 1b6bef8

Browse files
gh-129107: make bytearray iterator thread safe (#130096)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
1 parent 388e1ca commit 1b6bef8

File tree

3 files changed

+81
-26
lines changed

3 files changed

+81
-26
lines changed

Lib/test/test_bytes.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2455,9 +2455,6 @@ def check(funcs, a=None, *args):
24552455
with threading_helper.start_threads(threads):
24562456
pass
24572457

2458-
for thread in threads:
2459-
threading_helper.join_thread(thread)
2460-
24612458
# hard errors
24622459

24632460
check([clear] + [reduce] * 10)
@@ -2519,6 +2516,44 @@ def check(funcs, a=None, *args):
25192516
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
25202517
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))
25212518

2519+
@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
2520+
@threading_helper.reap_threads
2521+
@threading_helper.requires_working_threading()
2522+
def test_free_threading_bytearrayiter(self):
2523+
# Non-deterministic but good chance to fail if bytearrayiter is not free-threading safe.
2524+
# We are fishing for a "Assertion failed: object has negative ref count" and tsan races.
2525+
2526+
def iter_next(b, it):
2527+
b.wait()
2528+
list(it)
2529+
2530+
def iter_reduce(b, it):
2531+
b.wait()
2532+
it.__reduce__()
2533+
2534+
def iter_setstate(b, it):
2535+
b.wait()
2536+
it.__setstate__(0)
2537+
2538+
def check(funcs, it):
2539+
barrier = threading.Barrier(len(funcs))
2540+
threads = []
2541+
2542+
for func in funcs:
2543+
thread = threading.Thread(target=func, args=(barrier, it))
2544+
2545+
threads.append(thread)
2546+
2547+
with threading_helper.start_threads(threads):
2548+
pass
2549+
2550+
for _ in range(10):
2551+
ba = bytearray(b'0' * 0x4000) # this is a load-bearing variable, do not remove
2552+
2553+
check([iter_next] * 10, iter(ba))
2554+
check([iter_next] + [iter_reduce] * 10, iter(ba)) # for tsan
2555+
check([iter_next] + [iter_setstate] * 10, iter(ba)) # for tsan
2556+
25222557

25232558
if __name__ == "__main__":
25242559
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make :class:`bytearray` iterator safe under :term:`free threading`.

Objects/bytearrayobject.c

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2856,31 +2856,44 @@ static PyObject *
28562856
bytearrayiter_next(PyObject *self)
28572857
{
28582858
bytesiterobject *it = _bytesiterobject_CAST(self);
2859-
PyByteArrayObject *seq;
2859+
int val;
28602860

28612861
assert(it != NULL);
2862-
seq = it->it_seq;
2863-
if (seq == NULL)
2862+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2863+
if (index < 0) {
28642864
return NULL;
2865+
}
2866+
PyByteArrayObject *seq = it->it_seq;
28652867
assert(PyByteArray_Check(seq));
28662868

2867-
if (it->it_index < PyByteArray_GET_SIZE(seq)) {
2868-
return _PyLong_FromUnsignedChar(
2869-
(unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
2869+
Py_BEGIN_CRITICAL_SECTION(seq);
2870+
if (index < Py_SIZE(seq)) {
2871+
val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
2872+
}
2873+
else {
2874+
val = -1;
28702875
}
2876+
Py_END_CRITICAL_SECTION();
28712877

2872-
it->it_seq = NULL;
2873-
Py_DECREF(seq);
2874-
return NULL;
2878+
if (val == -1) {
2879+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
2880+
#ifndef Py_GIL_DISABLED
2881+
Py_CLEAR(it->it_seq);
2882+
#endif
2883+
return NULL;
2884+
}
2885+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
2886+
return _PyLong_FromUnsignedChar((unsigned char)val);
28752887
}
28762888

28772889
static PyObject *
28782890
bytearrayiter_length_hint(PyObject *self, PyObject *Py_UNUSED(ignored))
28792891
{
28802892
bytesiterobject *it = _bytesiterobject_CAST(self);
28812893
Py_ssize_t len = 0;
2882-
if (it->it_seq) {
2883-
len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
2894+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2895+
if (index >= 0) {
2896+
len = PyByteArray_GET_SIZE(it->it_seq) - index;
28842897
if (len < 0) {
28852898
len = 0;
28862899
}
@@ -2900,27 +2913,33 @@ bytearrayiter_reduce(PyObject *self, PyObject *Py_UNUSED(ignored))
29002913
* call must be before access of iterator pointers.
29012914
* see issue #101765 */
29022915
bytesiterobject *it = _bytesiterobject_CAST(self);
2903-
if (it->it_seq != NULL) {
2904-
return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
2905-
} else {
2906-
return Py_BuildValue("N(())", iter);
2916+
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
2917+
if (index >= 0) {
2918+
return Py_BuildValue("N(O)n", iter, it->it_seq, index);
29072919
}
2920+
return Py_BuildValue("N(())", iter);
29082921
}
29092922

29102923
static PyObject *
29112924
bytearrayiter_setstate(PyObject *self, PyObject *state)
29122925
{
29132926
Py_ssize_t index = PyLong_AsSsize_t(state);
2914-
if (index == -1 && PyErr_Occurred())
2927+
if (index == -1 && PyErr_Occurred()) {
29152928
return NULL;
2929+
}
29162930

29172931
bytesiterobject *it = _bytesiterobject_CAST(self);
2918-
if (it->it_seq != NULL) {
2919-
if (index < 0)
2920-
index = 0;
2921-
else if (index > PyByteArray_GET_SIZE(it->it_seq))
2922-
index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
2923-
it->it_index = index;
2932+
if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
2933+
if (index < -1) {
2934+
index = -1;
2935+
}
2936+
else {
2937+
Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
2938+
if (index > size) {
2939+
index = size; /* iterator at end */
2940+
}
2941+
}
2942+
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
29242943
}
29252944
Py_RETURN_NONE;
29262945
}
@@ -2982,7 +3001,7 @@ bytearray_iter(PyObject *seq)
29823001
it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
29833002
if (it == NULL)
29843003
return NULL;
2985-
it->it_index = 0;
3004+
it->it_index = 0; // -1 indicates exhausted
29863005
it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
29873006
_PyObject_GC_TRACK(it);
29883007
return (PyObject *)it;

0 commit comments

Comments
 (0)