From 1d650ed99c094b7a7fe99813f78f3f90ef2c1ac8 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Wed, 19 Sep 2018 12:48:00 -0400 Subject: [PATCH] Add support for specifying multiple host addresses when connecting The behavior matches that of libpq. Multiple hosts can now be specified in the DSN, e.g. `postgres://host1,host2:5433`. The `host` and `port` arguments now also accept lists. Like libpq, asyncpg will select the first host it can successfully connect to. Closes: #257 Related: #352 --- asyncpg/connect_utils.py | 150 ++++++++++++++++++++++++++++----------- asyncpg/connection.py | 100 +++++++++++++++++++------- tests/test_connect.py | 62 +++++++++++++++- 3 files changed, 242 insertions(+), 70 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index be2dde2b..ac97a4c0 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -104,8 +104,10 @@ def _read_password_file(passfile: pathlib.Path) \ def _read_password_from_pgpass( *, passfile: typing.Optional[pathlib.Path], - hosts: typing.List[typing.Union[str, typing.Tuple[str, int]]], - port: int, database: str, user: str): + hosts: typing.List[str], + ports: typing.List[int], + database: str, + user: str): """Parse the pgpass file and return the matching password. :return: @@ -116,7 +118,7 @@ def _read_password_from_pgpass( if not passtab: return None - for host in hosts: + for host, port in zip(hosts, ports): if host.startswith('/'): # Unix sockets get normalized into 'localhost' host = 'localhost' @@ -137,27 +139,83 @@ def _read_password_from_pgpass( return None +def _validate_port_spec(hosts, port): + if isinstance(port, list): + # If there is a list of ports, its length must + # match that of the host list. + if len(port) != len(hosts): + raise exceptions.InterfaceError( + 'could not match {} port numbers to {} hosts'.format( + len(port), len(hosts))) + else: + port = [port for _ in range(len(hosts))] + + return port + + +def _parse_hostlist(hostlist, port): + if ',' in hostlist: + # A comma-separated list of host addresses. + hostspecs = hostlist.split(',') + else: + hostspecs = [hostlist] + + hosts = [] + hostlist_ports = [] + + if not port: + portspec = os.environ.get('PGPORT') + if portspec: + if ',' in portspec: + default_port = [int(p) for p in portspec.split(',')] + else: + default_port = int(portspec) + else: + default_port = 5432 + + default_port = _validate_port_spec(hostspecs, default_port) + + else: + port = _validate_port_spec(hostspecs, port) + + for i, hostspec in enumerate(hostspecs): + addr, _, hostspec_port = hostspec.partition(':') + hosts.append(addr) + + if not port: + if hostspec_port: + hostlist_ports.append(int(hostspec_port)) + else: + hostlist_ports.append(default_port[i]) + + if not port: + port = hostlist_ports + + return hosts, port + + def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, connect_timeout, server_settings): - if host is not None and not isinstance(host, str): - raise TypeError( - 'host argument is expected to be str, got {!r}'.format( - type(host))) + # `auth_hosts` is the version of host information for the purposes + # of reading the pgpass file. + auth_hosts = None if dsn: parsed = urllib.parse.urlparse(dsn) if parsed.scheme not in {'postgresql', 'postgres'}: raise ValueError( - 'invalid DSN: scheme is expected to be either of ' + 'invalid DSN: scheme is expected to be either ' '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) - if parsed.port and port is None: - port = int(parsed.port) + if not host and parsed.netloc: + if '@' in parsed.netloc: + auth, _, hostspec = parsed.netloc.partition('@') + else: + hostspec = parsed.netloc - if parsed.hostname and host is None: - host = parsed.hostname + host, port = _parse_hostlist(hostspec, port) if parsed.path and database is None: database = parsed.path @@ -178,13 +236,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'host' in query: val = query.pop('host') - if host is None: - host = val + if not host and val: + host, port = _parse_hostlist(val, port) if 'port' in query: - val = int(query.pop('port')) - if port is None: - port = val + val = query.pop('port') + if not port and val: + port = [int(p) for p in val.split(',')] if 'dbname' in query: val = query.pop('dbname') @@ -222,24 +280,19 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: server_settings = {**query, **server_settings} - # On env-var -> connection parameter conversion read here: - # https://www.postgresql.org/docs/current/static/libpq-envars.html - # Note that env values may be an empty string in cases when - # the variable is "unset" by setting it to an empty value - # `auth_hosts` is the version of host information for the purposes - # of reading the pgpass file. - auth_hosts = None - if host is None: - host = os.getenv('PGHOST') - if not host: - auth_hosts = ['localhost'] + if not host: + hostspec = os.environ.get('PGHOST') + if hostspec: + host, port = _parse_hostlist(hostspec, port) - if _system == 'Windows': - host = ['localhost'] - else: - host = ['/tmp', '/private/tmp', - '/var/pgsql_socket', '/run/postgresql', - 'localhost'] + if not host: + auth_hosts = ['localhost'] + + if _system == 'Windows': + host = ['localhost'] + else: + host = ['/run/postgresql', '/var/run/postgresql', + '/tmp', '/private/tmp', 'localhost'] if not isinstance(host, list): host = [host] @@ -247,15 +300,24 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if auth_hosts is None: auth_hosts = host - if port is None: - port = os.getenv('PGPORT') - if port: - port = int(port) + if not port: + portspec = os.environ.get('PGPORT') + if portspec: + if ',' in portspec: + port = [int(p) for p in portspec.split(',')] + else: + port = int(portspec) else: port = 5432 + + elif isinstance(port, (list, tuple)): + port = [int(p) for p in port] + else: port = int(port) + port = _validate_port_spec(host, port) + if user is None: user = os.getenv('PGUSER') if not user: @@ -293,19 +355,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if passfile is not None: password = _read_password_from_pgpass( - hosts=auth_hosts, port=port, database=database, user=user, + hosts=auth_hosts, ports=port, + database=database, user=user, passfile=passfile) addrs = [] - for h in host: + for h, p in zip(host, port): if h.startswith('/'): # UNIX socket name if '.s.PGSQL.' not in h: - h = os.path.join(h, '.s.PGSQL.{}'.format(port)) + h = os.path.join(h, '.s.PGSQL.{}'.format(p)) addrs.append(h) else: # TCP host/port - addrs.append((h, port)) + addrs.append((h, p)) if not addrs: raise ValueError( @@ -329,7 +392,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, sslmode = SSLMODES[ssl] except KeyError: modes = ', '.join(SSLMODES.keys()) - raise ValueError('`sslmode` parameter must be one of ' + modes) + raise exceptions.InterfaceError( + '`sslmode` parameter must be one of: {}'.format(modes)) # sslmode 'allow' is currently handled as 'prefer' because we're # missing the "retry with SSL" behavior for 'allow', but do have the diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 9e0dba8e..25047226 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1506,77 +1506,120 @@ async def connect(dsn=None, *, server_settings=None): r"""A coroutine to establish a connection to a PostgreSQL server. + The connection parameters may be specified either as a connection + URI in *dsn*, or as specific keyword arguments, or both. + If both *dsn* and keyword arguments are specified, the latter + override the corresponding values parsed from the connection URI. + The default values for the majority of arguments can be specified + using `environment variables `_. + Returns a new :class:`~asyncpg.connection.Connection` object. :param dsn: Connection arguments specified using as a single string in the - following format: - ``postgres://user:pass@host:port/database?option=value`` + `libpq connection URI format`_: + ``postgres://user:password@host:port/database?option=value``. + The following options are recognized by asyncpg: host, port, + user, database (or dbname), password, passfile, sslmode. + Unlike libpq, asyncpg will treat unrecognized options + as `server settings`_ to be used for the connection. :param host: - database host address or a path to the directory containing - database server UNIX socket (defaults to the default UNIX socket, - or the value of the ``PGHOST`` environment variable, if set). + Database host address as one of the following: + + - an IP address or a domain name; + - an absolute path to the directory containing the database + server Unix-domain socket (not supported on Windows); + - a sequence of any of the above, in which case the addresses + will be tried in order, and the first successful connection + will be returned. + + If not specified, asyncpg will try the following, in order: + + - host address(es) parsed from the *dsn* argument, + - the value of the ``PGHOST`` environment variable, + - on Unix, common directories used for PostgreSQL Unix-domain + sockets: ``"/run/postgresql"``, ``"/var/run/postgresl"``, + ``"/var/pgsql_socket"``, ``"/private/tmp"``, and ``"/tmp"``, + - ``"localhost"``. :param port: - connection port number (defaults to ``5432``, or the value of - the ``PGPORT`` environment variable, if set) + Port number to connect to at the server host + (or Unix-domain socket file extension). If multiple host + addresses were specified, this parameter may specify a + sequence of port numbers of the same length as the host sequence, + or it may specify a single port number to be used for all host + addresses. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGPORT`` environment variable, or ``5432`` if + neither is specified. :param user: - the name of the database role used for authentication - (defaults to the name of the effective user of the process - making the connection, or the value of ``PGUSER`` environment - variable, if set) + The name of the database role used for authentication. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGUSER`` environment variable, or the + operating system name of the user running the application. :param database: - the name of the database (defaults to the value of ``PGDATABASE`` - environment variable, if set.) + The name of the database to connect to. + + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``PGDATABASE`` environment variable, or the + operating system name of the user running the application. :param password: - password used for authentication + Password to be used for authentication, if the server requires + one. If not specified, the value parsed from the *dsn* argument + is used, or the value of the ``PGPASSWORD`` environment variable. + Note that the use of the environment variable is discouraged as + other users and applications may be able to read it without needing + specific privileges. It is recommended to use *passfile* instead. :param passfile: - the name of the file used to store passwords + The name of the file used to store passwords (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` - on Windows) + on Windows). :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. :param float timeout: - connection timeout in seconds. + Connection timeout in seconds. :param int statement_cache_size: - the size of prepared statement LRU cache. Pass ``0`` to + The size of prepared statement LRU cache. Pass ``0`` to disable the cache. :param int max_cached_statement_lifetime: - the maximum time in seconds a prepared statement will stay + The maximum time in seconds a prepared statement will stay in the cache. Pass ``0`` to allow statements be cached indefinitely. :param int max_cacheable_statement_size: - the maximum size of a statement that can be cached (15KiB by + The maximum size of a statement that can be cached (15KiB by default). Pass ``0`` to allow all statements to be cached regardless of their size. :param float command_timeout: - the default timeout for operations on this connection + The default timeout for operations on this connection (the default is ``None``: no timeout). :param ssl: - pass ``True`` or an `ssl.SSLContext `_ instance to + Pass ``True`` or an `ssl.SSLContext `_ instance to require an SSL connection. If ``True``, a default SSL context returned by `ssl.create_default_context() `_ will be used. :param dict server_settings: - an optional dict of server runtime parameters. Refer to - PostgreSQL documentation for a `list of supported options`_. + An optional dict of server runtime parameters. Refer to + PostgreSQL documentation for + a `list of supported options `_. :param Connection connection_class: - class of the returned connection object. Must be a subclass of + Class of the returned connection object. Must be a subclass of :class:`~asyncpg.connection.Connection`. :return: A :class:`~asyncpg.connection.Connection` instance. @@ -1613,8 +1656,13 @@ class of the returned connection object. Must be a subclass of .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext .. _create_default_context: https://docs.python.org/3/library/ssl.html#ssl.create_default_context - .. _list of supported options: + .. _server settings: https://www.postgresql.org/docs/current/static/runtime-config.html + .. _postgres envvars: + https://www.postgresql.org/docs/current/static/libpq-envars.html + .. _libpq connection URI format: + https://www.postgresql.org/docs/current/static/\ + libpq-connect.html#LIBPQ-CONNSTRING """ if not issubclass(connection_class, Connection): raise TypeError( diff --git a/tests/test_connect.py b/tests/test_connect.py index 119148cb..7226bebc 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -23,6 +23,7 @@ from asyncpg import connection from asyncpg import connect_utils from asyncpg import cluster as pg_cluster +from asyncpg import exceptions from asyncpg.serverversion import split_server_version_string _system = platform.uname().system @@ -284,6 +285,57 @@ class TestConnectParams(tb.TestCase): 'database': 'abcdef'}) }, + { + 'dsn': 'postgresql://user@host1,host2/db', + 'result': ([('host1', 5432), ('host2', 5432)], { + 'database': 'db', + 'user': 'user', + }) + }, + + { + 'dsn': 'postgresql://user@host1:1111,host2:2222/db', + 'result': ([('host1', 1111), ('host2', 2222)], { + 'database': 'db', + 'user': 'user', + }) + }, + + { + 'env': { + 'PGHOST': 'host1:1111,host2:2222', + 'PGUSER': 'foo', + }, + 'dsn': 'postgresql:///db', + 'result': ([('host1', 1111), ('host2', 2222)], { + 'database': 'db', + 'user': 'foo', + }) + }, + + { + 'env': { + 'PGUSER': 'foo', + }, + 'dsn': 'postgresql:///db?host=host1:1111,host2:2222', + 'result': ([('host1', 1111), ('host2', 2222)], { + 'database': 'db', + 'user': 'foo', + }) + }, + + { + 'env': { + 'PGUSER': 'foo', + }, + 'dsn': 'postgresql:///db', + 'host': ['host1', 'host2'], + 'result': ([('host1', 5432), ('host2', 5432)], { + 'database': 'db', + 'user': 'foo', + }) + }, + { 'dsn': 'postgresql://user3:123123@localhost:5555/' 'abcdef?param=sss¶m=123&host=testhost&user=testuser' @@ -332,6 +384,14 @@ class TestConnectParams(tb.TestCase): 'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam', 'error': (ValueError, 'invalid DSN') }, + { + 'dsn': 'postgresql://host1,host2,host3/db', + 'port': [111, 222], + 'error': ( + exceptions.InterfaceError, + 'could not match 2 port numbers to 3 hosts' + ) + }, ] @contextlib.contextmanager @@ -409,7 +469,7 @@ def run_testcase(self, testcase): # this because different SSLContexts don't compare equal. if isinstance(v, type) and isinstance(result[1].get(k), v): result[1][k] = v - self.assertEqual(expected, result) + self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) def test_test_connect_params_environ(self): self.assertNotIn('AAAAAAAAAA123', os.environ)