Skip to content

Commit 2b93ee5

Browse files
committed
Implement Record.get()
`Record.get()` allows record objects to better masquerade as dicts. Fixes: #330.
1 parent 0ba8a3a commit 2b93ee5

File tree

3 files changed

+100
-36
lines changed

3 files changed

+100
-36
lines changed

asyncpg/protocol/record/recordobj.c

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,62 @@ record_item(ApgRecordObject *o, Py_ssize_t i)
254254
}
255255

256256

257+
typedef enum item_by_name_result {
258+
APG_ITEM_FOUND = 0,
259+
APG_ERROR = -1,
260+
APG_ITEM_NOT_FOUND = -2
261+
} item_by_name_result_t;
262+
263+
264+
/* Lookup a record value by its name. Return 0 on success, -2 if the
265+
* value was not found (with KeyError set), and -1 on all other errors.
266+
*/
267+
static item_by_name_result_t
268+
record_item_by_name(ApgRecordObject *o, PyObject *item, PyObject **result)
269+
{
270+
PyObject *mapped;
271+
PyObject *val;
272+
Py_ssize_t i;
273+
274+
mapped = PyObject_GetItem(o->desc->mapping, item);
275+
if (mapped == NULL) {
276+
goto noitem;
277+
}
278+
279+
if (!PyIndex_Check(mapped)) {
280+
Py_DECREF(mapped);
281+
goto error;
282+
}
283+
284+
i = PyNumber_AsSsize_t(mapped, PyExc_IndexError);
285+
Py_DECREF(mapped);
286+
287+
if (i < 0) {
288+
if (PyErr_Occurred())
289+
PyErr_Clear();
290+
goto error;
291+
}
292+
293+
val = record_item(o, i);
294+
if (val == NULL) {
295+
PyErr_Clear();
296+
goto error;
297+
}
298+
299+
*result = val;
300+
301+
return APG_ITEM_FOUND;
302+
303+
noitem:
304+
PyErr_SetObject(PyExc_KeyError, item);
305+
return APG_ITEM_NOT_FOUND;
306+
307+
error:
308+
PyErr_SetString(PyExc_RuntimeError, "invalid record descriptor");
309+
return APG_ERROR;
310+
}
311+
312+
257313
static PyObject *
258314
record_subscript(ApgRecordObject* o, PyObject* item)
259315
{
@@ -299,42 +355,13 @@ record_subscript(ApgRecordObject* o, PyObject* item)
299355
}
300356
}
301357
else {
302-
PyObject *mapped;
303-
mapped = PyObject_GetItem(o->desc->mapping, item);
304-
if (mapped != NULL) {
305-
Py_ssize_t i;
306-
PyObject *result;
307-
308-
if (!PyIndex_Check(mapped)) {
309-
Py_DECREF(mapped);
310-
goto noitem;
311-
}
312-
313-
i = PyNumber_AsSsize_t(mapped, PyExc_IndexError);
314-
Py_DECREF(mapped);
315-
316-
if (i < 0) {
317-
if (PyErr_Occurred()) {
318-
PyErr_Clear();
319-
}
320-
goto noitem;
321-
}
358+
PyObject* result;
322359

323-
result = record_item(o, i);
324-
if (result == NULL) {
325-
PyErr_Clear();
326-
goto noitem;
327-
}
360+
if (record_item_by_name(o, item, &result) < 0)
361+
return NULL;
362+
else
328363
return result;
329-
}
330-
else {
331-
goto noitem;
332-
}
333364
}
334-
335-
noitem:
336-
_PyErr_SetKeyError(item);
337-
return NULL;
338365
}
339366

340367

@@ -483,6 +510,28 @@ record_contains(ApgRecordObject *o, PyObject *arg)
483510
}
484511

485512

513+
static PyObject *
514+
record_get(ApgRecordObject* o, PyObject* args)
515+
{
516+
PyObject *key;
517+
PyObject *defval = Py_None;
518+
PyObject *val = NULL;
519+
int res;
520+
521+
if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &defval))
522+
return NULL;
523+
524+
res = record_item_by_name(o, key, &val);
525+
if (res == APG_ITEM_NOT_FOUND) {
526+
PyErr_Clear();
527+
Py_INCREF(defval);
528+
val = defval;
529+
}
530+
531+
return val;
532+
}
533+
534+
486535
static PySequenceMethods record_as_sequence = {
487536
(lenfunc)record_length, /* sq_length */
488537
0, /* sq_concat */
@@ -506,6 +555,7 @@ static PyMethodDef record_methods[] = {
506555
{"values", (PyCFunction)record_values, METH_NOARGS},
507556
{"keys", (PyCFunction)record_keys, METH_NOARGS},
508557
{"items", (PyCFunction)record_items, METH_NOARGS},
558+
{"get", (PyCFunction)record_get, METH_VARARGS},
509559
{NULL, NULL} /* sentinel */
510560
};
511561

docs/api/index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ items either by a numeric index or by a field name:
302302

303303
Return an iterator over the *values* of the record *r*.
304304

305+
.. describe:: get(name[, default])
306+
307+
Return the value for *name* if the record has a field named *name*,
308+
else return *default*. If *default* is not given, return ``None``.
309+
310+
.. versionadded:: 0.18
311+
305312
.. method:: values()
306313

307314
Return an iterator over the record values.

tests/test_record.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_record_gc(self):
4949
mapping = {key: val}
5050
with self.checkref(key, val):
5151
r = Record(mapping, (0,))
52-
with self.assertRaises(KeyError):
52+
with self.assertRaises(RuntimeError):
5353
r[key]
5454
del r
5555

@@ -58,7 +58,7 @@ def test_record_gc(self):
5858
mapping = {key: val}
5959
with self.checkref(key, val):
6060
r = Record(mapping, (0,))
61-
with self.assertRaises(KeyError):
61+
with self.assertRaises(RuntimeError):
6262
r[key]
6363
del r
6464

@@ -90,7 +90,7 @@ def test_record_len_getindex(self):
9090
with self.assertRaisesRegex(KeyError, 'spam'):
9191
Record(None, (1,))['spam']
9292

93-
with self.assertRaisesRegex(KeyError, 'spam'):
93+
with self.assertRaisesRegex(RuntimeError, 'invalid record descriptor'):
9494
Record({'spam': 123}, (1,))['spam']
9595

9696
def test_record_slice(self):
@@ -272,6 +272,13 @@ def test_record_cmp(self):
272272
sorted([r1, r2, r3, r4, r5, r6, r7]),
273273
[r1, r2, r3, r6, r7, r4, r5])
274274

275+
def test_record_get(self):
276+
r = Record(R_AB, (42, 43))
277+
with self.checkref(r):
278+
self.assertEqual(r.get('a'), 42)
279+
self.assertEqual(r.get('nonexistent'), None)
280+
self.assertEqual(r.get('nonexistent', 'default'), 'default')
281+
275282
def test_record_not_pickleable(self):
276283
r = Record(R_A, (42,))
277284
with self.assertRaises(Exception):

0 commit comments

Comments
 (0)