From f78913117993948723069b5968e209d4bbd00659 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Sat, 20 Jan 2018 14:11:54 -0500 Subject: [PATCH] Initialize statement codecs immediately after Prepare Currently the statement codecs are populated just before the first Bind is issued. This is problematic as in the time since Prepare, the codec cache for derived types (arrays, composites etc.) may have been purged by an installation of a custom codec, or general schema state invalidation. Fix this by populating the codecs immediately after the statement data types have been resolved. Fixes: #241. --- asyncpg/connection.py | 20 ++++++++++++++------ asyncpg/protocol/prepared_stmt.pxd | 1 + asyncpg/protocol/prepared_stmt.pyx | 7 ++++--- asyncpg/protocol/protocol.pyx | 7 +++++-- tests/test_introspection.py | 27 +++++++++++++++++++++++++++ 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index e21c12c4..23141f4d 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -291,12 +291,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False, types, intro_stmt = await self.__execute( self._intro_query, (list(ready),), 0, timeout) self._protocol.get_settings().register_data_types(types) - if not intro_stmt.name and not statement.name: - # The introspection query has used an anonymous statement, - # which has blown away the anonymous statement we've prepared - # for the query, so we need to re-prepare it. - statement = await self._protocol.prepare( - stmt_name, query, timeout) + # The introspection query has used an anonymous statement, + # which has blown away the anonymous statement we've prepared + # for the query, so we need to re-prepare it. + need_reprepare = not intro_stmt.name and not statement.name + else: + need_reprepare = False + + # Now that types have been resolved, populate the codec pipeline + # for the statement. + statement._init_codecs() + + if need_reprepare: + await self._protocol.prepare( + stmt_name, query, timeout, state=statement) if use_cache: self._stmt_cache.put(query, statement) diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 8dab35b1..9749113c 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -30,6 +30,7 @@ cdef class PreparedStatementState: tuple rows_codecs cdef _encode_bind_msg(self, args) + cpdef _init_codecs(self) cdef _ensure_rows_decoder(self) cdef _ensure_args_encoder(self) cdef _set_row_desc(self, object desc) diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index 3edb56f0..e69369d3 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -82,6 +82,10 @@ cdef class PreparedStatementState: else: return True + cpdef _init_codecs(self): + self._ensure_args_encoder() + self._ensure_rows_decoder() + def attach(self): self.refs += 1 @@ -101,9 +105,6 @@ cdef class PreparedStatementState: raise exceptions.InterfaceError( 'the number of query arguments cannot exceed 32767') - self._ensure_args_encoder() - self._ensure_rows_decoder() - writer = WriteBuffer.new() num_args_passed = len(args) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 09fc8c11..983c0ea1 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -146,7 +146,8 @@ cdef class BaseProtocol(CoreProtocol): self.is_reading = False self.transport.pause_reading() - async def prepare(self, stmt_name, query, timeout): + async def prepare(self, stmt_name, query, timeout, + PreparedStatementState state=None): if self.cancel_waiter is not None: await self.cancel_waiter if self.cancel_sent_waiter is not None: @@ -160,7 +161,9 @@ cdef class BaseProtocol(CoreProtocol): try: self._prepare(stmt_name, query) # network op self.last_query = query - self.statement = PreparedStatementState(stmt_name, query, self) + if state is None: + state = PreparedStatementState(stmt_name, query, self) + self.statement = state except Exception as ex: waiter.set_exception(ex) self._coreproto_error() diff --git a/tests/test_introspection.py b/tests/test_introspection.py index d46095f8..fcf5885d 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -5,6 +5,8 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +import json + from asyncpg import _testbase as tb from asyncpg import connection as apg_con @@ -98,3 +100,28 @@ async def test_introspection_no_stmt_cache_03(self): "SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2]) self.assertEqual(apg_con._uid, old_uid + 1) + + async def test_introspection_sticks_for_ps(self): + # Test that the introspected codec pipeline for a prepared + # statement is not affected by a subsequent codec cache bust. + + ps = await self.con._prepare('SELECT $1::json[]', use_cache=True) + + try: + # Setting a custom codec blows the codec cache for derived types. + await self.con.set_type_codec( + 'json', encoder=lambda v: v, decoder=json.loads, + schema='pg_catalog', format='text' + ) + + # The originally prepared statement should still be OK and + # use the previously selected codec. + self.assertEqual(await ps.fetchval(['{"foo": 1}']), ['{"foo": 1}']) + + # The new query uses the custom codec. + v = await self.con.fetchval('SELECT $1::json[]', ['{"foo": 1}']) + self.assertEqual(v, [{'foo': 1}]) + + finally: + await self.con.reset_type_codec( + 'json', schema='pg_catalog')