Skip to content

Add support for specifying multiple host addresses when connecting #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 107 additions & 43 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def _read_password_file(passfile: pathlib.Path) \

def _read_password_from_pgpass(
*, passfile: typing.Optional[pathlib.Path],
hosts: typing.List[typing.Union[str, typing.Tuple[str, int]]],
port: int, database: str, user: str):
hosts: typing.List[str],
ports: typing.List[int],
database: str,
user: str):
"""Parse the pgpass file and return the matching password.

:return:
Expand All @@ -116,7 +118,7 @@ def _read_password_from_pgpass(
if not passtab:
return None

for host in hosts:
for host, port in zip(hosts, ports):
if host.startswith('/'):
# Unix sockets get normalized into 'localhost'
host = 'localhost'
Expand All @@ -137,27 +139,83 @@ def _read_password_from_pgpass(
return None


def _validate_port_spec(hosts, port):
if isinstance(port, list):
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
raise exceptions.InterfaceError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
else:
port = [port for _ in range(len(hosts))]

return port


def _parse_hostlist(hostlist, port):
if ',' in hostlist:
# A comma-separated list of host addresses.
hostspecs = hostlist.split(',')
else:
hostspecs = [hostlist]

hosts = []
hostlist_ports = []

if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
default_port = [int(p) for p in portspec.split(',')]
else:
default_port = int(portspec)
else:
default_port = 5432

default_port = _validate_port_spec(hostspecs, default_port)

else:
port = _validate_port_spec(hostspecs, port)

for i, hostspec in enumerate(hostspecs):
addr, _, hostspec_port = hostspec.partition(':')
hosts.append(addr)

if not port:
if hostspec_port:
hostlist_ports.append(int(hostspec_port))
else:
hostlist_ports.append(default_port[i])

if not port:
port = hostlist_ports

return hosts, port


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
connect_timeout, server_settings):
if host is not None and not isinstance(host, str):
raise TypeError(
'host argument is expected to be str, got {!r}'.format(
type(host)))
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None

if dsn:
parsed = urllib.parse.urlparse(dsn)

if parsed.scheme not in {'postgresql', 'postgres'}:
raise ValueError(
'invalid DSN: scheme is expected to be either of '
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))

if parsed.port and port is None:
port = int(parsed.port)
if not host and parsed.netloc:
if '@' in parsed.netloc:
auth, _, hostspec = parsed.netloc.partition('@')
else:
hostspec = parsed.netloc

if parsed.hostname and host is None:
host = parsed.hostname
host, port = _parse_hostlist(hostspec, port)

if parsed.path and database is None:
database = parsed.path
Expand All @@ -178,13 +236,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

if 'host' in query:
val = query.pop('host')
if host is None:
host = val
if not host and val:
host, port = _parse_hostlist(val, port)

if 'port' in query:
val = int(query.pop('port'))
if port is None:
port = val
val = query.pop('port')
if not port and val:
port = [int(p) for p in val.split(',')]

if 'dbname' in query:
val = query.pop('dbname')
Expand Down Expand Up @@ -222,40 +280,44 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
server_settings = {**query, **server_settings}

# On env-var -> connection parameter conversion read here:
# https://www.postgresql.org/docs/current/static/libpq-envars.html
# Note that env values may be an empty string in cases when
# the variable is "unset" by setting it to an empty value
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
if host is None:
host = os.getenv('PGHOST')
if not host:
auth_hosts = ['localhost']
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
host, port = _parse_hostlist(hostspec, port)

if _system == 'Windows':
host = ['localhost']
else:
host = ['/tmp', '/private/tmp',
'/var/pgsql_socket', '/run/postgresql',
'localhost']
if not host:
auth_hosts = ['localhost']

if _system == 'Windows':
host = ['localhost']
else:
host = ['/run/postgresql', '/var/run/postgresql',
'/tmp', '/private/tmp', 'localhost']

if not isinstance(host, list):
host = [host]

if auth_hosts is None:
auth_hosts = host

if port is None:
port = os.getenv('PGPORT')
if port:
port = int(port)
if not port:
portspec = os.environ.get('PGPORT')
if portspec:
if ',' in portspec:
port = [int(p) for p in portspec.split(',')]
else:
port = int(portspec)
else:
port = 5432

elif isinstance(port, (list, tuple)):
port = [int(p) for p in port]

else:
port = int(port)

port = _validate_port_spec(host, port)

if user is None:
user = os.getenv('PGUSER')
if not user:
Expand Down Expand Up @@ -293,19 +355,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

if passfile is not None:
password = _read_password_from_pgpass(
hosts=auth_hosts, port=port, database=database, user=user,
hosts=auth_hosts, ports=port,
database=database, user=user,
passfile=passfile)

addrs = []
for h in host:
for h, p in zip(host, port):
if h.startswith('/'):
# UNIX socket name
if '.s.PGSQL.' not in h:
h = os.path.join(h, '.s.PGSQL.{}'.format(port))
h = os.path.join(h, '.s.PGSQL.{}'.format(p))
addrs.append(h)
else:
# TCP host/port
addrs.append((h, port))
addrs.append((h, p))

if not addrs:
raise ValueError(
Expand All @@ -329,7 +392,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
sslmode = SSLMODES[ssl]
except KeyError:
modes = ', '.join(SSLMODES.keys())
raise ValueError('`sslmode` parameter must be one of ' + modes)
raise exceptions.InterfaceError(
'`sslmode` parameter must be one of: {}'.format(modes))

# sslmode 'allow' is currently handled as 'prefer' because we're
# missing the "retry with SSL" behavior for 'allow', but do have the
Expand Down
100 changes: 74 additions & 26 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,77 +1506,120 @@ async def connect(dsn=None, *,
server_settings=None):
r"""A coroutine to establish a connection to a PostgreSQL server.

