diff --git a/asyncpg/connection.py b/asyncpg/connection.py index d33db090..120e3623 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -354,6 +354,8 @@ async def _get_statement( ): if record_class is None: record_class = self._protocol.get_record_class() + else: + _check_record_class(record_class) if use_cache: statement = self._stmt_cache.get( @@ -1980,14 +1982,12 @@ async def connect(dsn=None, *, libpq-connect.html#LIBPQ-CONNSTRING """ if not issubclass(connection_class, Connection): - raise TypeError( + raise exceptions.InterfaceError( 'connection_class is expected to be a subclass of ' 'asyncpg.Connection, got {!r}'.format(connection_class)) - if not issubclass(record_class, protocol.Record): - raise TypeError( - 'record_class is expected to be a subclass of ' - 'asyncpg.Record, got {!r}'.format(record_class)) + if record_class is not protocol.Record: + _check_record_class(record_class) if loop is None: loop = asyncio.get_event_loop() @@ -2253,4 +2253,25 @@ def _extract_stack(limit=10): return ''.join(traceback.format_list(stack)) +def _check_record_class(record_class): + if record_class is protocol.Record: + pass + elif ( + isinstance(record_class, type) + and issubclass(record_class, protocol.Record) + ): + if ( + record_class.__new__ is not object.__new__ + or record_class.__init__ is not object.__init__ + ): + raise exceptions.InterfaceError( + 'record_class must not redefine __new__ or __init__' + ) + else: + raise exceptions.InterfaceError( + 'record_class is expected to be a subclass of ' + 'asyncpg.Record, got {!r}'.format(record_class) + ) + + _uid = 0 diff --git a/asyncpg/protocol/record/recordobj.c b/asyncpg/protocol/record/recordobj.c index b734ee9b..8ee27f59 100644 --- a/asyncpg/protocol/record/recordobj.c +++ b/asyncpg/protocol/record/recordobj.c @@ -227,26 +227,39 @@ record_richcompare(PyObject *v, PyObject *w, int op) Py_ssize_t vlen, wlen; int v_is_tuple = 0; int w_is_tuple = 0; + int v_is_record = 0; + int w_is_record = 0; int comp; - if (!ApgRecord_CheckExact(v)) { - if (!PyTuple_Check(v)) { - Py_RETURN_NOTIMPLEMENTED; - } + if (PyTuple_Check(v)) { v_is_tuple = 1; } + else if (ApgRecord_CheckExact(v)) { + v_is_record = 1; + } + else if (!ApgRecord_Check(v)) { + Py_RETURN_NOTIMPLEMENTED; + } - if (!ApgRecord_CheckExact(w)) { - if (!PyTuple_Check(w)) { - Py_RETURN_NOTIMPLEMENTED; - } + if (PyTuple_Check(w)) { w_is_tuple = 1; } + else if (ApgRecord_CheckExact(w)) { + w_is_record = 1; + } + else if (!ApgRecord_Check(w)) { + Py_RETURN_NOTIMPLEMENTED; + } + #define V_ITEM(i) \ - (v_is_tuple ? (PyTuple_GET_ITEM(v, i)) : (ApgRecord_GET_ITEM(v, i))) + (v_is_tuple ? \ + PyTuple_GET_ITEM(v, i) \ + : (v_is_record ? ApgRecord_GET_ITEM(v, i) : PySequence_GetItem(v, i))) #define W_ITEM(i) \ - (w_is_tuple ? (PyTuple_GET_ITEM(w, i)) : (ApgRecord_GET_ITEM(w, i))) + (w_is_tuple ? \ + PyTuple_GET_ITEM(w, i) \ + : (w_is_record ? ApgRecord_GET_ITEM(w, i) : PySequence_GetItem(w, i))) vlen = Py_SIZE(v); wlen = Py_SIZE(w); @@ -546,7 +559,7 @@ record_values(PyObject *o, PyObject *args) static PyObject * record_keys(PyObject *o, PyObject *args) { - if (!ApgRecord_CheckExact(o)) { + if (!ApgRecord_Check(o)) { PyErr_BadInternalCall(); return NULL; } @@ -558,7 +571,7 @@ record_keys(PyObject *o, PyObject *args) static PyObject * record_items(PyObject *o, PyObject *args) { - if (!ApgRecord_CheckExact(o)) { + if (!ApgRecord_Check(o)) { PyErr_BadInternalCall(); return NULL; } @@ -570,7 +583,7 @@ record_items(PyObject *o, PyObject *args) static int record_contains(ApgRecordObject *o, PyObject *arg) { - if (!ApgRecord_CheckExact(o)) { + if (!ApgRecord_Check(o)) { PyErr_BadInternalCall(); return -1; } @@ -686,7 +699,7 @@ record_iter_next(ApgRecordIterObject *it) seq = it->it_seq; if (seq == NULL) return NULL; - assert(ApgRecord_CheckExact(seq)); + assert(ApgRecord_Check(seq)); if (it->it_index < Py_SIZE(seq)) { item = ApgRecord_GET_ITEM(seq, it->it_index); @@ -742,7 +755,7 @@ record_iter(PyObject *seq) { ApgRecordIterObject *it; - if (!ApgRecord_CheckExact(seq)) { + if (!ApgRecord_Check(seq)) { PyErr_BadInternalCall(); return NULL; } @@ -800,7 +813,7 @@ record_items_next(ApgRecordItemsObject *it) if (seq == NULL) { return NULL; } - assert(ApgRecord_CheckExact(seq)); + assert(ApgRecord_Check(seq)); assert(it->it_key_iter != NULL); key = PyIter_Next(it->it_key_iter); @@ -880,7 +893,7 @@ record_new_items_iter(PyObject *seq) ApgRecordItemsObject *it; PyObject *key_iter; - if (!ApgRecord_CheckExact(seq)) { + if (!ApgRecord_Check(seq)) { PyErr_BadInternalCall(); return NULL; } diff --git a/asyncpg/protocol/record/recordobj.h b/asyncpg/protocol/record/recordobj.h index 2c6c1f1c..373c8967 100644 --- a/asyncpg/protocol/record/recordobj.h +++ b/asyncpg/protocol/record/recordobj.h @@ -37,6 +37,7 @@ extern PyTypeObject ApgRecordItems_Type; extern PyTypeObject ApgRecordDesc_Type; +#define ApgRecord_Check(self) PyObject_TypeCheck(self, &ApgRecord_Type) #define ApgRecord_CheckExact(o) (Py_TYPE(o) == &ApgRecord_Type) #define ApgRecordDesc_CheckExact(o) (Py_TYPE(o) == &ApgRecordDesc_Type) diff --git a/tests/test_record.py b/tests/test_record.py index 8abe90ee..5b85fb4d 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -513,3 +513,64 @@ async def test_record_subclass_04(self): r = await ps.fetch() self.assertIs(type(r[0]), asyncpg.Record) + + async def test_record_subclass_05(self): + class MyRecord(asyncpg.Record): + pass + + r = await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=MyRecord, + ) + self.assertIsInstance(r, MyRecord) + + self.assertEqual(list(r.items()), [('a', 1), ('b', '2')]) + self.assertEqual(list(r.keys()), ['a', 'b']) + self.assertEqual(list(r.values()), [1, '2']) + self.assertIn('b', r) + self.assertEqual(next(iter(r)), 1) + + async def test_record_subclass_06(self): + class MyRecord(asyncpg.Record): + def __init__(self): + raise AssertionError('this is not supposed to be called') + + class MyRecord2(asyncpg.Record): + def __new__(cls): + raise AssertionError('this is not supposed to be called') + + class MyRecordBad: + pass + + with self.assertRaisesRegex( + asyncpg.InterfaceError, + 'record_class must not redefine __new__ or __init__', + ): + await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=MyRecord, + ) + + with self.assertRaisesRegex( + asyncpg.InterfaceError, + 'record_class must not redefine __new__ or __init__', + ): + await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=MyRecord2, + ) + + with self.assertRaisesRegex( + asyncpg.InterfaceError, + 'record_class is expected to be a subclass of asyncpg.Record', + ): + await self.con.fetchrow( + "SELECT 1 as a, '2' as b", + record_class=MyRecordBad, + ) + + with self.assertRaisesRegex( + asyncpg.InterfaceError, + 'record_class is expected to be a subclass of asyncpg.Record', + ): + await self.connect(record_class=MyRecordBad)