Skip to content

Commit b99824a

Browse files
bpo-40824: Do not mask errors in __iter__ in "in" and the operator module. (GH-20537)
Unexpected errors in calling the __iter__ method are no longer masked by TypeError in the "in" operator and functions operator.contains(), operator.indexOf() and operator.countOf(). (cherry picked from commit cafe1b6) Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
1 parent d5ee9b9 commit b99824a

File tree

4 files changed

+21
-1
lines changed

4 files changed

+21
-1
lines changed

Lib/test/test_iter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def __getitem__(self, i):
6262
return i
6363
__iter__ = None
6464

65+
class BadIterableClass:
66+
def __iter__(self):
67+
raise ZeroDivisionError
68+
6569
# Main test suite
6670

6771
class TestCase(unittest.TestCase):
@@ -637,6 +641,7 @@ def test_in_and_not_in(self):
637641

638642
self.assertRaises(TypeError, lambda: 3 in 12)
639643
self.assertRaises(TypeError, lambda: 3 not in map)
644+
self.assertRaises(ZeroDivisionError, lambda: 3 in BadIterableClass())
640645

641646
d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
642647
for k in d:
@@ -719,6 +724,7 @@ def test_indexOf(self):
719724

720725
self.assertRaises(TypeError, indexOf, 42, 1)
721726
self.assertRaises(TypeError, indexOf, indexOf, indexOf)
727+
self.assertRaises(ZeroDivisionError, indexOf, BadIterableClass(), 1)
722728

723729
f = open(TESTFN, "w")
724730
try:
@@ -1006,6 +1012,7 @@ def test_free_after_iterating(self):
10061012
def test_error_iter(self):
10071013
for typ in (DefaultIterClass, NoIterClass):
10081014
self.assertRaises(TypeError, iter, typ())
1015+
self.assertRaises(ZeroDivisionError, iter, BadIterableClass())
10091016

10101017

10111018
def test_main():

Lib/test/test_operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def __mul__(self, other):
3535
def __rmul__(self, other):
3636
return other * self.lst
3737

38+
class BadIterable:
39+
def __iter__(self):
40+
raise ZeroDivisionError
41+
3842

3943
class OperatorTestCase:
4044
def test_lt(self):
@@ -142,6 +146,7 @@ def test_countOf(self):
142146
operator = self.module
143147
self.assertRaises(TypeError, operator.countOf)
144148
self.assertRaises(TypeError, operator.countOf, None, None)
149+
self.assertRaises(ZeroDivisionError, operator.countOf, BadIterable(), 1)
145150
self.assertEqual(operator.countOf([1, 2, 1, 3, 1, 4], 3), 1)
146151
self.assertEqual(operator.countOf([1, 2, 1, 3, 1, 4], 5), 0)
147152

@@ -176,6 +181,7 @@ def test_indexOf(self):
176181
operator = self.module
177182
self.assertRaises(TypeError, operator.indexOf)
178183
self.assertRaises(TypeError, operator.indexOf, None, None)
184+
self.assertRaises(ZeroDivisionError, operator.indexOf, BadIterable(), 1)
179185
self.assertEqual(operator.indexOf([4, 3, 2, 1], 3), 1)
180186
self.assertRaises(ValueError, operator.indexOf, [4, 3, 2, 1], 0)
181187

@@ -258,6 +264,7 @@ def test_contains(self):
258264
operator = self.module
259265
self.assertRaises(TypeError, operator.contains)
260266
self.assertRaises(TypeError, operator.contains, None, None)
267+
self.assertRaises(ZeroDivisionError, operator.contains, BadIterable(), 1)
261268
self.assertTrue(operator.contains(range(4), 2))
262269
self.assertFalse(operator.contains(range(4), 5))
263270

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Unexpected errors in calling the ``__iter__`` method are no longer masked by
2+
``TypeError`` in the :keyword:`in` operator and functions
3+
:func:`~operator.contains`, :func:`~operator.indexOf` and
4+
:func:`~operator.countOf` of the :mod:`operator` module.

Objects/abstract.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1999,7 +1999,9 @@ _PySequence_IterSearch(PyObject *seq, PyObject *obj, int operation)
19991999

20002000
it = PyObject_GetIter(seq);
20012001
if (it == NULL) {
2002-
type_error("argument of type '%.200s' is not iterable", seq);
2002+
if (PyErr_ExceptionMatches(PyExc_TypeError)) {
2003+
type_error("argument of type '%.200s' is not iterable", seq);
2004+
}
20032005
return -1;
20042006
}
20052007

0 commit comments

Comments
 (0)