Skip to content

Fix handling of OIDs >= 2**31 #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,21 @@ def _format_loop_exception(self, context, n):
_default_cluster = None


def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
initdb_options=None):
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
cluster = ClusterCls(**cluster_kwargs)
cluster.init(**(initdb_options or {}))
cluster.trust_local_connections()
cluster.start(port='dynamic', server_settings=server_settings)
atexit.register(_shutdown_cluster, cluster)
return cluster


def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
initdb_options=None):
cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
cluster.start(port='dynamic', server_settings=server_settings)
return cluster


def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
Expand All @@ -228,7 +233,7 @@ def _get_initdb_options(initdb_options=None):
return initdb_options


def _start_default_cluster(server_settings={}, initdb_options=None):
def _init_default_cluster(initdb_options=None):
global _default_cluster

if _default_cluster is None:
Expand All @@ -237,9 +242,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None):
# Using existing cluster, assuming it is initialized and running
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _start_cluster(
_default_cluster = _init_cluster(
pg_cluster.TempCluster, cluster_kwargs={},
server_settings=server_settings,
initdb_options=_get_initdb_options(initdb_options))

return _default_cluster
Expand All @@ -248,7 +252,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None):
def _shutdown_cluster(cluster):
if cluster.get_status() == 'running':
cluster.stop()
cluster.destroy()
if cluster.get_status() != 'not-initialized':
cluster.destroy()


def create_pool(dsn=None, *,
Expand Down Expand Up @@ -278,15 +283,40 @@ def get_server_settings(cls):
'log_connections': 'on'
}

@classmethod
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
cluster = _init_cluster(ClusterCls, cluster_kwargs,
_get_initdb_options(initdb_options))
cls._clusters.append(cluster)
return cluster

@classmethod
def start_cluster(cls, cluster, *, server_settings={}):
cluster.start(port='dynamic', server_settings=server_settings)

@classmethod
def setup_cluster(cls):
cls.cluster = _start_default_cluster(cls.get_server_settings())
cls.cluster = _init_default_cluster()

if cls.cluster.get_status() != 'running':
cls.cluster.start(
port='dynamic', server_settings=cls.get_server_settings())

@classmethod
def setUpClass(cls):
super().setUpClass()
cls._clusters = []
cls.setup_cluster()

@classmethod
def tearDownClass(cls):
super().tearDownClass()
for cluster in cls._clusters:
if cluster is not _default_cluster:
cluster.stop()
cluster.destroy()
cls._clusters = []

@classmethod
def get_connection_spec(cls, kwargs={}):
conn_spec = cls.cluster.get_connection_spec()
Expand All @@ -309,14 +339,6 @@ def connect(cls, **kwargs):
conn_spec = cls.get_connection_spec(kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)

@classmethod
def start_cluster(cls, ClusterCls, *,
cluster_kwargs={}, server_settings={},
initdb_options={}):
return _start_cluster(
ClusterCls, cluster_kwargs,
server_settings, _get_initdb_options(initdb_options))


class ProxiedClusterTestCase(ClusterTestCase):
@classmethod
Expand Down
39 changes: 38 additions & 1 deletion asyncpg/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def get_status(self):
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr

if process.returncode == 4 or not os.listdir(self._data_dir):
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
not os.listdir(self._data_dir)):
return 'not-initialized'
elif process.returncode == 3:
return 'stopped'
Expand Down Expand Up @@ -299,6 +300,42 @@ def get_connection_spec(self):
def override_connection_spec(self, **kwargs):
self._connection_spec_override = kwargs

def reset_wal(self, *, oid=None, xid=None):
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify WAL status: cluster is not initialized')

if status == 'running':
raise ClusterError(
'cannot modify WAL status: cluster is running')

opts = []
if oid is not None:
opts.extend(['-o', str(oid)])
if xid is not None:
opts.extend(['-x', str(xid)])
if not opts:
return