The connection parameters may be specified either as a connection
URI in *dsn*, or as specific keyword arguments, or both.
If both *dsn* and keyword arguments are specified, the latter
override the corresponding values parsed from the connection URI.
The default values for the majority of arguments can be specified
using `environment variables <postgres envvars>`_.

Returns a new :class:`~asyncpg.connection.Connection` object.

:param dsn:
Connection arguments specified using as a single string in the
following format:
``postgres://user:pass@host:port/database?option=value``
`libpq connection URI format`_:
``postgres://user:password@host:port/database?option=value``.
The following options are recognized by asyncpg: host, port,
user, database (or dbname), password, passfile, sslmode.
Unlike libpq, asyncpg will treat unrecognized options
as `server settings`_ to be used for the connection.

:param host:
database host address or a path to the directory containing
database server UNIX socket (defaults to the default UNIX socket,
or the value of the ``PGHOST`` environment variable, if set).
Database host address as one of the following:

- an IP address or a domain name;
- an absolute path to the directory containing the database
server Unix-domain socket (not supported on Windows);
- a sequence of any of the above, in which case the addresses
will be tried in order, and the first successful connection
will be returned.

If not specified, asyncpg will try the following, in order:

- host address(es) parsed from the *dsn* argument,
- the value of the ``PGHOST`` environment variable,
- on Unix, common directories used for PostgreSQL Unix-domain
sockets: ``"/run/postgresql"``, ``"/var/run/postgresl"``,
``"/var/pgsql_socket"``, ``"/private/tmp"``, and ``"/tmp"``,
- ``"localhost"``.

:param port:
connection port number (defaults to ``5432``, or the value of
the ``PGPORT`` environment variable, if set)
Port number to connect to at the server host
(or Unix-domain socket file extension). If multiple host
addresses were specified, this parameter may specify a
sequence of port numbers of the same length as the host sequence,
or it may specify a single port number to be used for all host
addresses.

If not specified, the value parsed from the *dsn* argument is used,
or the value of the ``PGPORT`` environment variable, or ``5432`` if
neither is specified.

:param user:
the name of the database role used for authentication
(defaults to the name of the effective user of the process
making the connection, or the value of ``PGUSER`` environment
variable, if set)
The name of the database role used for authentication.

If not specified, the value parsed from the *dsn* argument is used,
or the value of the ``PGUSER`` environment variable, or the
operating system name of the user running the application.

:param database:
the name of the database (defaults to the value of ``PGDATABASE``
environment variable, if set.)
The name of the database to connect to.

If not specified, the value parsed from the *dsn* argument is used,
or the value of the ``PGDATABASE`` environment variable, or the
operating system name of the user running the application.

:param password:
password used for authentication
Password to be used for authentication, if the server requires
one. If not specified, the value parsed from the *dsn* argument
is used, or the value of the ``PGPASSWORD`` environment variable.
Note that the use of the environment variable is discouraged as
other users and applications may be able to read it without needing
specific privileges. It is recommended to use *passfile* instead.

:param passfile:
the name of the file used to store passwords
The name of the file used to store passwords
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows)
on Windows).

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.

:param float timeout:
connection timeout in seconds.
Connection timeout in seconds.

:param int statement_cache_size:
the size of prepared statement LRU cache. Pass ``0`` to
The size of prepared statement LRU cache. Pass ``0`` to
disable the cache.

:param int max_cached_statement_lifetime:
the maximum time in seconds a prepared statement will stay
The maximum time in seconds a prepared statement will stay
in the cache. Pass ``0`` to allow statements be cached
indefinitely.

:param int max_cacheable_statement_size:
the maximum size of a statement that can be cached (15KiB by
The maximum size of a statement that can be cached (15KiB by
default). Pass ``0`` to allow all statements to be cached
regardless of their size.

:param float command_timeout:
the default timeout for operations on this connection
The default timeout for operations on this connection
(the default is ``None``: no timeout).

:param ssl:
pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
Pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
require an SSL connection. If ``True``, a default SSL context
returned by `ssl.create_default_context() <create_default_context_>`_
will be used.

:param dict server_settings:
an optional dict of server runtime parameters. Refer to
PostgreSQL documentation for a `list of supported options`_.
An optional dict of server runtime parameters. Refer to
PostgreSQL documentation for
a `list of supported options <server settings>`_.

:param Connection connection_class:
class of the returned connection object. Must be a subclass of
Class of the returned connection object. Must be a subclass of
:class:`~asyncpg.connection.Connection`.

:return: A :class:`~asyncpg.connection.Connection` instance.
Expand Down Expand Up @@ -1613,8 +1656,13 @@ class of the returned connection object. Must be a subclass of
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
.. _list of supported options:
.. _server settings:
https://www.postgresql.org/docs/current/static/runtime-config.html
.. _postgres envvars:
https://www.postgresql.org/docs/current/static/libpq-envars.html
.. _libpq connection URI format:
https://www.postgresql.org/docs/current/static/\
libpq-connect.html#LIBPQ-CONNSTRING
"""
if not issubclass(connection_class, Connection):
raise TypeError(
Expand Down
Loading