Skip to content

Commit a2fa7b2

Browse files
committed
Prohibit non-iterable containers to be passed as array input
Make the array input type check for correct ABCs instead of producing a cryptic error when a non-iterable container is passed.
1 parent 2f558c2 commit a2fa7b2

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

asyncpg/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
__all__ = ('connect', 'create_pool', 'Record', 'Connection') + \
1616
exceptions.__all__ # NOQA
1717

18-
__version__ = '0.15.0'
18+
__version__ = '0.16.0.dev0'

asyncpg/protocol/codecs/array.pyx

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
from collections.abc import Container as ContainerABC
8+
from collections.abc import Iterable as IterableABC, Sized as SizedABC
99

1010

1111
DEF ARRAY_MAXDIM = 6 # defined in postgresql/src/includes/c.h
@@ -30,13 +30,18 @@ cdef inline bint _is_trivial_container(object obj):
3030
PyByteArray_Check(obj) or PyMemoryView_Check(obj)
3131

3232

33-
cdef inline _is_container(object obj):
34-
return not _is_trivial_container(obj) and isinstance(obj, ContainerABC)
33+
cdef inline _is_array_iterable(object obj):
34+
return (
35+
isinstance(obj, IterableABC) and
36+
isinstance(obj, SizedABC) and
37+
not _is_trivial_container(obj)
38+
)
3539

3640

37-
cdef inline _is_sub_array(object obj):
38-
return not _is_trivial_container(obj) and isinstance(obj, ContainerABC) \
39-
and not cpython.PyTuple_Check(obj)
41+
cdef inline _is_sub_array_iterable(object obj):
42+
# Sub-arrays have a specialized check, because we treat
43+
# nested tuples as records.
44+
return _is_array_iterable(obj) and not cpython.PyTuple_Check(obj)
4045

4146

4247
cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
@@ -56,7 +61,7 @@ cdef _get_array_shape(object obj, int32_t *dims, int32_t *ndims):
5661
dims[ndims[0] - 1] = <int32_t>mylen
5762

5863
for elem in obj:
59-
if _is_sub_array(elem):
64+
if _is_sub_array_iterable(elem):
6065
if elemlen == -2:
6166
elemlen = len(elem)
6267
if elemlen > _MAXINT32:
@@ -101,9 +106,9 @@ cdef inline array_encode(ConnectionSettings settings, WriteBuffer buf,
101106
int32_t ndims = 1
102107
int32_t i
103108

104-
if not _is_container(obj):
109+
if not _is_array_iterable(obj):
105110
raise TypeError(
106-
'a non-trivial iterable expected (got type {!r})'.format(
111+
'a sized iterable container expected (got type {!r})'.format(
107112
type(obj).__name__))
108113

109114
_get_array_shape(obj, dims, &ndims)
@@ -247,9 +252,9 @@ cdef inline textarray_encode(ConnectionSettings settings, WriteBuffer buf,
247252
int32_t ndims = 1
248253
int32_t i
249254

250-
if not _is_container(obj):
255+
if not _is_array_iterable(obj):
251256
raise TypeError(
252-
'a non-trivial iterable expected (got type {!r})'.format(
257+
'a sized iterable container expected (got type {!r})'.format(
253258
type(obj).__name__))
254259

255260
_get_array_shape(obj, dims, &ndims)

tests/test_codecs.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,27 @@ async def test_arrays(self):
661661

662662
self.assertEqual(result, case, err_msg)
663663

664+
# A sized iterable is fine as array input.
665+
class Iterable:
666+
def __iter__(self):
667+
return iter([1, 2, 3])
668+
669+
def __len__(self):
670+
return 3
671+
672+
result = await self.con.fetchval("SELECT $1::int[]", Iterable())
673+
self.assertEqual(result, [1, 2, 3])
674+
675+
# A pure container is _not_ OK for array input.
676+
class SomeContainer:
677+
def __contains__(self, item):
678+
return False
679+
680+
with self.assertRaisesRegex(TypeError,
681+
'sized iterable container expected'):
682+
result = await self.con.fetchval("SELECT $1::int[]",
683+
SomeContainer())
684+
664685
with self.assertRaisesRegex(ValueError, 'dimensions'):
665686
await self.con.fetchval(
666687
"SELECT $1::int[]",
@@ -687,7 +708,7 @@ async def test_arrays(self):
687708
[[1], ['t'], [2]])
688709

689710
with self.assertRaisesRegex(TypeError,
690-
'non-trivial iterable expected'):
711+
'sized iterable container expected'):
691712
await self.con.fetchval(
692713
"SELECT $1::int[]",
693714
1)

0 commit comments

Comments
 (0)