diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index e141dc18..4c5e94ac 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -205,16 +205,21 @@ def _format_loop_exception(self, context, n): _default_cluster = None -def _start_cluster(ClusterCls, cluster_kwargs, server_settings, - initdb_options=None): +def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None): cluster = ClusterCls(**cluster_kwargs) cluster.init(**(initdb_options or {})) cluster.trust_local_connections() - cluster.start(port='dynamic', server_settings=server_settings) atexit.register(_shutdown_cluster, cluster) return cluster +def _start_cluster(ClusterCls, cluster_kwargs, server_settings, + initdb_options=None): + cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options) + cluster.start(port='dynamic', server_settings=server_settings) + return cluster + + def _get_initdb_options(initdb_options=None): if not initdb_options: initdb_options = {} @@ -228,7 +233,7 @@ def _get_initdb_options(initdb_options=None): return initdb_options -def _start_default_cluster(server_settings={}, initdb_options=None): +def _init_default_cluster(initdb_options=None): global _default_cluster if _default_cluster is None: @@ -237,9 +242,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None): # Using existing cluster, assuming it is initialized and running _default_cluster = pg_cluster.RunningCluster() else: - _default_cluster = _start_cluster( + _default_cluster = _init_cluster( pg_cluster.TempCluster, cluster_kwargs={}, - server_settings=server_settings, initdb_options=_get_initdb_options(initdb_options)) return _default_cluster @@ -248,7 +252,8 @@ def _start_default_cluster(server_settings={}, initdb_options=None): def _shutdown_cluster(cluster): if cluster.get_status() == 'running': cluster.stop() - cluster.destroy() + if cluster.get_status() != 'not-initialized': + cluster.destroy() def create_pool(dsn=None, *, @@ -278,15 +283,40 @@ def get_server_settings(cls): 'log_connections': 'on' } + @classmethod + def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}): + cluster = _init_cluster(ClusterCls, cluster_kwargs, + _get_initdb_options(initdb_options)) + cls._clusters.append(cluster) + return cluster + + @classmethod + def start_cluster(cls, cluster, *, server_settings={}): + cluster.start(port='dynamic', server_settings=server_settings) + @classmethod def setup_cluster(cls): - cls.cluster = _start_default_cluster(cls.get_server_settings()) + cls.cluster = _init_default_cluster() + + if cls.cluster.get_status() != 'running': + cls.cluster.start( + port='dynamic', server_settings=cls.get_server_settings()) @classmethod def setUpClass(cls): super().setUpClass() + cls._clusters = [] cls.setup_cluster() + @classmethod + def tearDownClass(cls): + super().tearDownClass() + for cluster in cls._clusters: + if cluster is not _default_cluster: + cluster.stop() + cluster.destroy() + cls._clusters = [] + @classmethod def get_connection_spec(cls, kwargs={}): conn_spec = cls.cluster.get_connection_spec() @@ -309,14 +339,6 @@ def connect(cls, **kwargs): conn_spec = cls.get_connection_spec(kwargs) return pg_connection.connect(**conn_spec, loop=cls.loop) - @classmethod - def start_cluster(cls, ClusterCls, *, - cluster_kwargs={}, server_settings={}, - initdb_options={}): - return _start_cluster( - ClusterCls, cluster_kwargs, - server_settings, _get_initdb_options(initdb_options)) - class ProxiedClusterTestCase(ClusterTestCase): @classmethod diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 31a40e37..33723669 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -106,7 +106,8 @@ def get_status(self): stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.stdout, process.stderr - if process.returncode == 4 or not os.listdir(self._data_dir): + if (process.returncode == 4 or not os.path.exists(self._data_dir) or + not os.listdir(self._data_dir)): return 'not-initialized' elif process.returncode == 3: return 'stopped' @@ -299,6 +300,42 @@ def get_connection_spec(self): def override_connection_spec(self, **kwargs): self._connection_spec_override = kwargs + def reset_wal(self, *, oid=None, xid=None): + status = self.get_status() + if status == 'not-initialized': + raise ClusterError( + 'cannot modify WAL status: cluster is not initialized') + + if status == 'running': + raise ClusterError( + 'cannot modify WAL status: cluster is running') + + opts = [] + if oid is not None: + opts.extend(['-o', str(oid)]) + if xid is not None: + opts.extend(['-x', str(xid)]) + if not opts: + return + + opts.append(self._data_dir) + + try: + reset_wal = self._find_pg_binary('pg_resetwal') + except ClusterError: + reset_wal = self._find_pg_binary('pg_resetxlog') + + process = subprocess.run( + [reset_wal] + opts, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stderr = process.stderr + + if process.returncode != 0: + raise ClusterError( + 'pg_resetwal exited with status {:d}: {}'.format( + process.returncode, stderr.decode())) + def reset_hba(self): """Remove all records from pg_hba.conf.""" status = self.get_status() diff --git a/asyncpg/connection.py b/asyncpg/connection.py index d4f36e06..11b01a4c 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -913,12 +913,12 @@ async def set_type_codec(self, typename, *, if not typeinfo: raise ValueError('unknown type: {}.{}'.format(schema, typename)) - oid = typeinfo['oid'] - if typeinfo['kind'] != b'b' or typeinfo['elemtype']: + if not introspection.is_scalar_type(typeinfo): raise ValueError( 'cannot use custom codec on non-scalar type {}.{}'.format( schema, typename)) + oid = typeinfo['oid'] self._protocol.get_settings().add_python_codec( oid, typename, schema, 'scalar', encoder, decoder, format) diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py index ad19c1d1..201f4341 100644 --- a/asyncpg/introspection.py +++ b/asyncpg/introspection.py @@ -145,3 +145,14 @@ WHERE t.typname = $1 AND ns.nspname = $2 ''' + + +# 'b' for a base type, 'd' for a domain, 'e' for enum. +SCALAR_TYPE_KINDS = (b'b', b'd', b'e') + + +def is_scalar_type(typeinfo) -> bool: + return ( + typeinfo['kind'] in SCALAR_TYPE_KINDS and + not typeinfo['elemtype'] + ) diff --git a/asyncpg/protocol/buffer.pxd b/asyncpg/protocol/buffer.pxd index 4126d973..caca282e 100644 --- a/asyncpg/protocol/buffer.pxd +++ b/asyncpg/protocol/buffer.pxd @@ -97,13 +97,13 @@ cdef class ReadBuffer: cdef feed_data(self, data) cdef inline _ensure_first_buf(self) cdef _switch_to_next_buf(self) - cdef inline read_byte(self) + cdef inline char read_byte(self) except? -1 cdef inline const char* _try_read_bytes(self, ssize_t nbytes) cdef inline _read(self, char *buf, ssize_t nbytes) cdef read(self, ssize_t nbytes) cdef inline const char* read_bytes(self, ssize_t n) except NULL - cdef inline read_int32(self) - cdef inline read_int16(self) + cdef inline int32_t read_int32(self) except? -1 + cdef inline int16_t read_int16(self) except? -1 cdef inline read_cstr(self) cdef int32_t has_message(self) except -1 cdef inline int32_t has_message_type(self, char mtype) except -1 diff --git a/asyncpg/protocol/buffer.pyx b/asyncpg/protocol/buffer.pyx index 84fa2f13..77dea5eb 100644 --- a/asyncpg/protocol/buffer.pyx +++ b/asyncpg/protocol/buffer.pyx @@ -376,7 +376,7 @@ cdef class ReadBuffer: return Memory.new(buf, result, nbytes) - cdef inline read_byte(self): + cdef inline char read_byte(self) except? -1: cdef const char *first_byte if ASYNCPG_DEBUG: @@ -404,7 +404,7 @@ cdef class ReadBuffer: mem = (self.read(n)) return mem.buf - cdef inline read_int32(self): + cdef inline int32_t read_int32(self) except? -1: cdef: Memory mem const char *cbuf @@ -417,7 +417,7 @@ cdef class ReadBuffer: mem = (self.read(4)) return hton.unpack_int32(mem.buf) - cdef inline read_int16(self): + cdef inline int16_t read_int16(self) except? -1: cdef: Memory mem const char *cbuf diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index 74a706a5..c1348781 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -373,6 +373,22 @@ cdef codec_decode_func_ex(ConnectionSettings settings, FastReadBuffer buf, return (arg).decode(settings, buf) +cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl: + cdef: + int64_t oid = 0 + bint overflow = False + + try: + oid = cpython.PyLong_AsLongLong(val) + except OverflowError: + overflow = True + + if overflow or (oid < 0 or oid > UINT32_MAX): + raise OverflowError('OID value too large: {!r}'.format(val)) + + return val + + cdef class DataCodecConfig: def __init__(self, cache_key): try: @@ -523,9 +539,10 @@ cdef class DataCodecConfig: Codec core_codec encode_func c_encoder = NULL decode_func c_decoder = NULL + uint32_t oid = pylong_as_oid(typeoid) if xformat == PG_XFORMAT_TUPLE: - core_codec = get_any_core_codec(typeoid, format, xformat) + core_codec = get_any_core_codec(oid, format, xformat) if core_codec is None: raise ValueError( "{} type does not support 'tuple' exchange format".format( @@ -538,7 +555,7 @@ cdef class DataCodecConfig: self.remove_python_codec(typeoid, typename, typeschema) self._local_type_codecs[typeoid] = \ - Codec.new_python_codec(typeoid, typename, typeschema, typekind, + Codec.new_python_codec(oid, typename, typeschema, typekind, encoder, decoder, c_encoder, c_decoder, format, xformat) @@ -551,6 +568,8 @@ cdef class DataCodecConfig: cdef: Codec codec Codec target_codec + uint32_t oid = pylong_as_oid(typeoid) + uint32_t alias_pid if format == PG_FORMAT_ANY: formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT) @@ -558,12 +577,9 @@ cdef class DataCodecConfig: formats = (format,) for format in formats: - if self.get_codec(typeoid, format) is not None: - raise ValueError('cannot override codec for type {}'.format( - typeoid)) - if isinstance(alias_to, int): - target_codec = self.get_codec(alias_to, format) + alias_oid = pylong_as_oid(alias_to) + target_codec = self.get_codec(alias_oid, format) else: target_codec = get_extra_codec(alias_to, format) diff --git a/asyncpg/protocol/codecs/int.pyx b/asyncpg/protocol/codecs/int.pyx index 30acb176..eb65d85c 100644 --- a/asyncpg/protocol/codecs/int.pyx +++ b/asyncpg/protocol/codecs/int.pyx @@ -27,9 +27,9 @@ cdef int2_encode(ConnectionSettings settings, WriteBuffer buf, obj): except OverflowError: overflow = 1 - if overflow or val < -32768 or val > 32767: + if overflow or val < INT16_MIN or val > INT16_MAX: raise OverflowError( - 'int too big to be encoded as INT2: {!r}'.format(obj)) + 'int16 value out of range: {!r}'.format(obj)) buf.write_int32(2) buf.write_int16(val) @@ -49,10 +49,9 @@ cdef int4_encode(ConnectionSettings settings, WriteBuffer buf, obj): overflow = 1 # "long" and "long long" have the same size for x86_64, need an extra check - if overflow or (sizeof(val) > 4 and (val < -2147483648 or - val > 2147483647)): + if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)): raise OverflowError( - 'int too big to be encoded as INT4: {!r}'.format(obj)) + 'int32 value out of range: {!r}'.format(obj)) buf.write_int32(4) buf.write_int32(val) @@ -62,6 +61,29 @@ cdef int4_decode(ConnectionSettings settings, FastReadBuffer buf): return cpython.PyLong_FromLong(hton.unpack_int32(buf.read(4))) +cdef uint4_encode(ConnectionSettings settings, WriteBuffer buf, obj): + cdef int overflow = 0 + cdef unsigned long val = 0 + + try: + val = cpython.PyLong_AsUnsignedLong(obj) + except OverflowError: + overflow = 1 + + # "long" and "long long" have the same size for x86_64, need an extra check + if overflow or (sizeof(val) > 4 and val > UINT32_MAX): + raise OverflowError( + 'uint32 value out of range: {!r}'.format(obj)) + + buf.write_int32(4) + buf.write_int32(val) + + +cdef uint4_decode(ConnectionSettings settings, FastReadBuffer buf): + return cpython.PyLong_FromUnsignedLong( + hton.unpack_int32(buf.read(4))) + + cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj): cdef int overflow = 0 cdef long long val @@ -72,10 +94,9 @@ cdef int8_encode(ConnectionSettings settings, WriteBuffer buf, obj): overflow = 1 # Just in case for systems with "long long" bigger than 8 bytes - if overflow or (sizeof(val) > 8 and (val < -9223372036854775808 or - val > 9223372036854775807)): + if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)): raise OverflowError( - 'int too big to be encoded as INT8: {!r}'.format(obj)) + 'int64 value out of range: {!r}'.format(obj)) buf.write_int32(8) buf.write_int64(val) diff --git a/asyncpg/protocol/codecs/misc.pyx b/asyncpg/protocol/codecs/misc.pyx index 6d6cc910..3f5a6178 100644 --- a/asyncpg/protocol/codecs/misc.pyx +++ b/asyncpg/protocol/codecs/misc.pyx @@ -36,8 +36,8 @@ cdef init_pseudo_codecs(): for oid_type in oid_types: register_core_codec(oid_type, - &int4_encode, - &int4_decode, + &uint4_encode, + &uint4_decode, PG_FORMAT_BINARY) # reg* types -- these are really system catalog OIDs, but diff --git a/asyncpg/protocol/codecs/tid.pyx b/asyncpg/protocol/codecs/tid.pyx index fef7e7c3..a64e1901 100644 --- a/asyncpg/protocol/codecs/tid.pyx +++ b/asyncpg/protocol/codecs/tid.pyx @@ -23,9 +23,9 @@ cdef tid_encode(ConnectionSettings settings, WriteBuffer buf, obj): overflow = 1 # "long" and "long long" have the same size for x86_64, need an extra check - if overflow or (sizeof(block) > 4 and block > 4294967295): + if overflow or (sizeof(block) > 4 and block > UINT32_MAX): raise OverflowError( - 'block too big to be encoded as UINT4: {!r}'.format(obj[0])) + 'tuple id block value out of range: {!r}'.format(obj[0])) try: offset = cpython.PyLong_AsUnsignedLong(obj[1]) @@ -35,7 +35,7 @@ cdef tid_encode(ConnectionSettings settings, WriteBuffer buf, obj): if overflow or offset > 65535: raise OverflowError( - 'offset too big to be encoded as UINT2: {!r}'.format(obj[1])) + 'tuple id offset value out of range: {!r}'.format(obj[1])) buf.write_int32(6) buf.write_int32(block) diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index e69369d3..e8ea038c 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -163,7 +163,7 @@ cdef class PreparedStatementState: list cols_names object cols_mapping tuple row - int oid + uint32_t oid Codec codec list codecs @@ -183,7 +183,7 @@ cdef class PreparedStatementState: cols_mapping[col_name] = i cols_names.append(col_name) oid = row[3] - codec = self.settings.get_data_codec(oid) + codec = self.settings.get_data_codec(oid) if codec is None or not codec.has_decoder(): raise RuntimeError('no decoder for OID {}'.format(oid)) if not codec.is_binary(): @@ -198,7 +198,7 @@ cdef class PreparedStatementState: cdef _ensure_args_encoder(self): cdef: - int p_oid + uint32_t p_oid Codec codec list codecs = [] @@ -207,7 +207,7 @@ cdef class PreparedStatementState: for i from 0 <= i < self.args_num: p_oid = self.parameters_desc[i] - codec = self.settings.get_data_codec(p_oid) + codec = self.settings.get_data_codec(p_oid) if codec is None or not codec.has_encoder(): raise RuntimeError('no encoder for OID {}'.format(p_oid)) if codec.type not in {}: @@ -290,14 +290,14 @@ cdef _decode_parameters_desc(object desc): cdef: ReadBuffer reader int16_t nparams - int32_t p_oid + uint32_t p_oid list result = [] reader = ReadBuffer.new_message_parser(desc) nparams = reader.read_int16() for i from 0 <= i < nparams: - p_oid = reader.read_int32() + p_oid = reader.read_int32() result.append(p_oid) return result @@ -310,9 +310,9 @@ cdef _decode_row_desc(object desc): int16_t nfields bytes f_name - int32_t f_table_oid + uint32_t f_table_oid int16_t f_column_num - int32_t f_dt_oid + uint32_t f_dt_oid int16_t f_dt_size int32_t f_dt_mod int16_t f_format @@ -325,9 +325,9 @@ cdef _decode_row_desc(object desc): for i from 0 <= i < nfields: f_name = reader.read_cstr() - f_table_oid = reader.read_int32() + f_table_oid = reader.read_int32() f_column_num = reader.read_int16() - f_dt_oid = reader.read_int32() + f_dt_oid = reader.read_int32() f_dt_size = reader.read_int16() f_dt_mod = reader.read_int32() f_format = reader.read_int16() diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index bff77b5e..5aed74e5 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -18,7 +18,9 @@ import socket import time from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \ - int32_t, uint32_t, int64_t, uint64_t + int32_t, uint32_t, int64_t, uint64_t, \ + INT16_MIN, INT16_MAX, INT32_MIN, INT32_MAX, \ + UINT32_MAX, INT64_MIN, INT64_MAX from asyncpg.protocol cimport record diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 8a01721f..152e1e48 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -9,12 +9,15 @@ import decimal import ipaddress import math +import os import random import struct +import unittest import uuid import asyncpg from asyncpg import _testbase as tb +from asyncpg import cluster as pg_cluster def _timezone(offset): @@ -388,6 +391,11 @@ def _timezone(offset): (0, 65535), (4294967295, 65535), ]), + ('oid', 'oid', [ + 0, + 10, + 4294967295 + ]) ] @@ -581,7 +589,7 @@ async def test_invalid_input(self): '2', 'aa', ]), - ('smallint', OverflowError, 'int too big to be encoded as INT2', [ + ('smallint', OverflowError, 'int16 value out of range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffff, @@ -597,7 +605,7 @@ async def test_invalid_input(self): '2', 'aa', ]), - ('int', OverflowError, 'int too big to be encoded as INT4', [ + ('int', OverflowError, 'int32 value out of range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffffffff, @@ -608,7 +616,7 @@ async def test_invalid_input(self): '2', 'aa', ]), - ('bigint', OverflowError, 'int too big to be encoded as INT8', [ + ('bigint', OverflowError, 'int64 value out of range', [ 2**256, # check for the same exception for any big numbers decimal.Decimal("2000000000000000000000000000000"), 0xffffffffffffffff, @@ -630,19 +638,23 @@ async def test_invalid_input(self): [1, 2, 3], (4,), ]), - ('tid', OverflowError, 'block too big to be encoded as UINT4', [ + ('tid', OverflowError, 'tuple id block value out of range', [ (-1, 0), (2**256, 0), (0xffffffff + 1, 0), (2**32, 0), ]), - ('tid', OverflowError, 'offset too big to be encoded as UINT2', [ + ('tid', OverflowError, 'tuple id offset value out of range', [ (0, -1), (0, 2**256), (0, 0xffff + 1), (0, 0xffffffff), (0, 65536), ]), + ('oid', OverflowError, 'uint32 value out of range', [ + 2 ** 32, + -1, + ]), ] for typname, errcls, errmsg, data in cases: @@ -1082,6 +1094,40 @@ def hstore_encoder(obj): DROP EXTENSION hstore ''') + async def test_custom_codec_on_domain(self): + """Test encoding/decoding using a custom codec on a domain.""" + await self.con.execute(''' + CREATE DOMAIN custom_codec_t AS int + ''') + + try: + await self.con.set_type_codec( + 'custom_codec_t', + encoder=lambda v: str(v), + decoder=lambda v: int(v)) + + v = await self.con.fetchval('SELECT $1::custom_codec_t', 10) + self.assertEqual(v, 10) + finally: + await self.con.execute('DROP DOMAIN custom_codec_t') + + async def test_custom_codec_on_enum(self): + """Test encoding/decoding using a custom codec on an enum.""" + await self.con.execute(''' + CREATE TYPE custom_codec_t AS ENUM ('foo', 'bar', 'baz') + ''') + + try: + await self.con.set_type_codec( + 'custom_codec_t', + encoder=lambda v: str(v).lstrip('enum :'), + decoder=lambda v: 'enum: ' + str(v)) + + v = await self.con.fetchval('SELECT $1::custom_codec_t', 'foo') + self.assertEqual(v, 'enum: foo') + finally: + await self.con.execute('DROP TYPE custom_codec_t') + async def test_custom_codec_override_binary(self): """Test overriding core codecs.""" import json @@ -1511,3 +1557,31 @@ async def test_enum_and_range(self): async def test_no_result(self): st = await self.con.prepare('rollback') self.assertTupleEqual(st.get_attributes(), ()) + + +@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') +class TestCodecsLargeOIDs(tb.ConnectedTestCase): + @classmethod + def setup_cluster(cls): + cls.cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.cluster.reset_wal(oid=2147483648) + cls.start_cluster(cls.cluster) + + async def test_custom_codec_large_oid(self): + await self.con.execute('CREATE DOMAIN test_domain_t AS int') + try: + oid = await self.con.fetchval(''' + SELECT oid FROM pg_type WHERE typname = 'test_domain_t' + ''') + self.assertEqual(oid, 2147483648) + + await self.con.set_type_codec( + 'test_domain_t', + encoder=lambda v: str(v), + decoder=lambda v: int(v)) + + v = await self.con.fetchval('SELECT $1::test_domain_t', 10) + self.assertEqual(v, 10) + + finally: + await self.con.execute('DROP DOMAIN test_domain_t') diff --git a/tests/test_connect.py b/tests/test_connect.py index c187ab62..48cf14fd 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -756,8 +756,9 @@ def get_server_settings(cls): @classmethod def setup_cluster(cls): - cls.cluster = cls.start_cluster( - pg_cluster.TempCluster, server_settings=cls.get_server_settings()) + cls.cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.cluster, server_settings=cls.get_server_settings()) def setUp(self): super().setUp() diff --git a/tests/test_pool.py b/tests/test_pool.py index eba49f7d..d35fe9f0 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -743,17 +743,17 @@ class MyException(Exception): @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') -class TestHotStandby(tb.ConnectedTestCase): +class TestHotStandby(tb.ClusterTestCase): @classmethod - def setUpClass(cls): - super().setUpClass() - - cls.master_cluster = cls.start_cluster( - pg_cluster.TempCluster, + def setup_cluster(cls): + cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.master_cluster, server_settings={ 'max_wal_senders': 10, 'wal_level': 'hot_standby' - }) + } + ) con = None @@ -771,27 +771,24 @@ def setUpClass(cls): conn_spec = cls.master_cluster.get_connection_spec() - cls.standby_cluster = cls.start_cluster( + cls.standby_cluster = cls.new_cluster( pg_cluster.HotStandbyCluster, cluster_kwargs={ 'master': conn_spec, 'replication_user': 'replication' - }, + } + ) + cls.start_cluster( + cls.standby_cluster, server_settings={ 'hot_standby': True - }) + } + ) finally: if con is not None: cls.loop.run_until_complete(con.close()) - @classmethod - def tearDownMethod(cls): - cls.standby_cluster.stop() - cls.standby_cluster.destroy() - cls.master_cluster.stop() - cls.master_cluster.destroy() - def create_pool(self, **kwargs): conn_spec = self.standby_cluster.get_connection_spec() conn_spec.update(kwargs)