Skip to content

Commit 92baac1

Browse files
committed
Log query errors, clean up context managers
1 parent 107f196 commit 92baac1

File tree

3 files changed

+74
-82
lines changed

3 files changed

+74
-82
lines changed

asyncpg/connection.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import asyncpg
1010
import collections
1111
import collections.abc
12+
import contextlib
1213
import functools
1314
import itertools
1415
import inspect
@@ -226,10 +227,9 @@ def add_query_logger(self, callback):
226227
"""Add a logger that will be called when queries are executed.
227228
228229
:param callable callback:
229-
A callable or a coroutine function receiving two arguments:
230-
**connection**: a Connection the callback is registered with.
231-
**query**: a LoggedQuery containing the query, args, timeout, and
232-
elapsed.
230+
A callable or a coroutine function receiving one argument:
231+
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
232+
`elapsed`, `addr`, `params`, and `exception`.
233233
234234
.. versionadded:: 0.29.0
235235
"""
@@ -339,9 +339,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
339339
self._check_open()
340340

341341
if not args:
342-
with utils.timer() as t:
342+
with self._time_and_log(query, args, timeout):
343343
result = await self._protocol.query(query, timeout)
344-
self._log_query(query, args, timeout, t.elapsed)
345344
return result
346345

347346
_, status, _ = await self._execute(
@@ -1412,6 +1411,7 @@ def _cleanup(self):
14121411
self._mark_stmts_as_closed()
14131412
self._listeners.clear()
14141413
self._log_listeners.clear()
1414+
self._query_loggers.clear()
14151415
self._clean_tasks()
14161416

14171417
def _clean_tasks(self):
@@ -1695,15 +1695,15 @@ async def _execute(
16951695
)
16961696
return result
16971697

1698+
@contextlib.contextmanager
16981699
def logger(self, callback):
16991700
"""Context manager that adds `callback` to the list of query loggers,
17001701
and removes it upon exit.
17011702
17021703
:param callable callback:
1703-
A callable or a coroutine function receiving two arguments:
1704-
**connection**: a Connection the callback is registered with.
1705-
**query**: a LoggedQuery containing the query, args, timeout, and
1706-
elapsed.
1704+
A callable or a coroutine function receiving one argument:
1705+
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
1706+
`elapsed`, `addr`, and `params`.
17071707
17081708
Example:
17091709
@@ -1721,18 +1721,35 @@ def __call__(self, conn, record):
17211721
17221722
.. versionadded:: 0.29.0
17231723
"""
1724-
return _LoggingContext(self, callback)
1725-
1726-
def _log_query(self, query, args, timeout, elapsed):
1727-
if not self._query_loggers:
1728-
return
1729-
con_ref = self._unwrap()
1730-
record = LoggedQuery(query, args, timeout, elapsed)
1731-
for cb in self._query_loggers:
1732-
if cb.is_async:
1733-
self._loop.create_task(cb.cb(con_ref, record))
1734-
else:
1735-
self._loop.call_soon(cb.cb, con_ref, record)
1724+
self.add_query_logger(callback)
1725+
yield callback
1726+
self.remove_query_logger(callback)
1727+
1728+
@contextlib.contextmanager
1729+
def _time_and_log(self, query, args, timeout):
1730+
start = time.monotonic()
1731+
exception = None
1732+
try:
1733+
yield
1734+
except Exception as ex:
1735+
exception = ex
1736+
raise
1737+
finally:
1738+
elapsed = time.monotonic() - start
1739+
record = LoggedQuery(
1740+
query=query,
1741+
args=args,
1742+
timeout=timeout,
1743+
elapsed=elapsed,
1744+
addr=self._addr,
1745+
params=self._params,
1746+
exception=exception,
1747+
)
1748+
for cb in self._query_loggers:
1749+
if cb.is_async:
1750+
self._loop.create_task(cb.cb(record))
1751+
else:
1752+
self._loop.call_soon(cb.cb, record)
17361753

17371754
async def __execute(
17381755
self,
@@ -1748,25 +1765,23 @@ async def __execute(
17481765
executor = lambda stmt, timeout: self._protocol.bind_execute(
17491766
stmt, args, '', limit, return_status, timeout)
17501767
timeout = self._protocol._get_timeout(timeout)
1751-
with utils.timer() as t:
1768+
with self._time_and_log(query, args, timeout):
17521769
result, stmt = await self._do_execute(
17531770
query,
17541771
executor,
17551772
timeout,
17561773
record_class=record_class,
17571774
ignore_custom_codec=ignore_custom_codec,
17581775
)
1759-
self._log_query(query, args, timeout, t.elapsed)
17601776
return result, stmt
17611777

17621778
async def _executemany(self, query, args, timeout):
17631779
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
17641780
stmt, args, '', timeout)
17651781
timeout = self._protocol._get_timeout(timeout)
17661782
with self._stmt_exclusive_section:
1767-
with utils.timer() as t:
1783+
with self._time_and_log(query, args, timeout):
17681784
result, _ = await self._do_execute(query, executor, timeout)
1769-
self._log_query(query, args, timeout, t.elapsed)
17701785
return result
17711786

17721787
async def _do_execute(
@@ -2401,25 +2416,10 @@ class _ConnectionProxy:
24012416

24022417
LoggedQuery = collections.namedtuple(
24032418
'LoggedQuery',
2404-
['query', 'args', 'timeout', 'elapsed'])
2419+
['query', 'args', 'timeout', 'elapsed', 'exception', 'addr', 'params'])
24052420
LoggedQuery.__doc__ = 'Log record of an executed query.'
24062421

24072422

2408-
class _LoggingContext:
2409-
__slots__ = ('_conn', '_cb')
2410-
2411-
def __init__(self, conn, callback):
2412-
self._conn = conn
2413-
self._cb = callback
2414-
2415-
def __enter__(self):
2416-
self._conn.add_query_logger(self._cb)
2417-
return self._cb
2418-
2419-
def __exit__(self, *exc_info):
2420-
self._conn.remove_query_logger(self._cb)
2421-
2422-
24232423
ServerCapabilities = collections.namedtuple(
24242424
'ServerCapabilities',
24252425
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',

asyncpg/utils.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
import re
9-
import time
109

1110

1211
def _quote_ident(ident):
@@ -44,28 +43,3 @@ async def _mogrify(conn, query, args):
4443
# Finally, replace $n references with text values.
4544
return re.sub(
4645
r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query)
47-
48-
49-
class timer:
50-
__slots__ = ('start', 'elapsed')
51-
52-
def __init__(self):
53-
self.start = time.monotonic()
54-
self.elapsed = None
55-
56-
@property
57-
def current(self):
58-
return time.monotonic() - self.start
59-
60-
def restart(self):
61-
self.start = time.monotonic()
62-
63-
def stop(self):
64-
self.elapsed = self.current
65-
66-
def __enter__(self):
67-
self.restart()
68-
return self
69-
70-
def __exit__(self, *exc):
71-
self.stop()

tests/test_logging.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,48 @@
11
import asyncio
22

33
from asyncpg import _testbase as tb
4+
from asyncpg import exceptions
5+
6+
7+
class LogCollector:
8+
def __init__(self):
9+
self.records = []
10+
11+
def __call__(self, record):
12+
self.records.append(record)
413

514

615
class TestQueryLogging(tb.ConnectedTestCase):
716

817
async def test_logging_context(self):
918
queries = asyncio.Queue()
1019

11-
def query_saver(conn, record):
20+
def query_saver(record):
1221
queries.put_nowait(record)
1322

14-
class QuerySaver:
15-
def __init__(self):
16-
self.queries = []
17-
18-
def __call__(self, conn, record):
19-
self.queries.append(record.query)
20-
2123
with self.con.logger(query_saver):
2224
self.assertEqual(len(self.con._query_loggers), 1)
23-
with self.con.logger(QuerySaver()) as log:
25+
await self.con.execute("SELECT 1")
26+
with self.con.logger(LogCollector()) as log:
2427
self.assertEqual(len(self.con._query_loggers), 2)
25-
await self.con.execute("SELECT 1")
26-
27-
record = await queries.get()
28-
self.assertEqual(record.query, "SELECT 1")
29-
self.assertEqual(log.queries, ["SELECT 1"])
28+
await self.con.execute("SELECT 2")
29+
30+
r1 = await queries.get()
31+
r2 = await queries.get()
32+
self.assertEqual(r1.query, "SELECT 1")
33+
self.assertEqual(r2.query, "SELECT 2")
34+
self.assertEqual(len(log.records), 1)
35+
self.assertEqual(log.records[0].query, "SELECT 2")
3036
self.assertEqual(len(self.con._query_loggers), 0)
37+
38+
async def test_error_logging(self):
39+
with self.con.logger(LogCollector()) as log:
40+
with self.assertRaises(exceptions.UndefinedColumnError):
41+
await self.con.execute("SELECT x")
42+
43+
await asyncio.sleep(0) # wait for logging
44+
self.assertEqual(len(log.records), 1)
45+
self.assertEqual(
46+
type(log.records[0].exception),
47+
exceptions.UndefinedColumnError
48+
)

0 commit comments

Comments
 (0)