Skip to content

Fix method calls on custom Record subclasses #678

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ async def _get_statement(
):
if record_class is None:
record_class = self._protocol.get_record_class()
else:
_check_record_class(record_class)

if use_cache:
statement = self._stmt_cache.get(
Expand Down Expand Up @@ -1980,14 +1982,12 @@ async def connect(dsn=None, *,
libpq-connect.html#LIBPQ-CONNSTRING
"""
if not issubclass(connection_class, Connection):
raise TypeError(
raise exceptions.InterfaceError(
'connection_class is expected to be a subclass of '
'asyncpg.Connection, got {!r}'.format(connection_class))

if not issubclass(record_class, protocol.Record):
raise TypeError(
'record_class is expected to be a subclass of '
'asyncpg.Record, got {!r}'.format(record_class))
if record_class is not protocol.Record:
_check_record_class(record_class)

if loop is None:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -2253,4 +2253,25 @@ def _extract_stack(limit=10):
return ''.join(traceback.format_list(stack))


def _check_record_class(record_class):
if record_class is protocol.Record:
pass
elif (
isinstance(record_class, type)
and issubclass(record_class, protocol.Record)
):
if (
record_class.__new__ is not object.__new__
or record_class.__init__ is not object.__init__
):
raise exceptions.InterfaceError(
'record_class must not redefine __new__ or __init__'
)
else:
raise exceptions.InterfaceError(
'record_class is expected to be a subclass of '
'asyncpg.Record, got {!r}'.format(record_class)
)


_uid = 0
47 changes: 30 additions & 17 deletions asyncpg/protocol/record/recordobj.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,26 +227,39 @@ record_richcompare(PyObject *v, PyObject *w, int op)
Py_ssize_t vlen, wlen;
int v_is_tuple = 0;
int w_is_tuple = 0;
int v_is_record = 0;
int w_is_record = 0;
int comp;

if (!ApgRecord_CheckExact(v)) {
if (!PyTuple_Check(v)) {
Py_RETURN_NOTIMPLEMENTED;
}
if (PyTuple_Check(v)) {
v_is_tuple = 1;
}
else if (ApgRecord_CheckExact(v)) {
v_is_record = 1;
}
else if (!ApgRecord_Check(v)) {
Py_RETURN_NOTIMPLEMENTED;
}

if (!ApgRecord_CheckExact(w)) {
if (!PyTuple_Check(w)) {
Py_RETURN_NOTIMPLEMENTED;
}
if (PyTuple_Check(w)) {
w_is_tuple = 1;
}
else if (ApgRecord_CheckExact(w)) {
w_is_record = 1;
}
else if (!ApgRecord_Check(w)) {
Py_RETURN_NOTIMPLEMENTED;
}


#define V_ITEM(i) \
(v_is_tuple ? (PyTuple_GET_ITEM(v, i)) : (ApgRecord_GET_ITEM(v, i)))
(v_is_tuple ? \
PyTuple_GET_ITEM(v, i) \
: (v_is_record ? ApgRecord_GET_ITEM(v, i) : PySequence_GetItem(v, i)))
#define W_ITEM(i) \
(w_is_tuple ? (PyTuple_GET_ITEM(w, i)) : (ApgRecord_GET_ITEM(w, i)))
(w_is_tuple ? \
PyTuple_GET_ITEM(w, i) \
: (w_is_record ? ApgRecord_GET_ITEM(w, i) : PySequence_GetItem(w, i)))

vlen = Py_SIZE(v);
wlen = Py_SIZE(w);
Expand Down Expand Up @@ -546,7 +559,7 @@ record_values(PyObject *o, PyObject *args)
static PyObject *
record_keys(PyObject *o, PyObject *args)
{
if (!ApgRecord_CheckExact(o)) {
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
return NULL;
}
Expand All @@ -558,7 +571,7 @@ record_keys(PyObject *o, PyObject *args)
static PyObject *
record_items(PyObject *o, PyObject *args)
{
if (!ApgRecord_CheckExact(o)) {
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
return NULL;
}
Expand All @@ -570,7 +583,7 @@ record_items(PyObject *o, PyObject *args)
static int
record_contains(ApgRecordObject *o, PyObject *arg)
{
if (!ApgRecord_CheckExact(o)) {
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
return -1;
}
Expand Down Expand Up @@ -686,7 +699,7 @@ record_iter_next(ApgRecordIterObject *it)
seq = it->it_seq;
if (seq == NULL)
return NULL;
assert(ApgRecord_CheckExact(seq));
assert(ApgRecord_Check(seq));

if (it->it_index < Py_SIZE(seq)) {
item = ApgRecord_GET_ITEM(seq, it->it_index);
Expand Down Expand Up @@ -742,7 +755,7 @@ record_iter(PyObject *seq)
{
ApgRecordIterObject *it;

if (!ApgRecord_CheckExact(seq)) {
if (!ApgRecord_Check(seq)) {
PyErr_BadInternalCall();
return NULL;
}
Expand Down Expand Up @@ -800,7 +813,7 @@ record_items_next(ApgRecordItemsObject *it)
if (seq == NULL) {
return NULL;
}
assert(ApgRecord_CheckExact(seq));
assert(ApgRecord_Check(seq));
assert(it->it_key_iter != NULL);

key = PyIter_Next(it->it_key_iter);
Expand Down Expand Up @@ -880,7 +893,7 @@ record_new_items_iter(PyObject *seq)
ApgRecordItemsObject *it;
PyObject *key_iter;

if (!ApgRecord_CheckExact(seq)) {
if (!ApgRecord_Check(seq)) {
PyErr_BadInternalCall();
return NULL;
}
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/record/recordobj.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ extern PyTypeObject ApgRecordItems_Type;

extern PyTypeObject ApgRecordDesc_Type;

#define ApgRecord_Check(self) PyObject_TypeCheck(self, &ApgRecord_Type)
#define ApgRecord_CheckExact(o) (Py_TYPE(o) == &ApgRecord_Type)
#define ApgRecordDesc_CheckExact(o) (Py_TYPE(o) == &ApgRecordDesc_Type)

Expand Down
61 changes: 61 additions & 0 deletions tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,64 @@ async def test_record_subclass_04(self):

r = await ps.fetch()
self.assertIs(type(r[0]), asyncpg.Record)

async def test_record_subclass_05(self):
class MyRecord(asyncpg.Record):
pass

r = await self.con.fetchrow(
"SELECT 1 as a, '2' as b",
record_class=MyRecord,
)
self.assertIsInstance(r, MyRecord)

self.assertEqual(list(r.items()), [('a', 1), ('b', '2')])
self.assertEqual(list(r.keys()), ['a', 'b'])
self.assertEqual(list(r.values()), [1, '2'])
self.assertIn('b', r)
self.assertEqual(next(iter(r)), 1)

async def test_record_subclass_06(self):
class MyRecord(asyncpg.Record):
def __init__(self):
raise AssertionError('this is not supposed to be called')

class MyRecord2(asyncpg.Record):
def __new__(cls):
raise AssertionError('this is not supposed to be called')

class MyRecordBad:
pass

with self.assertRaisesRegex(
asyncpg.InterfaceError,
'record_class must not redefine __new__ or __init__',
):
await self.con.fetchrow(
"SELECT 1 as a, '2' as b",
record_class=MyRecord,
)

with self.assertRaisesRegex(
asyncpg.InterfaceError,
'record_class must not redefine __new__ or __init__',
):
await self.con.fetchrow(
"SELECT 1 as a, '2' as b",
record_class=MyRecord2,
)

with self.assertRaisesRegex(
asyncpg.InterfaceError,
'record_class is expected to be a subclass of asyncpg.Record',
):
await self.con.fetchrow(
"SELECT 1 as a, '2' as b",
record_class=MyRecordBad,
)

with self.assertRaisesRegex(
asyncpg.InterfaceError,
'record_class is expected to be a subclass of asyncpg.Record',
):
await self.connect(record_class=MyRecordBad)