Skip to content

Commit b4ce740

Browse files
committed
Consistently use postgres/postgres as database/user pair in tests
When running against a temporary cluster, make sure the default superuser and database name are 'postgres'. When PGHOST environment variable is specified, rely on the default connection spec heuristics.
1 parent 7b6c083 commit b4ce740

File tree

7 files changed

+77
-60
lines changed

7 files changed

+77
-60
lines changed

asyncpg/_testbase.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,30 @@ def handler(loop, ctx):
128128
_default_cluster = None
129129

130130

131-
def _start_cluster(ClusterCls, cluster_kwargs, server_settings):
131+
def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
132+
initdb_options=None):
132133
cluster = ClusterCls(**cluster_kwargs)
133-
cluster.init()
134+
cluster.init(**(initdb_options or {}))
134135
cluster.trust_local_connections()
135136
cluster.start(port='dynamic', server_settings=server_settings)
136137
atexit.register(_shutdown_cluster, cluster)
137138
return cluster
138139

139140

140-
def _start_default_cluster(server_settings={}):
141+
def _get_initdb_options(initdb_options=None):
142+
if not initdb_options:
143+
initdb_options = {}
144+
else:
145+
initdb_options = dict(initdb_options)
146+
147+
# Make the default superuser name stable.
148+
if 'username' not in initdb_options:
149+
initdb_options['username'] = 'postgres'
150+
151+
return initdb_options
152+
153+
154+
def _start_default_cluster(server_settings={}, initdb_options=None):
141155
global _default_cluster
142156

143157
if _default_cluster is None:
@@ -147,7 +161,9 @@ def _start_default_cluster(server_settings={}):
147161
_default_cluster = pg_cluster.RunningCluster()
148162
else:
149163
_default_cluster = _start_cluster(
150-
pg_cluster.TempCluster, {}, server_settings)
164+
pg_cluster.TempCluster, cluster_kwargs={},
165+
server_settings=server_settings,
166+
initdb_options=_get_initdb_options(initdb_options))
151167

152168
return _default_cluster
153169

@@ -193,15 +209,33 @@ def setUpClass(cls):
193209
super().setUpClass()
194210
cls.setup_cluster()
195211

196-
def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
197-
conn_spec = self.cluster.get_connection_spec()
212+
@classmethod
213+
def get_connection_spec(cls, kwargs={}):
214+
conn_spec = cls.cluster.get_connection_spec()
198215
conn_spec.update(kwargs)
216+
if not os.environ.get('PGHOST'):
217+
if 'database' not in conn_spec:
218+
conn_spec['database'] = 'postgres'
219+
if 'user' not in conn_spec:
220+
conn_spec['user'] = 'postgres'
221+
return conn_spec
222+
223+
def create_pool(self, pool_class=pg_pool.Pool, **kwargs):
224+
conn_spec = self.get_connection_spec(kwargs)
199225
return create_pool(loop=self.loop, pool_class=pool_class, **conn_spec)
200226

227+
@classmethod
228+
def connect(cls, **kwargs):
229+
conn_spec = cls.get_connection_spec(kwargs)
230+
return pg_connection.connect(**conn_spec, loop=cls.loop)
231+
201232
@classmethod
202233
def start_cluster(cls, ClusterCls, *,
203-
cluster_kwargs={}, server_settings={}):
204-
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)
234+
cluster_kwargs={}, server_settings={},
235+
initdb_options={}):
236+
return _start_cluster(
237+
ClusterCls, cluster_kwargs,
238+
server_settings, _get_initdb_options(initdb_options))
205239

206240

207241
def with_connection_options(**options):
@@ -223,13 +257,7 @@ def setUp(self):
223257
# Extract options set up with `with_connection_options`.
224258
test_func = getattr(self, self._testMethodName).__func__
225259
opts = getattr(test_func, '__connect_options__', {})
226-
if 'database' not in opts:
227-
opts = dict(opts)
228-
opts['database'] = 'postgres'
229-
230-
self.con = self.loop.run_until_complete(
231-
self.cluster.connect(loop=self.loop, **opts))
232-
260+
self.con = self.loop.run_until_complete(self.connect(**opts))
233261
self.server_version = self.con.get_server_version()
234262

235263
def tearDown(self):

