Skip to content

Commit c162909

Browse files
committed
Fix method calls on custom Record subclasses
The current implementation has a bunch of `CheckExact` calls in front of most `Record` methods, effectively making them unusable in subclasses. Relax the check to include the subclasses. While at it, add a check that `__init__` and `__new__` are not redefined on the provided Record subclass. Doing so is pointless, because `Record` instance initialization effectively bypasses both, so raise an `InterfaceError` to avoid needless confusion. Fixes: #676
1 parent 3d0e23f commit c162909

File tree

4 files changed

+118
-22
lines changed

4 files changed

+118
-22
lines changed

asyncpg/connection.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ async def _get_statement(
354354
):
355355
if record_class is None:
356356
record_class = self._protocol.get_record_class()
357+
else:
358+
_check_record_class(record_class)
357359

358360
if use_cache:
359361
statement = self._stmt_cache.get(
@@ -1980,14 +1982,12 @@ async def connect(dsn=None, *,
19801982
libpq-connect.html#LIBPQ-CONNSTRING
19811983
"""
19821984
if not issubclass(connection_class, Connection):
1983-
raise TypeError(
1985+
raise exceptions.InterfaceError(
19841986
'connection_class is expected to be a subclass of '
19851987
'asyncpg.Connection, got {!r}'.format(connection_class))
19861988

1987-
if not issubclass(record_class, protocol.Record):
1988-
raise TypeError(
1989-
'record_class is expected to be a subclass of '
1990-
'asyncpg.Record, got {!r}'.format(record_class))
1989+
if record_class is not protocol.Record:
1990+
_check_record_class(record_class)
19911991

19921992
if loop is None:
19931993
loop = asyncio.get_event_loop()
@@ -2253,4 +2253,25 @@ def _extract_stack(limit=10):
22532253
return ''.join(traceback.format_list(stack))
22542254

22552255

2256+
def _check_record_class(record_class):
2257+
if record_class is protocol.Record:
2258+
pass
2259+
elif (
2260+
isinstance(record_class, type)
2261+
and issubclass(record_class, protocol.Record)
2262+
):
2263+
if (
2264+
record_class.__new__ is not object.__new__
2265+
or record_class.__init__ is not object.__init__
2266+
):
2267+
raise exceptions.InterfaceError(
2268+
'record_class must not redefine __new__ or __init__'
2269+
)
2270+
else:
2271+
raise exceptions.InterfaceError(
2272+
'record_class is expected to be a subclass of '
2273+
'asyncpg.Record, got {!r}'.format(record_class)
2274+
)
2275+
2276+
22562277
_uid = 0

asyncpg/protocol/record/recordobj.c

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -227,26 +227,39 @@ record_richcompare(PyObject *v, PyObject *w, int op)
227227
Py_ssize_t vlen, wlen;
228228
int v_is_tuple = 0;
229229
int w_is_tuple = 0;
230+
int v_is_record = 0;
231+
int w_is_record = 0;
230232
int comp;
231233

232-
if (!ApgRecord_CheckExact(v)) {
233-
if (!PyTuple_Check(v)) {
234-
Py_RETURN_NOTIMPLEMENTED;
235-
}
234+
if (PyTuple_Check(v)) {
236235
v_is_tuple = 1;
237236
}
237+
else if (ApgRecord_CheckExact(v)) {
238+
v_is_record = 1;
239+
}
240+
else if (!ApgRecord_Check(v)) {
241+
Py_RETURN_NOTIMPLEMENTED;
242+
}
238243

239-
if (!ApgRecord_CheckExact(w)) {
240-
if (!PyTuple_Check(w)) {
241-
Py_RETURN_NOTIMPLEMENTED;
242-
}
244+
if (PyTuple_Check(w)) {
243245
w_is_tuple = 1;
244246
}
247+
else if (ApgRecord_CheckExact(w)) {
248+
w_is_record = 1;
249+
}
250+
else if (!ApgRecord_Check(w)) {
251+
Py_RETURN_NOTIMPLEMENTED;
252+
}
253+
245254

246255
#define V_ITEM(i) \
247-
(v_is_tuple ? (PyTuple_GET_ITEM(v, i)) : (ApgRecord_GET_ITEM(v, i)))
256+
(v_is_tuple ? \
257+
PyTuple_GET_ITEM(v, i) \
258+
: (v_is_record ? ApgRecord_GET_ITEM(v, i) : PySequence_GetItem(v, i)))
248259
#define W_ITEM(i) \
249-
(w_is_tuple ? (PyTuple_GET_ITEM(w, i)) : (ApgRecord_GET_ITEM(w, i)))
260+
(w_is_tuple ? \
261+
PyTuple_GET_ITEM(w, i) \
262+
: (w_is_record ? ApgRecord_GET_ITEM(w, i) : PySequence_GetItem(w, i)))
250263

251264
vlen = Py_SIZE(v);
252265
wlen = Py_SIZE(w);
@@ -546,7 +559,7 @@ record_values(PyObject *o, PyObject *args)
546559
static PyObject *
547560
record_keys(PyObject *o, PyObject *args)
548561
{
549-
if (!ApgRecord_CheckExact(o)) {
562+
if (!ApgRecord_Check(o)) {
550563
PyErr_BadInternalCall();
551564
return NULL;
552565
}
@@ -558,7 +571,7 @@ record_keys(PyObject *o, PyObject *args)
558571
static PyObject *
559572
record_items(PyObject *o, PyObject *args)
560573
{
561-
if (!ApgRecord_CheckExact(o)) {
574+
if (!ApgRecord_Check(o)) {
562575
PyErr_BadInternalCall();
563576
return NULL;
564577
}
@@ -570,7 +583,7 @@ record_items(PyObject *o, PyObject *args)
570583
static int
571584
record_contains(ApgRecordObject *o, PyObject *arg)
572585
{
573-
if (!ApgRecord_CheckExact(o)) {
586+
if (!ApgRecord_Check(o)) {
574587
PyErr_BadInternalCall();
575588
return -1;
576589
}
@@ -686,7 +699,7 @@ record_iter_next(ApgRecordIterObject *it)
686699
seq = it->it_seq;
687700
if (seq == NULL)
688701
return NULL;
689-
assert(ApgRecord_CheckExact(seq));
702+
assert(ApgRecord_Check(seq));
690703

691704
if (it->it_index < Py_SIZE(seq)) {
692705
item = ApgRecord_GET_ITEM(seq, it->it_index);
@@ -742,7 +755,7 @@ record_iter(PyObject *seq)
742755
{
743756
ApgRecordIterObject *it;
744757

745-
if (!ApgRecord_CheckExact(seq)) {
758+
if (!ApgRecord_Check(seq)) {
746759
PyErr_BadInternalCall();
747760
return NULL;
748761
}
@@ -800,7 +813,7 @@ record_items_next(ApgRecordItemsObject *it)
800813
if (seq == NULL) {
801814
return NULL;
802815
}
803-
assert(ApgRecord_CheckExact(seq));
816+
assert(ApgRecord_Check(seq));
804817
assert(it->it_key_iter != NULL);
805818

806819
key = PyIter_Next(it->it_key_iter);
@@ -880,7 +893,7 @@ record_new_items_iter(PyObject *seq)
880893
ApgRecordItemsObject *it;
881894
PyObject *key_iter;
882895

883-
if (!ApgRecord_CheckExact(seq)) {
896+
if (!ApgRecord_Check(seq)) {
884897
PyErr_BadInternalCall();
885898
return NULL;
886899
}

asyncpg/protocol/record/recordobj.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ extern PyTypeObject ApgRecordItems_Type;
3737

3838
extern PyTypeObject ApgRecordDesc_Type;
3939

40+
#define ApgRecord_Check(self) PyObject_TypeCheck(self, &ApgRecord_Type)
4041
#define ApgRecord_CheckExact(o) (Py_TYPE(o) == &ApgRecord_Type)
4142
#define ApgRecordDesc_CheckExact(o) (Py_TYPE(o) == &ApgRecordDesc_Type)
4243

tests/test_record.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,64 @@ async def test_record_subclass_04(self):
513513

514514
r = await ps.fetch()
515515
self.assertIs(type(r[0]), asyncpg.Record)
516+
517+
async def test_record_subclass_05(self):
518+
class MyRecord(asyncpg.Record):
519+
pass
520+
521+
r = await self.con.fetchrow(
522+
"SELECT 1 as a, '2' as b",
523+
record_class=MyRecord,
524+
)
525+
self.assertIsInstance(r, MyRecord)
526+
527+
self.assertEqual(list(r.items()), [('a', 1), ('b', '2')])
528+
self.assertEqual(list(r.keys()), ['a', 'b'])
529+
self.assertEqual(list(r.values()), [1, '2'])
530+
self.assertIn('b', r)
531+
self.assertEqual(next(iter(r)), 1)
532+
533+
async def test_record_subclass_06(self):
534+
class MyRecord(asyncpg.Record):
535+
def __init__(self):
536+
raise AssertionError('this is not supposed to be called')
537+
538+
class MyRecord2(asyncpg.Record):
539+
def __new__(cls):
540+
raise AssertionError('this is not supposed to be called')
541+
542+
class MyRecordBad:
543+
pass
544+
545+
with self.assertRaisesRegex(
546+
asyncpg.InterfaceError,
547+
'record_class must not redefine __new__ or __init__',
548+
):
549+
await self.con.fetchrow(
550+
"SELECT 1 as a, '2' as b",
551+
record_class=MyRecord,
552+
)
553+
554+
with self.assertRaisesRegex(
555+
asyncpg.InterfaceError,
556+
'record_class must not redefine __new__ or __init__',
557+
):
558+
await self.con.fetchrow(
559+
"SELECT 1 as a, '2' as b",
560+
record_class=MyRecord2,
561+
)
562+
563+
with self.assertRaisesRegex(
564+
asyncpg.InterfaceError,
565+
'record_class is expected to be a subclass of asyncpg.Record',
566+
):
567+
await self.con.fetchrow(
568+
"SELECT 1 as a, '2' as b",
569+
record_class=MyRecordBad,
570+
)
571+
572+
with self.assertRaisesRegex(
573+
asyncpg.InterfaceError,
574+
'record_class is expected to be a subclass of asyncpg.Record',
575+
):
576+
await self.connect(record_class=MyRecordBad)

0 commit comments

Comments
 (0)