diff --git a/docs/source/api.rst b/docs/source/api.rst index 9a4cc014..3481480f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -486,7 +486,12 @@ Name of the database to query. .. Note:: - The default database can be set on the Neo4j instance settings. + The default database can be set on the Neo4j instance settings. + +.. Note:: + It is recommended to always specify the database explicitly when possible. + This allows the driver to work more efficiently, as it will not have to + resolve the home database first. .. code-block:: python @@ -499,6 +504,41 @@ Name of the database to query. :Default: ``neo4j.DEFAULT_DATABASE`` +.. _impersonated-user-ref: + +``impersonated_user`` +--------------------- +Name of the user to impersonate. +This means that all actions in the session will be executed in the security +context of the impersonated user. For this, the user for which the +:class:``Driver`` has been created needs to have the appropriate permissions. + +:Type: ``str``, None + + +.. py:data:: None + :noindex: + + Will not perform impersonation. + + +.. Note:: + + The server or all servers of the cluster need to support impersonation when. + Otherwise, the driver will raise :py:exc:`.ConfigurationError` + as soon as it encounters a server that does not. + + +.. code-block:: python + + from neo4j import GraphDatabase + driver = GraphDatabase.driver(uri, auth=(user, password)) + session = driver.session(impersonated_user="alice") + + +:Default: ``None`` + + .. _default-access-mode-ref: ``default_access_mode`` diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 950bf6fa..707d0c43 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -329,11 +329,9 @@ def supports_multi_db(self): :return: Returns true if the server or cluster the driver connects to supports multi-databases, otherwise false. :rtype: bool """ - cx = self._pool.acquire(access_mode=READ_ACCESS, timeout=self._pool.workspace_config.connection_acquisition_timeout, database=self._pool.workspace_config.database) - support = cx.supports_multiple_databases - self._pool.release(cx) - - return support + with self.session() as session: + session._connect(READ_ACCESS) + return session._connection.supports_multiple_databases class BoltDriver(Direct, Driver): @@ -447,6 +445,7 @@ def _verify_routing_connectivity(self): routing_info[ix] = self._pool.fetch_routing_info( address=table.routers[0], database=self._default_workspace_config.database, + imp_user=self._default_workspace_config.impersonated_user, bookmarks=None, timeout=self._default_workspace_config .connection_acquisition_timeout diff --git a/neo4j/conf.py b/neo4j/conf.py index 80ad44c8..f74dd2e5 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -283,6 +283,10 @@ class WorkspaceConfig(Config): #: Fetch Size fetch_size = 1000 + #: User to impersonate + impersonated_user = None + # Note that you need appropriate permissions to do so. + class SessionConfig(WorkspaceConfig): """ Session configuration. diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 7d52c336..40644e72 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -392,7 +392,7 @@ def __del__(self): pass @abc.abstractmethod - def route(self, database=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE message; for earlier versions, a procedure call is made via @@ -400,6 +400,7 @@ def route(self, database=None, bookmarks=None): sent to the network, and a response is fetched. :param database: database for which to fetch a routing table + :param imp_user: the user to impersonate :param bookmarks: iterable of bookmark values after which this transaction should begin :return: dictionary of raw routing data @@ -407,8 +408,8 @@ def route(self, database=None, bookmarks=None): pass @abc.abstractmethod - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, - timeout=None, db=None, **handlers): + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): """ Appends a RUN message to the output queue. :param query: Cypher query string @@ -418,6 +419,7 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, :param metadata: custom metadata dictionary to attach to the transaction :param timeout: timeout for transaction execution (seconds) :param db: name of the database against which to begin the transaction + :param imp_user: the user to impersonate :param handlers: handler functions passed into the returned Response object :return: Response object """ @@ -446,7 +448,8 @@ def pull(self, n=-1, qid=-1, **handlers): pass @abc.abstractmethod - def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) @@ -454,6 +457,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, :param metadata: custom metadata dictionary to attach to the transaction :param timeout: timeout for transaction execution (seconds) :param db: name of the database against which to begin the transaction + :param imp_user: the user to impersonate :param handlers: handler functions passed into the returned Response object :return: Response object """ @@ -826,7 +830,7 @@ def __repr__(self): return "<{} address={!r}>".format(self.__class__.__name__, self.address) def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + imp_user=None, bookmarks=None): # The access_mode and database is not needed for a direct connection, its just there for consistency. return self._acquire(self.address, timeout) @@ -907,15 +911,24 @@ def get_default_database_router_addresses(self): def get_routing_table_for_default_database(self): return self.routing_tables[self.workspace_config.database] - def create_routing_table(self, database): - if database not in self.routing_tables: - self.routing_tables[database] = RoutingTable(database=database, routers=self.get_default_database_initial_router_addresses()) + def get_or_create_routing_table(self, database): + with self.refresh_lock: + if database not in self.routing_tables: + self.routing_tables[database] = RoutingTable( + database=database, + routers=self.get_default_database_initial_router_addresses() + ) + return self.routing_tables[database] - def fetch_routing_info(self, address, database, bookmarks, timeout): + def fetch_routing_info(self, address, database, imp_user, bookmarks, + timeout): """ Fetch raw routing info from a given router address. :param address: router address :param database: the database name to get routing table for + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched :param timeout: connection acquisition timeout in seconds @@ -929,7 +942,9 @@ def fetch_routing_info(self, address, database, bookmarks, timeout): cx = self._acquire(address, timeout) try: routing_table = cx.route( - database or self.workspace_config.database, bookmarks + database or self.workspace_config.database, + imp_user or self.workspace_config.impersonated_user, + bookmarks ) finally: self.release(cx) @@ -954,21 +969,26 @@ def fetch_routing_info(self, address, database, bookmarks, timeout): self.deactivate(address) return routing_table - def fetch_routing_table(self, *, address, timeout, database, bookmarks): + def fetch_routing_table(self, *, address, timeout, database, imp_user, + bookmarks): """ Fetch a routing table from a given router address. :param address: router address :param timeout: seconds :param database: the database name :type: str + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table :return: a new RoutingTable instance or None if the given router is currently unable to provide routing information """ try: - new_routing_info = self.fetch_routing_info(address, database, - bookmarks, timeout) + new_routing_info = self.fetch_routing_info( + address, database, imp_user, bookmarks, timeout + ) except (ServiceUnavailable, SessionExpired): new_routing_info = None if not new_routing_info: @@ -977,7 +997,10 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): else: servers = new_routing_info[0]["servers"] ttl = new_routing_info[0]["ttl"] - new_routing_table = RoutingTable.parse_routing_info(database=database, servers=servers, ttl=ttl) + database = new_routing_info[0].get("db", database) + new_routing_table = RoutingTable.parse_routing_info( + database=database, servers=servers, ttl=ttl + ) # Parse routing info and count the number of each type of server num_routers = len(new_routing_table.routers) @@ -1000,8 +1023,8 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): # At least one of each is fine, so return this table return new_routing_table - def update_routing_table_from(self, *routers, database=None, - bookmarks=None): + def _update_routing_table_from(self, *routers, database=None, imp_user=None, + bookmarks=None, database_callback=None): """ Try to update routing tables with the given routers. :return: True if the routing table is successfully updated, @@ -1013,67 +1036,86 @@ def update_routing_table_from(self, *routers, database=None, new_routing_table = self.fetch_routing_table( address=address, timeout=self.pool_config.connection_timeout, - database=database, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks ) if new_routing_table is not None: - self.routing_tables[database].update(new_routing_table) + new_databse = new_routing_table.database + self.get_or_create_routing_table(new_databse)\ + .update(new_routing_table) log.debug( "[#0000] C: address=%r (%r)", - address, self.routing_tables[database] + address, self.routing_tables[new_databse] ) + if callable(database_callback): + database_callback(new_databse) return True self.deactivate(router) return False - def update_routing_table(self, *, database, bookmarks): + def update_routing_table(self, *, database, imp_user, bookmarks, + database_callback=None): """ Update the routing table from the first router able to provide valid routing information. :param database: The database name + :param imp_user: the user to impersonate while fetching the routing + table + :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param database_callback: A callback function that will be called with + the database name as only argument when a new routing table has been + acquired. This database name might different from `database` if that + was None and the underlying protocol supports reporting back the + actual database. :raise neo4j.exceptions.ServiceUnavailable: """ - # copied because it can be modified - existing_routers = set(self.routing_tables[database].routers) - - prefer_initial_routing_address = \ - self.routing_tables[database].missing_fresh_writer() + with self.refresh_lock: + # copied because it can be modified + existing_routers = set( + self.get_or_create_routing_table(database).routers + ) - if prefer_initial_routing_address: - # TODO: Test this state - if self.update_routing_table_from( - self.first_initial_routing_address, database=database, - bookmarks=bookmarks + prefer_initial_routing_address = \ + self.routing_tables[database].initialized_without_writers + + if prefer_initial_routing_address: + # TODO: Test this state + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return + if self._update_routing_table_from( + *(existing_routers - {self.first_initial_routing_address}), + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback ): - # Why is only the first initial routing address used? return - if self.update_routing_table_from( - *(existing_routers - {self.first_initial_routing_address}), - database=database, bookmarks=bookmarks - ): - return - if not prefer_initial_routing_address: - if self.update_routing_table_from( - self.first_initial_routing_address, database=database, - bookmarks=bookmarks - ): - # Why is only the first initial routing address used? - return + if not prefer_initial_routing_address: + if self._update_routing_table_from( + self.first_initial_routing_address, database=database, + imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ): + # Why is only the first initial routing address used? + return - # None of the routers have been successful, so just fail - log.error("Unable to retrieve routing information") - raise ServiceUnavailable("Unable to retrieve routing information") + # None of the routers have been successful, so just fail + log.error("Unable to retrieve routing information") + raise ServiceUnavailable("Unable to retrieve routing information") def update_connection_pool(self, *, database): - servers = self.routing_tables[database].servers() + servers = self.get_or_create_routing_table(database).servers() for address in list(self.connections): if address.unresolved not in servers: super(Neo4jPool, self).deactivate(address) - def ensure_routing_table_is_fresh(self, *, access_mode, database, - bookmarks): + def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, + bookmarks, database_callback=None): """ Update the routing table if stale. This method performs two freshness checks, before and after acquiring @@ -1087,24 +1129,29 @@ def ensure_routing_table_is_fresh(self, *, access_mode, database, :return: `True` if an update was required, `False` otherwise. """ from neo4j.api import READ_ACCESS - if self.routing_tables[database].is_fresh(readonly=(access_mode == READ_ACCESS)): - # Readers are fresh. - return False with self.refresh_lock: - - self.update_routing_table(database=database, bookmarks=bookmarks) + if self.get_or_create_routing_table(database)\ + .is_fresh(readonly=(access_mode == READ_ACCESS)): + # Readers are fresh. + return False + + self.update_routing_table( + database=database, imp_user=imp_user, bookmarks=bookmarks, + database_callback=database_callback + ) self.update_connection_pool(database=database) for database in list(self.routing_tables.keys()): # Remove unused databases in the routing table # Remove the routing table after a timeout = TTL + 30s log.debug("[#0000] C: database=%s", database) - if self.routing_tables[database].should_be_purged_from_memory() and database != self.workspace_config.database: + if (self.routing_tables[database].should_be_purged_from_memory() + and database != self.workspace_config.database): del self.routing_tables[database] return True - def _select_address(self, *, access_mode, database, bookmarks): + def _select_address(self, *, access_mode, database): from neo4j.api import READ_ACCESS """ Selects the address with the fewest in-use connections. """ @@ -1134,24 +1181,25 @@ def acquire(self, access_mode=None, timeout=None, database=None, if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) if not timeout: - raise ClientError("'timeout' must be a float larger than 0; {}".format(timeout)) + raise ClientError("'timeout' must be a float larger than 0; {}" + .format(timeout)) from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) with self.refresh_lock: - self.create_routing_table(database) - log.debug("[#0000] C: %r", self.routing_tables) + log.debug("[#0000] C: %r", + self.routing_tables) self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, bookmarks=bookmarks + access_mode=access_mode, database=database, imp_user=None, + bookmarks=bookmarks ) while True: try: - # Get an address for a connection that have the fewest in-use connections. - address = self._select_address( - access_mode=access_mode, database=database, - bookmarks=bookmarks - ) + # Get an address for a connection that have the fewest in-use + # connections. + address = self._select_address(access_mode=access_mode, + database=database) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err try: diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index fe7608b0..4ed722d8 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -165,12 +165,22 @@ def hello(self): self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, bookmarks=None): - if database is not None: # default database - raise ConfigurationError("Database name parameter for selecting database is not " - "supported in Bolt Protocol {!r}. Database name {!r}. " - "Server Agent {!r}.".format(Bolt3.PROTOCOL_VERSION, database, - self.server_info.agent)) + def route(self, database=None, imp_user=None, bookmarks=None): + if database is not None: + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}. " + "Server Agent {!r}".format( + self.PROTOCOL_VERSION, database, self.server_info.agent + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) metadata = {} records = [] @@ -197,9 +207,22 @@ def fail(md): routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): if db is not None: - raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) if not parameters: parameters = {} extra = {} @@ -238,9 +261,22 @@ def pull(self, n=-1, qid=-1, **handlers): log.debug("[#%04X] C: PULL_ALL", self.local_port) self._append(b"\x3F", (), Response(self, "pull", **handlers)) - def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): if db is not None: - raise ConfigurationError("Database name parameter for selecting database is not supported in Bolt Protocol {!r}. Database name {!r}.".format(Bolt3.PROTOCOL_VERSION, db)) + raise ConfigurationError( + "Database name parameter for selecting database is not " + "supported in Bolt Protocol {!r}. Database name {!r}.".format( + self.PROTOCOL_VERSION, db + ) + ) + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 4b7c2045..086f7baf 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -32,6 +32,7 @@ Version, ) from neo4j.exceptions import ( + ConfigurationError, DatabaseUnavailable, DriverError, ForbiddenOnReadOnlyDatabase, @@ -122,7 +123,14 @@ def hello(self): self.fetch_all() check_supported_server_product(self.server_info.agent) - def route(self, database=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) metadata = {} records = [] @@ -160,7 +168,15 @@ def fail(md): routing_info = [dict(zip(metadata.get("fields", ()), values)) for values in records] return routing_info - def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) if not parameters: parameters = {} extra = {} @@ -206,7 +222,14 @@ def pull(self, n=-1, qid=-1, **handlers): self._append(b"\x3F", (extra,), Response(self, "pull", **handlers)) def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, - db=None, **handlers): + db=None, imp_user=None, **handlers): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) extra = {} if mode in (READ_ACCESS, "r"): extra["mode"] = "r" # It will default to mode "w" if nothing is specified @@ -376,7 +399,14 @@ class Bolt4x3(Bolt4x2): PROTOCOL_VERSION = Version(4, 3) - def route(self, database=None, bookmarks=None): + def route(self, database=None, imp_user=None, bookmarks=None): + if imp_user is not None: + raise ConfigurationError( + "Impersonation is not supported in Bolt Protocol {!r}. " + "Trying to impersonate {!r}.".format( + self.PROTOCOL_VERSION, imp_user + ) + ) def fail(md): from neo4j._exceptions import BoltRoutingError @@ -384,12 +414,15 @@ def fail(md): if code == "Neo.ClientError.Database.DatabaseNotFound": return # surface this error to the user elif code == "Neo.ClientError.Procedure.ProcedureNotFound": - raise BoltRoutingError("Server does not support routing", self.unresolved_address) + raise BoltRoutingError("Server does not support routing", + self.unresolved_address) else: - raise BoltRoutingError("Routing support broken on server", self.unresolved_address) + raise BoltRoutingError("Routing support broken on server", + self.unresolved_address) routing_context = self.routing_context or {} - log.debug("[#%04X] C: ROUTE %r %r", self.local_port, routing_context, database) + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, database) metadata = {} if bookmarks is None: bookmarks = [] @@ -440,3 +473,103 @@ class Bolt4x4(Bolt4x3): """ PROTOCOL_VERSION = Version(4, 4) + + def route(self, database=None, imp_user=None, bookmarks=None): + def fail(md): + from neo4j._exceptions import BoltRoutingError + code = md.get("code") + if code == "Neo.ClientError.Database.DatabaseNotFound": + return # surface this error to the user + elif code == "Neo.ClientError.Procedure.ProcedureNotFound": + raise BoltRoutingError("Server does not support routing", + self.unresolved_address) + else: + raise BoltRoutingError("Routing support broken on server", + self.unresolved_address) + + routing_context = self.routing_context or {} + db_context = {} + if database is not None: + db_context.update(db=database) + if imp_user is not None: + db_context.update(imp_user=imp_user) + log.debug("[#%04X] C: ROUTE %r %r %r", self.local_port, + routing_context, bookmarks, db_context) + metadata = {} + if bookmarks is None: + bookmarks = [] + else: + bookmarks = list(bookmarks) + self._append(b"\x66", (routing_context, bookmarks, db_context), + response=Response(self, "route", + on_success=metadata.update, + on_failure=fail)) + self.send_all() + self.fetch_all() + return [metadata.get("rt")] + + def run(self, query, parameters=None, mode=None, bookmarks=None, + metadata=None, timeout=None, db=None, imp_user=None, **handlers): + if not parameters: + parameters = {} + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + fields = (query, parameters, extra) + log.debug("[#%04X] C: RUN %s", self.local_port, + " ".join(map(repr, fields))) + if query.upper() == u"COMMIT": + self._append(b"\x10", fields, CommitResponse(self, "run", + **handlers)) + else: + self._append(b"\x10", fields, Response(self, "run", **handlers)) + + def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, + db=None, imp_user=None, **handlers): + extra = {} + if mode in (READ_ACCESS, "r"): + # It will default to mode "w" if nothing is specified + extra["mode"] = "r" + if db: + extra["db"] = db + if imp_user: + extra["imp_user"] = imp_user + if bookmarks: + try: + extra["bookmarks"] = list(bookmarks) + except TypeError: + raise TypeError("Bookmarks must be provided within an iterable") + if metadata: + try: + extra["tx_metadata"] = dict(metadata) + except TypeError: + raise TypeError("Metadata must be coercible to a dict") + if timeout: + try: + extra["tx_timeout"] = int(1000 * timeout) + except TypeError: + raise TypeError("Timeout must be specified as a number of " + "seconds") + log.debug("[#%04X] C: BEGIN %r", self.local_port, extra) + self._append(b"\x11", (extra,), Response(self, "begin", **handlers)) diff --git a/neo4j/routing.py b/neo4j/routing.py index a0ef48c2..8303f4c2 100644 --- a/neo4j/routing.py +++ b/neo4j/routing.py @@ -110,6 +110,7 @@ def __init__(self, *, database, routers=(), readers=(), writers=(), ttl=0): self.routers = OrderedSet(routers) self.readers = OrderedSet(readers) self.writers = OrderedSet(writers) + self.initialized_without_writers = not self.writers self.last_updated_time = perf_counter() self.ttl = ttl self.database = database @@ -142,14 +143,6 @@ def is_fresh(self, readonly=False): log.debug("[#0000] C: Table has_server_for_mode=%r", has_server_for_mode) return not expired and self.routers and has_server_for_mode - def missing_fresh_writer(self): - """ Check if the routing table have a fresh write address. - - :return: Return true if it does not have a fresh write address. - :rtype: bool - """ - return not self.is_fresh(readonly=False) - def should_be_purged_from_memory(self): """ Check if the routing table is stale and not used for a long time and should be removed from memory. @@ -168,6 +161,7 @@ def update(self, new_routing_table): self.routers.replace(new_routing_table.routers) self.readers.replace(new_routing_table.readers) self.writers.replace(new_routing_table.writers) + self.initialized_without_writers = not self.writers self.last_updated_time = perf_counter() self.ttl = new_routing_table.ttl log.debug("[#0000] S: table=%r", self) diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 75921b0b..647153bb 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -67,9 +67,10 @@ def _tx_ready_run(self, query, parameters, **kwparameters): # BEGIN+RUN does not carry any extra on the RUN message. # BEGIN {extra} # RUN "query" {parameters} {extra} - self._run(query, parameters, None, None, None, **kwparameters) + self._run(query, parameters, None, None, None, None, **kwparameters) - def _run(self, query, parameters, db, access_mode, bookmarks, **kwparameters): + def _run(self, query, parameters, db, imp_user, access_mode, bookmarks, + **kwparameters): query_text = str(query) # Query or string object query_metadata = getattr(query, "metadata", None) query_timeout = getattr(query, "timeout", None) @@ -104,6 +105,7 @@ def on_failed_attach(metadata): metadata=query_metadata, timeout=query_timeout, db=db, + imp_user=imp_user, on_success=on_attached, on_failure=on_failed_attach, ) diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index 15f08b2e..84a05777 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -41,6 +41,7 @@ TransientError, TransactionError, ) +from neo4j.io import Neo4jPool from neo4j.work import Workspace from neo4j.work.result import Result from neo4j.work.transaction import Transaction @@ -81,6 +82,9 @@ class Session(Workspace): # :class:`.Transaction` should be carried out. _bookmarks = None + # Sessions are supposed to cache the database on which to operate. + _cached_database = False + # The state this session is in. _state_failed = False @@ -106,7 +110,11 @@ def __exit__(self, exception_type, exception_value, traceback): self._state_failed = True self.close() - def _connect(self, access_mode, database): + def _set_cached_database(self, database): + self._cached_database = True + self._config.database = database + + def _connect(self, access_mode): if access_mode is None: access_mode = self._config.default_access_mode if self._connection: @@ -115,10 +123,27 @@ def _connect(self, access_mode, database): self._connection.send_all() self._connection.fetch_all() self._disconnect() + if not self._cached_database: + if (self._config.database is not None + or not isinstance(self._pool, Neo4jPool)): + self._set_cached_database(self._config.database) + else: + # This is the first time we open a connection to a server in a + # cluster environment for this session without explicitly + # configured database. Hence, we request a routing table update + # to try to fetch the home database. If provided by the server, + # we shall use this database explicitly for all subsequent + # actions within this session. + self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._bookmarks, + database_callback=self._set_cached_database + ) self._connection = self._pool.acquire( access_mode=access_mode, timeout=self._config.connection_acquisition_timeout, - database=database, + database=self._config.database, bookmarks=self._bookmarks ) @@ -218,7 +243,7 @@ def run(self, query, parameters=None, **kwparameters): self._autoResult._buffer_all() # This will buffer upp all records for the previous auto-transaction if not self._connection: - self._connect(self._config.default_access_mode, database=self._config.database) + self._connect(self._config.default_access_mode) cx = self._connection protocol_version = cx.PROTOCOL_VERSION server_info = cx.server_info @@ -231,7 +256,8 @@ def run(self, query, parameters=None, **kwparameters): ) self._autoResult._run( query, parameters, self._config.database, - self._config.default_access_mode, self._bookmarks, **kwparameters + self._config.impersonated_user, self._config.default_access_mode, + self._bookmarks, **kwparameters ) return self._autoResult @@ -266,16 +292,18 @@ def _transaction_error_handler(self, _): self._transaction = None self._disconnect() - def _open_transaction(self, *, access_mode, database, metadata=None, + def _open_transaction(self, *, access_mode, metadata=None, timeout=None): - self._connect(access_mode=access_mode, database=database) + self._connect(access_mode=access_mode) self._transaction = Transaction( self._connection, self._config.fetch_size, self._transaction_closed_handler, self._transaction_error_handler ) - self._transaction._begin(database, self._bookmarks, access_mode, - metadata, timeout) + self._transaction._begin( + self._config.database, self._config.impersonated_user, + self._bookmarks, access_mode, metadata, timeout + ) def begin_transaction(self, metadata=None, timeout=None): """ Begin a new unmanaged transaction. Creates a new :class:`.Transaction` within this session. @@ -312,7 +340,8 @@ def begin_transaction(self, metadata=None, timeout=None): if self._transaction: raise TransactionError("Explicit transaction already open") - self._open_transaction(access_mode=self._config.default_access_mode, database=self._config.database, metadata=metadata, timeout=timeout) + self._open_transaction(access_mode=self._config.default_access_mode, + metadata=metadata, timeout=timeout) return self._transaction @@ -332,7 +361,7 @@ def _run_transaction(self, access_mode, transaction_function, *args, **kwargs): while True: try: - self._open_transaction(access_mode=access_mode, database=self._config.database, metadata=metadata, timeout=timeout) + self._open_transaction(access_mode=access_mode, metadata=metadata, timeout=timeout) tx = self._transaction try: result = transaction_function(tx, *args, **kwargs) diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 0d77f3e0..74676748 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -62,9 +62,11 @@ def __exit__(self, exception_type, exception_value, traceback): self.commit() self.close() - def _begin(self, database, bookmarks, access_mode, metadata, timeout): - self._connection.begin(bookmarks=bookmarks, metadata=metadata, - timeout=timeout, mode=access_mode, db=database) + def _begin(self, database, imp_user, bookmarks, access_mode, metadata, timeout): + self._connection.begin( + bookmarks=bookmarks, metadata=metadata, timeout=timeout, + mode=access_mode, db=database, imp_user=imp_user + ) self._error_handling_connection.send_all() self._error_handling_connection.fetch_all() diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index a23091fb..7b84b6ce 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -196,12 +196,14 @@ def NewSession(backend, data): elif access_mode == "w": access_mode = neo4j.WRITE_ACCESS else: - raise Exception("Unknown access mode:" + access_mode) + raise ValueError("Unknown access mode:" + access_mode) config = { "default_access_mode": access_mode, "bookmarks": data["bookmarks"], "database": data["database"], - "fetch_size": data.get("fetchSize", None) + "fetch_size": data.get("fetchSize", None), + "impersonated_user": data.get("impersonatedUser", None), + } session = driver.session(**config) key = backend.next_key() @@ -402,8 +404,7 @@ def ForcedRoutingTableUpdate(backend, data): database = data["database"] bookmarks = data["bookmarks"] with driver._pool.refresh_lock: - driver._pool.create_routing_table(database) - driver._pool.update_routing_table(database=database, + driver._pool.update_routing_table(database=database, imp_user=None, bookmarks=bookmarks) backend.send_response("Driver", {"id": driver_id}) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 5c1291ab..0c1abf51 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -6,6 +6,8 @@ "Test makes assumptions about how verify_connectivity is implemented", "stub.routing.test_routing_v4x3.RoutingV4x3.test_should_successfully_acquire_rt_when_router_ip_changes": "Test makes assumptions about how verify_connectivity is implemented", + "stub.routing.test_routing_v4x4.RoutingV4x4.test_should_successfully_acquire_rt_when_router_ip_changes": + "Test makes assumptions about how verify_connectivity is implemented", "stub.retry.test_retry_clustering.TestRetryClustering.test_retry_ForbiddenOnReadOnlyDatabase_ChangingWriter": "Test makes assumptions about how verify_connectivity is implemented", "stub.authorization.test_authorization.TestAuthorizationV4x3.test_should_retry_on_auth_expired_on_begin_using_tx_function": @@ -31,6 +33,8 @@ "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, + "Feature:Bolt:4.4": true, + "Feature:Impersonation": true, "AuthorizationExpiredTreatment": true, "Optimization:ConnectionReuse": true, "Optimization:EagerTransactionBegin": true, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 64aa1f3d..015ba64d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -299,7 +299,8 @@ def bolt_driver(target, auth): def neo4j_driver(target, auth): try: driver = GraphDatabase.neo4j_driver(target, auth=auth) - driver._pool.update_routing_table(database=None, bookmarks=None) + driver._pool.update_routing_table(database=None, imp_user=None, + bookmarks=None) except ServiceUnavailable as error: if isinstance(error.__cause__, BoltHandshakeError): pytest.skip(error.args[0]) diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/io/test_class_bolt4x4.py index 562a720c..19378a1c 100644 --- a/tests/unit/io/test_class_bolt4x4.py +++ b/tests/unit/io/test_class_bolt4x4.py @@ -56,31 +56,43 @@ def test_conn_is_not_stale(fake_socket, set_stale): connection.set_stale() assert connection.stale() is set_stale - -def test_db_extra_in_begin(fake_socket): +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) socket = fake_socket(address) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) - connection.begin(db="something") + connection.begin(*args, **kwargs) connection.send_all() - tag, fields = socket.pop_message() + tag, is_fields = socket.pop_message() assert tag == b"\x11" - assert len(fields) == 1 - assert fields[0] == {"db": "something"} - - -def test_db_extra_in_run(fake_socket): + assert tuple(is_fields) == expected_fields + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): address = ("127.0.0.1", 7687) socket = fake_socket(address) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) - connection.run("", {}, db="something") + connection.run(*args, **kwargs) connection.send_all() - tag, fields = socket.pop_message() + tag, is_fields = socket.pop_message() assert tag == b"\x10" - assert len(fields) == 3 - assert fields[0] == "" - assert fields[1] == {} - assert fields[2] == {"db": "something"} + assert tuple(is_fields) == expected_fields def test_n_extra_in_discard(fake_socket): diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index ccd79501..6e685eda 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -63,6 +63,7 @@ "bookmarks": (), "default_access_mode": WRITE_ACCESS, "database": None, + "impersonated_user": None, "fetch_size": 100, } diff --git a/tests/unit/test_driver.py b/tests/unit/test_driver.py index 1e35cbf4..0c1192e4 100644 --- a/tests/unit/test_driver.py +++ b/tests/unit/test_driver.py @@ -124,7 +124,10 @@ def test_driver_trust_config_error( def test_driver_opens_write_session_by_default(uri, mocker): driver = GraphDatabase.driver(uri) from neo4j.work.transaction import Transaction - with driver.session() as session: + # we set a specific db, because else the driver would try to fetch a RT + # to get hold of the actual home database (which won't work in this + # unittest) + with driver.session(database="foobar") as session: acquire_mock = mocker.patch.object(session._pool, "acquire", autospec=True) tx_begin_mock = mocker.patch.object(Transaction, "_begin", @@ -140,6 +143,7 @@ def test_driver_opens_write_session_by_default(uri, mocker): tx, mocker.ANY, mocker.ANY, + mocker.ANY, WRITE_ACCESS, mocker.ANY, mocker.ANY diff --git a/tests/unit/work/test_result.py b/tests/unit/work/test_result.py index df33293a..21627a30 100644 --- a/tests/unit/work/test_result.py +++ b/tests/unit/work/test_result.py @@ -219,7 +219,7 @@ def test_result_iteration(method): records = [[1], [2], [3], [4], [5]] connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), 2, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) _fetch_and_compare_all_records(result, "x", records, method) @@ -232,9 +232,9 @@ def test_parallel_result_iteration(method, invert_fetch): records=(Records(["x"], records1), Records(["x"], records2)) ) result1 = Result(connection, HydratorStub(), 2, noop, noop) - result1._run("CYPHER1", {}, None, "r", None) + result1._run("CYPHER1", {}, None, None, "r", None) result2 = Result(connection, HydratorStub(), 2, noop, noop) - result2._run("CYPHER2", {}, None, "r", None) + result2._run("CYPHER2", {}, None, None, "r", None) if invert_fetch: _fetch_and_compare_all_records(result2, "x", records2, method) _fetch_and_compare_all_records(result1, "x", records1, method) @@ -252,9 +252,9 @@ def test_interwoven_result_iteration(method, invert_fetch): records=(Records(["x"], records1), Records(["y"], records2)) ) result1 = Result(connection, HydratorStub(), 2, noop, noop) - result1._run("CYPHER1", {}, None, "r", None) + result1._run("CYPHER1", {}, None, None, "r", None) result2 = Result(connection, HydratorStub(), 2, noop, noop) - result2._run("CYPHER2", {}, None, "r", None) + result2._run("CYPHER2", {}, None, None, "r", None) start = 0 for n in (1, 2, 3, 1, None): end = n if n is None else start + n @@ -276,7 +276,7 @@ def test_interwoven_result_iteration(method, invert_fetch): def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) for i in range(len(records) + 1): record = result.peek() if i == len(records): @@ -292,7 +292,7 @@ def test_result_peek(records, fetch_size): def test_result_single(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) result = Result(connection, HydratorStub(), fetch_size, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) with pytest.warns(None) as warning_record: record = result.single() if not records: @@ -310,7 +310,7 @@ def test_result_single(records, fetch_size): def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) assert list(result.keys()) == ["x"] list(result) assert list(result.keys()) == ["x"] @@ -323,7 +323,7 @@ def test_consume(records, consume_one, summary_meta): connection = ConnectionStub(records=Records(["x"], records), summary_meta=summary_meta) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) if consume_one: try: next(iter(result)) @@ -356,7 +356,7 @@ def test_time_in_summary(t_first, t_last): run_meta=run_meta, summary_meta=summary_meta) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() if t_first is not None: @@ -377,7 +377,7 @@ def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() assert isinstance(summary.counters, SummaryCounters) @@ -389,7 +389,7 @@ def test_query_type(query_type): summary_meta={"type": query_type}) result = Result(connection, HydratorStub(), 1, noop, noop) - result._run("CYPHER", {}, None, "r", None) + result._run("CYPHER", {}, None, None, "r", None) summary = result.consume() assert isinstance(summary.query_type, str)