Skip to content

Commit bd0c7cf

Browse files
committed
Handle environments without home dir
1 parent 247b1a5 commit bd0c7cf

File tree

3 files changed

+58
-20
lines changed

3 files changed

+58
-20
lines changed

asyncpg/compat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncio
99
import pathlib
1010
import platform
11+
import typing
1112

1213

1314
SYSTEM = platform.uname().system
@@ -18,7 +19,7 @@
1819

1920
CSIDL_APPDATA = 0x001a
2021

21-
def get_pg_home_directory() -> pathlib.Path:
22+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
2223
# We cannot simply use expanduser() as that returns the user's
2324
# home directory, whereas Postgres stores its config in
2425
# %AppData% on Windows.
@@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
3031
return pathlib.Path(buf.value) / 'postgresql'
3132

3233
else:
33-
def get_pg_home_directory() -> pathlib.Path:
34-
return pathlib.Path.home()
34+
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
35+
try:
36+
return pathlib.Path.home()
37+
except (RuntimeError, KeyError):
38+
return None
3539

3640

3741
async def wait_closed(stream):

asyncpg/connect_utils.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,12 @@ def _parse_tls_version(tls_version):
249249
)
250250

251251

252-
def _dot_postgresql_path(filename) -> pathlib.Path:
253-
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
252+
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
253+
homedir = compat.get_pg_home_directory()
254+
if homedir is None:
255+
return None
256+
257+
return (homedir / '.postgresql' / filename).resolve()
254258

255259

256260
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
@@ -501,11 +505,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
501505
ssl.load_verify_locations(cafile=sslrootcert)
502506
ssl.verify_mode = ssl_module.CERT_REQUIRED
503507
else:
504-
sslrootcert = _dot_postgresql_path('root.crt')
505508
try:
509+
sslrootcert = _dot_postgresql_path('root.crt')
510+
assert sslrootcert is not None
506511
ssl.load_verify_locations(cafile=sslrootcert)
507-
except FileNotFoundError:
512+
except (AssertionError, FileNotFoundError):
508513
if sslmode > SSLMode.require:
514+
if sslrootcert is None:
515+
raise RuntimeError('Cannot determine home directory')
509516
raise ValueError(
510517
f'root certificate file "{sslrootcert}" does '
511518
f'not exist\nEither provide the file or '
@@ -526,18 +533,19 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
526533
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
527534
else:
528535
sslcrl = _dot_postgresql_path('root.crl')
529-
try:
530-
ssl.load_verify_locations(cafile=sslcrl)
531-
except FileNotFoundError:
532-
pass
533-
else:
534-
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
536+
if sslcrl is not None:
537+
try:
538+
ssl.load_verify_locations(cafile=sslcrl)
539+
except FileNotFoundError:
540+
pass
541+
else:
542+
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
535543

536544
if sslkey is None:
537545
sslkey = os.getenv('PGSSLKEY')
538546
if not sslkey:
539547
sslkey = _dot_postgresql_path('postgresql.key')
540-
if not sslkey.exists():
548+
if sslkey is not None and not sslkey.exists():
541549
sslkey = None
542550
if not sslpassword:
543551
sslpassword = ''
@@ -549,12 +557,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
549557
)
550558
else:
551559
sslcert = _dot_postgresql_path('postgresql.crt')
552-
try:
553-
ssl.load_cert_chain(
554-
sslcert, keyfile=sslkey, password=lambda: sslpassword
555-
)
556-
except FileNotFoundError:
557-
pass
560+
if sslcert is not None:
561+
try:
562+
ssl.load_cert_chain(
563+
sslcert, keyfile=sslkey, password=lambda: sslpassword
564+
)
565+
except FileNotFoundError:
566+
pass
558567

559568
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
560569
if hasattr(ssl, 'keylog_filename'):

tests/test_connect.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
7171
yield
7272

7373

74+
@contextlib.contextmanager
75+
def mock_no_home_dir():
76+
with unittest.mock.patch(
77+
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
78+
):
79+
yield
80+
7481
class TestSettings(tb.ConnectedTestCase):
7582

7683
async def test_get_settings_01(self):
@@ -1257,6 +1264,24 @@ async def test_connection_implicit_host(self):
12571264
user=conn_spec.get('user'))
12581265
await con.close()
12591266

1267+
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
1268+
async def test_connection_no_home_dir(self):
1269+
with mock_no_home_dir():
1270+
con = await self.connect(
1271+
dsn='postgresql://foo/',
1272+
user='postgres',
1273+
database='postgres',
1274+
host='localhost')
1275+
await con.fetchval('SELECT 42')
1276+
await con.close()
1277+
1278+
with self.assertRaisesRegex(RuntimeError, 'Cannot determine home directory'):
1279+
with mock_no_home_dir():
1280+
await self.connect(
1281+
host='localhost',
1282+
user='ssl_user',
1283+
ssl='verify-full')
1284+
12601285

12611286
class BaseTestSSLConnection(tb.ConnectedTestCase):
12621287
@classmethod

0 commit comments

Comments
 (0)