Skip to content

Commit 6fe46fd

Browse files
committed
Make it possible to override cluster connection parameters
1 parent bae39a7 commit 6fe46fd

File tree

2 files changed

+35
-28
lines changed

2 files changed

+35
-28
lines changed

asyncpg/_testbase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def setUp(self):
129129
})
130130

131131
def create_pool(self, **kwargs):
132-
addr = self.cluster.get_connection_addr()
133-
return pg_pool.create_pool(host=addr[0], port=addr[1],
134-
loop=self.loop, **kwargs)
132+
conn_spec = self.cluster.get_connection_spec()
133+
conn_spec.update(kwargs)
134+
return pg_pool.create_pool(loop=self.loop, **conn_spec)
135135

136136

137137
class ConnectedTestCase(ClusterTestCase):

asyncpg/cluster.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(self, data_dir, *, pg_config_path=None):
7777
self._daemon_pid = None
7878
self._daemon_process = None
7979
self._connection_addr = None
80+
self._connection_spec_override = None
8081

8182
def is_managed(self):
8283
return True
@@ -111,9 +112,9 @@ def get_status(self):
111112
process.returncode, stderr))
112113

113114
async def connect(self, loop=None, **kwargs):
114-
conn_addr = self.get_connection_addr()
115-
return await asyncpg.connect(
116-
host=conn_addr[0], port=conn_addr[1], loop=loop, **kwargs)
115+
conn_info = self.get_connection_spec()
116+
conn_info.update(kwargs)
117+
return await asyncpg.connect(loop=loop, **conn_info)
117118

118119
def init(self, **settings):
119120
"""Initialize cluster."""
@@ -130,14 +131,16 @@ def init(self, **settings):
130131

131132
process = subprocess.run(
132133
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
133-
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
134+
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
134135

135-
stderr = process.stderr
136+
output = process.stdout
136137

137138
if process.returncode != 0:
138139
raise ClusterError(
139-
'pg_ctl init exited with status {:d}: {}'.format(
140-
process.returncode, stderr.decode()))
140+
'pg_ctl init exited with status {:d}:\n{}'.format(
141+
process.returncode, output.decode()))
142+
143+
return output.decode()
141144

142145
def start(self, wait=60, *, server_settings={}, **opts):
143146
"""Start the cluster."""
@@ -213,15 +216,27 @@ def destroy(self):
213216
else:
214217
raise ClusterError('cannot destroy {} cluster'.format(status))
215218

216-
def get_connection_addr(self):
219+
def _get_connection_spec(self):
220+
if self._connection_addr is None:
221+
self._connection_addr = self._connection_addr_from_pidfile()
222+
223+
if self._connection_addr is not None:
224+
if self._connection_spec_override:
225+
args = self._connection_addr.copy()
226+
args.update(self._connection_spec_override)
227+
return args
228+
else:
229+
return self._connection_addr
230+
231+
def get_connection_spec(self):
217232
status = self.get_status()
218233
if status != 'running':
219234
raise ClusterError('cluster is not running')
220235

221-
if self._connection_addr is None:
222-
self._connection_addr = self._connection_addr_from_pidfile()
236+
return self._get_connection_spec()
223237

224-
return self._connection_addr['host'], self._connection_addr['port']
238+
def override_connection_spec(self, **kwargs):
239+
self._connection_spec_override = kwargs
225240

226241
def reset_hba(self):
227242
"""Remove all records from pg_hba.conf."""
@@ -345,9 +360,8 @@ def _test_connection(self, timeout=60):
345360
try:
346361
for i in range(timeout):
347362
if self._connection_addr is None:
348-
self._connection_addr = \
349-
self._connection_addr_from_pidfile()
350-
if self._connection_addr is None:
363+
conn_spec = self._get_connection_spec()
364+
if conn_spec is None:
351365
time.sleep(1)
352366
continue
353367

@@ -441,21 +455,14 @@ def __init__(self, *,
441455

442456

443457
class RunningCluster(Cluster):
444-
def __init__(self, host=None, port=None):
445-
if host is None:
446-
host = os.environ.get('PGHOST') or 'localhost'
447-
448-
if port is None:
449-
port = os.environ.get('PGPORT') or 5432
450-
451-
self.host = host
452-
self.port = port
458+
def __init__(self, **kwargs):
459+
self.conn_spec = kwargs
453460

454461
def is_managed(self):
455462
return False
456463

457-
def get_connection_addr(self):
458-
return self.host, self.port
464+
def get_connection_spec(self):
465+
return self.conn_spec
459466

460467
def get_status(self):
461468
return 'running'

0 commit comments

Comments
 (0)