Skip to content

Commit c8f1e20

Browse files
committed
Allow setting custom codecs on domains and enumerated types
Previously this was disallowed for no particular reason.
1 parent 5301e67 commit c8f1e20

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

asyncpg/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -913,12 +913,12 @@ async def set_type_codec(self, typename, *,
913913
if not typeinfo:
914914
raise ValueError('unknown type: {}.{}'.format(schema, typename))
915915

916-
oid = typeinfo['oid']
917-
if typeinfo['kind'] != b'b' or typeinfo['elemtype']:
916+
if not introspection.is_scalar_type(typeinfo):
918917
raise ValueError(
919918
'cannot use custom codec on non-scalar type {}.{}'.format(
920919
schema, typename))
921920

921+
oid = typeinfo['oid']
922922
self._protocol.get_settings().add_python_codec(
923923
oid, typename, schema, 'scalar',
924924
encoder, decoder, format)

asyncpg/introspection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,14 @@
145145
WHERE
146146
t.typname = $1 AND ns.nspname = $2
147147
'''
148+
149+
150+
# 'b' for a base type, 'd' for a domain, 'e' for enum.
151+
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
152+
153+
154+
def is_scalar_type(typeinfo) -> bool:
155+
return (
156+
typeinfo['kind'] in SCALAR_TYPE_KINDS and
157+
not typeinfo['elemtype']
158+
)

tests/test_codecs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,40 @@ def hstore_encoder(obj):
10821082
DROP EXTENSION hstore
10831083
''')
10841084

1085+
async def test_custom_codec_on_domain(self):
1086+
"""Test encoding/decoding using a custom codec on a domain."""
1087+
await self.con.execute('''
1088+
CREATE DOMAIN custom_codec_t AS int
1089+
''')
1090+
1091+
try:
1092+
await self.con.set_type_codec(
1093+
'custom_codec_t',
1094+
encoder=lambda v: str(v),
1095+
decoder=lambda v: int(v))
1096+
1097+
v = await self.con.fetchval('SELECT $1::custom_codec_t', 10)
1098+
self.assertEqual(v, 10)
1099+
finally:
1100+
await self.con.execute('DROP DOMAIN custom_codec_t')
1101+
1102+
async def test_custom_codec_on_enum(self):
1103+
"""Test encoding/decoding using a custom codec on an enum."""
1104+
await self.con.execute('''
1105+
CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz')
1106+
''')
1107+
1108+
try:
1109+
await self.con.set_type_codec(
1110+
'custom_codec_t',
1111+
encoder=lambda v: str(v).lstrip('enum :'),
1112+
decoder=lambda v: 'enum: ' + str(v))
1113+
1114+
v = await self.con.fetchval('SELECT $1::custom_codec_t', 'foo')
1115+
self.assertEqual(v, 'enum: foo')
1116+
finally:
1117+
await self.con.execute('DROP TYPE custom_codec_t')
1118+
10851119
async def test_custom_codec_override_binary(self):
10861120
"""Test overriding core codecs."""
10871121
import json

0 commit comments

Comments
 (0)