asyncpg/cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def _test_connection(self, timeout=60):
448448
try:
449449
con = loop.run_until_complete(
450450
asyncpg.connect(database='postgres',
451+
user='postgres',
451452
timeout=5, loop=loop,
452453
**self._connection_addr))
453454
except (OSError, asyncio.TimeoutError,

tests/test_codecs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,7 +1011,7 @@ async def test_custom_codec_override_binary(self):
10111011
"""Test overriding core codecs."""
10121012
import json
10131013

1014-
conn = await self.cluster.connect(database='postgres', loop=self.loop)
1014+
conn = await self.connect()
10151015
try:
10161016
def _encoder(value):
10171017
return json.dumps(value).encode('utf-8')
@@ -1035,7 +1035,7 @@ async def test_custom_codec_override_text(self):
10351035
"""Test overriding core codecs."""
10361036
import json
10371037

1038-
conn = await self.cluster.connect(database='postgres', loop=self.loop)
1038+
conn = await self.connect()
10391039
try:
10401040
def _encoder(value):
10411041
return json.dumps(value)
@@ -1087,7 +1087,7 @@ async def test_custom_codec_override_tuple(self):
10871087
('interval', (2, 3, 1), '2 mons 3 days 00:00:00.000001')
10881088
]
10891089

1090-
conn = await self.cluster.connect(database='postgres', loop=self.loop)
1090+
conn = await self.connect()
10911091

10921092
def _encoder(value):
10931093
return tuple(value)

tests/test_connect.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,57 +137,50 @@ async def _try_connect(self, **kwargs):
137137
if _system == 'Windows':
138138
for tried in range(3):
139139
try:
140-
return await self.cluster.connect(**kwargs)
140+
return await self.connect(**kwargs)
141141
except asyncpg.ConnectionDoesNotExistError:
142142
pass
143143

144-
return await self.cluster.connect(**kwargs)
144+
return await self.connect(**kwargs)
145145

146146
async def test_auth_bad_user(self):
147147
with self.assertRaises(
148148
asyncpg.InvalidAuthorizationSpecificationError):
149-
await self._try_connect(user='__nonexistent__',
150-
database='postgres',
151-
loop=self.loop)
149+
await self._try_connect(user='__nonexistent__')
152150

153151
async def test_auth_trust(self):
154-
conn = await self.cluster.connect(
155-
user='trust_user', database='postgres', loop=self.loop)
152+
conn = await self.connect(user='trust_user')
156153
await conn.close()
157154

158155
async def test_auth_reject(self):
159156
with self.assertRaisesRegex(
160157
asyncpg.InvalidAuthorizationSpecificationError,
161158
'pg_hba.conf rejects connection'):
162-
await self._try_connect(
163-
user='reject_user', database='postgres',
164-
loop=self.loop)
159+
await self._try_connect(user='reject_user')
165160

166161
async def test_auth_password_cleartext(self):
167-
conn = await self.cluster.connect(
168-
user='password_user', database='postgres',
169-
password='correctpassword', loop=self.loop)
162+
conn = await self.connect(
163+
user='password_user',
164+
password='correctpassword')
170165
await conn.close()
171166

172167
with self.assertRaisesRegex(
173168
asyncpg.InvalidPasswordError,
174169
'password authentication failed for user "password_user"'):
175170
await self._try_connect(
176-
user='password_user', database='postgres',
177-
password='wrongpassword', loop=self.loop)
171+
user='password_user',
172+
password='wrongpassword')
178173

179174
async def test_auth_password_md5(self):
180-
conn = await self.cluster.connect(
181-
user='md5_user', database='postgres', password='correctpassword',
182-
loop=self.loop)
175+
conn = await self.connect(
176+
user='md5_user', password='correctpassword')
183177
await conn.close()
184178

185179
with self.assertRaisesRegex(
186180
asyncpg.InvalidPasswordError,
187181
'password authentication failed for user "md5_user"'):
188182
await self._try_connect(
189-
user='md5_user', database='postgres', password='wrongpassword',
190-
loop=self.loop)
183+
user='md5_user', password='wrongpassword')
191184

192185
async def test_auth_unsupported(self):
193186
pass
@@ -494,11 +487,9 @@ async def test_connection_ssl_to_no_ssl_server(self):
494487
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)
495488