opts.append(self._data_dir)

try:
reset_wal = self._find_pg_binary('pg_resetwal')
except ClusterError:
reset_wal = self._find_pg_binary('pg_resetxlog')

process = subprocess.run(
[reset_wal] + opts,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)

stderr = process.stderr

if process.returncode != 0:
raise ClusterError(
'pg_resetwal exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))

def reset_hba(self):
"""Remove all records from pg_hba.conf."""
status = self.get_status()
Expand Down
4 changes: 2 additions & 2 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,12 +913,12 @@ async def set_type_codec(self, typename, *,
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

oid = typeinfo['oid']
if typeinfo['kind'] != b'b' or typeinfo['elemtype']:
if not introspection.is_scalar_type(typeinfo):
raise ValueError(
'cannot use custom codec on non-scalar type {}.{}'.format(
schema, typename))

oid = typeinfo['oid']
self._protocol.get_settings().add_python_codec(
oid, typename, schema, 'scalar',
encoder, decoder, format)
Expand Down
11 changes: 11 additions & 0 deletions asyncpg/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,14 @@
WHERE
t.typname = $1 AND ns.nspname = $2
'''


# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')


def is_scalar_type(typeinfo) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)
6 changes: 3 additions & 3 deletions asyncpg/protocol/buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ cdef class ReadBuffer:
cdef feed_data(self, data)
cdef inline _ensure_first_buf(self)
cdef _switch_to_next_buf(self)
cdef inline read_byte(self)
cdef inline char read_byte(self) except? -1
cdef inline const char* _try_read_bytes(self, ssize_t nbytes)
cdef inline _read(self, char *buf, ssize_t nbytes)
cdef read(self, ssize_t nbytes)
cdef inline const char* read_bytes(self, ssize_t n) except NULL
cdef inline read_int32(self)
cdef inline read_int16(self)
cdef inline int32_t read_int32(self) except? -1
cdef inline int16_t read_int16(self) except? -1
cdef inline read_cstr(self)
cdef int32_t has_message(self) except -1
cdef inline int32_t has_message_type(self, char mtype) except -1
Expand Down
6 changes: 3 additions & 3 deletions asyncpg/protocol/buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ cdef class ReadBuffer:

return Memory.new(buf, result, nbytes)

cdef inline read_byte(self):
cdef inline char read_byte(self) except? -1:
cdef const char *first_byte

if ASYNCPG_DEBUG:
Expand Down Expand Up @@ -404,7 +404,7 @@ cdef class ReadBuffer:
mem = <Memory>(self.read(n))
return mem.buf

cdef inline read_int32(self):
cdef inline int32_t read_int32(self) except? -1:
cdef:
Memory mem
const char *cbuf
Expand All @@ -417,7 +417,7 @@ cdef class ReadBuffer:
mem = <Memory>(self.read(4))
return hton.unpack_int32(mem.buf)

cdef inline read_int16(self):
cdef inline int16_t read_int16(self) except? -1:
cdef:
Memory mem
const char *cbuf
Expand Down
30 changes: 23 additions & 7 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,22 @@ cdef codec_decode_func_ex(ConnectionSettings settings, FastReadBuffer buf,
return (<Codec>arg).decode(settings, buf)


cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef:
int64_t oid = 0
bint overflow = False

try:
oid = cpython.PyLong_AsLongLong(val)
except OverflowError:
overflow = True

if overflow or (oid < 0 or oid > UINT32_MAX):
raise OverflowError('OID value too large: {!r}'.format(val))

return <uint32_t>val


cdef class DataCodecConfig:
def __init__(self, cache_key):
try:
Expand Down Expand Up @@ -523,9 +539,10 @@ cdef class DataCodecConfig:
Codec core_codec
encode_func c_encoder = NULL
decode_func c_decoder = NULL
uint32_t oid = pylong_as_oid(typeoid)

if xformat == PG_XFORMAT_TUPLE:
core_codec = get_any_core_codec(typeoid, format, xformat)
core_codec = get_any_core_codec(oid, format, xformat)
if core_codec is None:
raise ValueError(
"{} type does not support 'tuple' exchange format".format(
Expand All @@ -538,7 +555,7 @@ cdef class DataCodecConfig:
self.remove_python_codec(typeoid, typename, typeschema)

self._local_type_codecs[typeoid] = \
Codec.new_python_codec(typeoid, typename, typeschema, typekind,
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
format, xformat)

Expand All @@ -551,19 +568,18 @@ cdef class DataCodecConfig:
cdef:
Codec codec
Codec target_codec
uint32_t oid = pylong_as_oid(typeoid)
uint32_t alias_pid

if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT)
else:
formats = (format,)

for format in formats:
if self.get_codec(typeoid, format) is not None:
raise ValueError('cannot override codec for type {}'.format(
typeoid))

if isinstance(alias_to, int):
target_codec = self.get_codec(alias_to, format)
alias_oid = pylong_as_oid(alias_to)
target_codec = self.get_codec(alias_oid, format)
else:
target_codec = get_extra_codec(alias_to, format)

Expand Down
37 changes: 29 additions & 8 deletions asyncpg/protocol/codecs/int.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ cdef int2_encode(ConnectionSettings settings, WriteBuffer buf, obj):
except OverflowError:
overflow = 1

if overflow or val < -32768 or val > 32767:
if overflow or val < INT16_MIN or val > INT16_MAX:
raise OverflowError(
'int too big to be encoded as INT2: {!r}'.format(obj))
'int16 value out of range: {!r}'.format(obj))

buf.write_int32(2)
buf.write_int16(<int16_t>val)
Expand All @@ -49,10 +49,9 @@ cdef int4_encode(ConnectionSettings settings, WriteBuffer buf, obj):
overflow = 1

# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and (val < -2147483648 or
val > 2147483647)):
if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)):
raise OverflowError(
'int too big to be encoded as INT4: {!r}'.format(obj))
'int32 value out of range: {!r}'.format(obj))

buf.write_int32(4)
buf.write_int32(<int32_t>val)
Expand All @@ -62,6 +61,29 @@ cdef int4_decode(ConnectionSettings settings, FastReadBuffer buf):
return cpython.PyLong_FromLong(hton.unpack_int32(buf.read(4)))


cdef uint4_encode(ConnectionSettings settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long val = 0

try:
val = cpython.PyLong_AsUnsignedLong(obj)
except OverflowError:
overflow = 1

# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and val > UINT32_MAX):
raise OverflowError(
'uint32 value out of range: {!r}'.format(obj))

buf.write_int32(4)
buf.write_int32(<int32_t>val)


cdef uint4_decode(ConnectionSettings settings, FastReadBuffer buf):
return cpython.PyLong_FromUnsignedLong(
<uint32_t>hton.unpack_int32(buf.read(4)))


cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long long val
Expand All @@ -72,10 +94,9 @@ cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj):
overflow = 1

# Just in case for systems with "long long" bigger than 8 bytes
if overflow or (sizeof(val) > 8 and (val < -9223372036854775808 or
val > 9223372036854775807)):
if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)):
raise OverflowError(
'int too big to be encoded as INT8: {!r}'.format(obj))
'int64 value out of range: {!r}'.format(obj))

buf.write_int32(8)
buf.write_int64(<int64_t>val)
Expand Down
4 changes: 2 additions & 2 deletions asyncpg/protocol/codecs/misc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ cdef init_pseudo_codecs():

for oid_type in oid_types:
register_core_codec(oid_type,
<encode_func>&int4_encode,
<decode_func>&int4_decode,
<encode_func>&uint4_encode,
<decode_func>&uint4_decode,
PG_FORMAT_BINARY)

# reg* types -- these are really system catalog OIDs, but
Expand Down
Loading