Skip to content

Prohibit custom codecs on domains #663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,9 +1160,18 @@ async def set_type_codec(self, typename, *,
self._check_open()
typeinfo = await self._introspect_type(typename, schema)
if not introspection.is_scalar_type(typeinfo):
raise ValueError(
raise exceptions.InterfaceError(
'cannot use custom codec on non-scalar type {}.{}'.format(
schema, typename))
if introspection.is_domain_type(typeinfo):
raise exceptions.UnsupportedClientFeatureError(
'custom codecs on domain types are not supported',
hint='Set the codec on the base type.',
detail=(
'PostgreSQL does not distinguish domains from '
'their base types in query results at the protocol level.'
)
)

oid = typeinfo['oid']
self._protocol.get_settings().add_python_codec(
Expand Down
16 changes: 15 additions & 1 deletion asyncpg/exceptions/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError')
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError')


def _is_asyncpg_class(cls):
Expand Down Expand Up @@ -209,11 +210,24 @@ def __init__(self, msg, *, detail=None, hint=None):
InterfaceMessage.__init__(self, detail=detail, hint=hint)
Exception.__init__(self, msg)

def with_msg(self, msg):
return type(self)(
msg,
detail=self.detail,
hint=self.hint,
).with_traceback(
self.__traceback__
)


class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""


class UnsupportedClientFeatureError(InterfaceError):
"""Requested feature is unsupported by asyncpg."""


class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""

Expand Down
4 changes: 4 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ def is_scalar_type(typeinfo) -> bool:
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)


def is_domain_type(typeinfo) -> bool:
return typeinfo['kind'] == b'd'
9 changes: 4 additions & 5 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ cdef class Codec:
self.decoder = <codec_decode_func>&self.decode_array_text
elif type == CODEC_RANGE:
if format != PG_FORMAT_BINARY:
raise NotImplementedError(
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_range
self.decoder = <codec_decode_func>&self.decode_range
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
raise NotImplementedError(
raise exceptions.UnsupportedClientFeatureError(
'cannot decode type "{}"."{}": text encoding of '
'composite types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_composite
Expand Down Expand Up @@ -675,9 +675,8 @@ cdef class DataCodecConfig:
# added builtin types, for which this version of
# asyncpg is lacking support.
#
raise NotImplementedError(
'unhandled standard data type {!r} (OID {})'.format(
name, oid))
raise exceptions.UnsupportedClientFeatureError(
f'unhandled standard data type {name!r} (OID {oid})')
else:
# This is a non-BKI type, and as such, has no
# stable OID, so no possibility of a builtin codec.
Expand Down
13 changes: 12 additions & 1 deletion asyncpg/protocol/codecs/record.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,20 @@ cdef anonymous_record_decode(ConnectionSettings settings, FRBuffer *buf):
return result


cdef anonymous_record_encode(ConnectionSettings settings, WriteBuffer buf, obj):
raise exceptions.UnsupportedClientFeatureError(
'input of anonymous composite types is not supported',
hint=(
'Consider declaring an explicit composite type and '
'using it to cast the argument.'
),
detail='PostgreSQL does not implement anonymous composite type input.'
)


cdef init_record_codecs():
register_core_codec(RECORDOID,
<encode_func>NULL,
<encode_func>anonymous_record_encode,
<decode_func>anonymous_record_decode,
PG_FORMAT_BINARY)

Expand Down
8 changes: 5 additions & 3 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,11 @@ cdef class PreparedStatementState:
except (AssertionError, exceptions.InternalClientError):
# These are internal errors and should raise as-is.
raise
except exceptions.InterfaceError:
# This is already a descriptive error.
raise
except exceptions.InterfaceError as e:
# This is already a descriptive error, but annotate
# with argument name for clarity.
raise e.with_msg(
f'query argument ${idx + 1}: {e.args[0]}') from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
Expand Down
43 changes: 21 additions & 22 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,13 @@ async def test_composites(self):

self.assertEqual(res, (None, 1234, '5678', (42, '42')))

with self.assertRaisesRegex(
asyncpg.UnsupportedClientFeatureError,
'query argument \\$1: input of anonymous '
'composite types is not supported',
):
await self.con.fetchval("SELECT (1, 'foo') = $1", (1, 'foo'))

try:
st = await self.con.prepare('''
SELECT ROW(
Expand Down Expand Up @@ -1075,7 +1082,7 @@ async def test_extra_codec_alias(self):
# This should fail, as there is no binary codec for
# my_dec_t and text decoding of composites is not
# implemented.
with self.assertRaises(NotImplementedError):
with self.assertRaises(asyncpg.UnsupportedClientFeatureError):
res = await self.con.fetchval('''
SELECT ($1::my_dec_t, 'a=>1'::hstore)::rec_t AS result
''', 44)
Expand Down Expand Up @@ -1132,7 +1139,7 @@ def hstore_encoder(obj):
self.assertEqual(at[0].type, pt[0])

err = 'cannot use custom codec on non-scalar type public._hstore'
with self.assertRaisesRegex(ValueError, err):
with self.assertRaisesRegex(asyncpg.InterfaceError, err):
await self.con.set_type_codec('_hstore',
encoder=hstore_encoder,
decoder=hstore_decoder)
Expand All @@ -1144,7 +1151,7 @@ def hstore_encoder(obj):
try:
err = 'cannot use custom codec on non-scalar type ' + \
'public.mytype'
with self.assertRaisesRegex(ValueError, err):
with self.assertRaisesRegex(asyncpg.InterfaceError, err):
await self.con.set_type_codec(
'mytype', encoder=hstore_encoder,
decoder=hstore_decoder)
Expand Down Expand Up @@ -1245,13 +1252,14 @@ async def test_custom_codec_on_domain(self):
''')

try:
await self.con.set_type_codec(
'custom_codec_t',
encoder=lambda v: str(v),
decoder=lambda v: int(v))

v = await self.con.fetchval('SELECT $1::custom_codec_t', 10)
self.assertEqual(v, 10)
with self.assertRaisesRegex(
asyncpg.UnsupportedClientFeatureError,
'custom codecs on domain types are not supported'
):
await self.con.set_type_codec(
'custom_codec_t',
encoder=lambda v: str(v),
decoder=lambda v: int(v))
finally:
await self.con.execute('DROP DOMAIN custom_codec_t')

Expand Down Expand Up @@ -1650,7 +1658,7 @@ async def test_unknown_type_text_fallback(self):
# Text encoding of ranges and composite types
# is not supported yet.
with self.assertRaisesRegex(
RuntimeError,
asyncpg.UnsupportedClientFeatureError,
'text encoding of range types is not supported'):

await self.con.fetchval('''
Expand All @@ -1659,7 +1667,7 @@ async def test_unknown_type_text_fallback(self):
''', ['a', 'z'])

with self.assertRaisesRegex(
RuntimeError,
asyncpg.UnsupportedClientFeatureError,
'text encoding of composite types is not supported'):

await self.con.fetchval('''
Expand Down Expand Up @@ -1831,7 +1839,7 @@ async def test_custom_codec_large_oid(self):

expected_oid = self.LARGE_OID
if self.server_version >= (11, 0):
# PostgreSQL 11 automatically create a domain array type
# PostgreSQL 11 automatically creates a domain array type
# _before_ the domain type, so the expected OID is
# off by one.
expected_oid += 1
Expand All @@ -1842,14 +1850,5 @@ async def test_custom_codec_large_oid(self):
v = await self.con.fetchval('SELECT $1::test_domain_t', 10)
self.assertEqual(v, 10)

# Test that custom codec logic handles large OIDs
await self.con.set_type_codec(
'test_domain_t',
encoder=lambda v: str(v),
decoder=lambda v: int(v))

v = await self.con.fetchval('SELECT $1::test_domain_t', 10)
self.assertEqual(v, 10)

finally:
await self.con.execute('DROP DOMAIN test_domain_t')