496489
with self.assertRaisesRegex(ConnectionError, 'rejected SSL'):
497-
await self.cluster.connect(
490+
await self.connect(
498491
host='localhost',
499492
user='ssl_user',
500-
database='postgres',
501-
loop=self.loop,
502493
ssl=ssl_context)
503494

504495
async def test_connection_ssl_unix(self):
@@ -507,15 +498,17 @@ async def test_connection_ssl_unix(self):
507498

508499
with self.assertRaisesRegex(asyncpg.InterfaceError,
509500
'can only be enabled for TCP addresses'):
510-
await self.cluster.connect(
501+
await self.connect(
511502
host='/tmp',
512-
loop=self.loop,
513503
ssl=ssl_context)
514504

515505
async def test_connection_implicit_host(self):
516-
conn_spec = self.cluster.get_connection_spec()
506+
conn_spec = self.get_connection_spec()
517507
con = await asyncpg.connect(
518-
port=conn_spec.get('port'), database='postgres', loop=self.loop)
508+
port=conn_spec.get('port'),
509+
database=conn_spec.get('database'),
510+
user=conn_spec.get('user'),
511+
loop=self.loop)
519512
await con.close()
520513

521514

@@ -576,11 +569,9 @@ async def test_ssl_connection_custom_context(self):
576569
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
577570
ssl_context.load_verify_locations(SSL_CA_CERT_FILE)
578571

579-
con = await self.cluster.connect(
572+
con = await self.connect(
580573
host='localhost',
581574
user='ssl_user',
582-
database='postgres',
583-
loop=self.loop,
584575
ssl=ssl_context)
585576

586577
try:
@@ -595,11 +586,9 @@ async def test_ssl_connection_custom_context(self):
595586

596587
async def test_ssl_connection_default_context(self):
597588
with self.assertRaisesRegex(ssl.SSLError, 'verify failed'):
598-
await self.cluster.connect(
589+
await self.connect(
599590
host='localhost',
600591
user='ssl_user',
601-
database='postgres',
602-
loop=self.loop,
603592
ssl=True)
604593

605594
async def test_ssl_connection_pool(self):

tests/test_introspection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ class TestIntrospection(tb.ConnectedTestCase):
1616
@classmethod
1717
def setUpClass(cls):
1818
super().setUpClass()
19-
cls.adminconn = cls.loop.run_until_complete(
20-
cls.cluster.connect(database='postgres', loop=cls.loop))
19+
cls.adminconn = cls.loop.run_until_complete(cls.connect())
2120
cls.loop.run_until_complete(
2221
cls.adminconn.execute('CREATE DATABASE asyncpg_intro_test'))
2322

tests/test_pool.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,8 @@ def setUpClass(cls):
656656

657657
try:
658658
con = cls.loop.run_until_complete(
659-
cls.master_cluster.connect(database='postgres', loop=cls.loop))
659+
cls.master_cluster.connect(
660+
database='postgres', user='postgres', loop=cls.loop))
660661

661662
cls.loop.run_until_complete(
662663
con.execute('''
@@ -696,8 +697,9 @@ def create_pool(self, **kwargs):
696697
async def test_standby_pool_01(self):
697698
for n in {1, 3, 5, 10, 20, 100}:
698699
with self.subTest(tasksnum=n):
699-
pool = await self.create_pool(database='postgres',
700-
min_size=5, max_size=10)
700+
pool = await self.create_pool(
701+
database='postgres', user='postgres',
702+
min_size=5, max_size=10)
701703

702704
async def worker():
703705
con = await pool.acquire()
@@ -710,7 +712,7 @@ async def worker():
710712

711713
async def test_standby_cursors(self):
712714
con = await self.standby_cluster.connect(
713-
database='postgres', loop=self.loop)
715+
database='postgres', user='postgres', loop=self.loop)
714716

715717
try:
716718
async with con.transaction():

tests/test_timeout.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ async def test_invalid_timeout(self):
114114
with self.subTest(command_timeout=command_timeout):
115115
with self.assertRaisesRegex(ValueError,
116116
'invalid command_timeout'):
117-
await self.cluster.connect(
118-
database='postgres', loop=self.loop,
119-
command_timeout=command_timeout)
117+
await self.connect(command_timeout=command_timeout)
120118

121119
# Note: negative timeouts are OK for method calls.
122120
for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}:

0 commit comments

Comments
 (0)