diff --git a/arango/client.py b/arango/client.py index eec64a7b..71b4eb51 100644 --- a/arango/client.py +++ b/arango/client.py @@ -31,6 +31,10 @@ class ArangoClient: multiple host URLs are provided). Accepted values are "roundrobin" and "random". Any other value defaults to round robin. :type host_resolver: str + :param resolver_max_tries: Number of attempts to process an HTTP request + before throwing a ConnectionAbortedError. Must not be lower than the + number of hosts. + :type resolver_max_tries: int :param http_client: User-defined HTTP client. :type http_client: arango.http.HTTPClient :param serializer: User-defined JSON serializer. Must be a callable @@ -48,6 +52,7 @@ def __init__( self, hosts: Union[str, Sequence[str]] = "http://127.0.0.1:8529", host_resolver: str = "roundrobin", + resolver_max_tries: Optional[int] = None, http_client: Optional[HTTPClient] = None, serializer: Callable[..., str] = lambda x: dumps(x), deserializer: Callable[[str], Any] = lambda x: loads(x), @@ -61,11 +66,11 @@ def __init__( self._host_resolver: HostResolver if host_count == 1: - self._host_resolver = SingleHostResolver() + self._host_resolver = SingleHostResolver(1, resolver_max_tries) elif host_resolver == "random": - self._host_resolver = RandomHostResolver(host_count) + self._host_resolver = RandomHostResolver(host_count, resolver_max_tries) else: - self._host_resolver = RoundRobinHostResolver(host_count) + self._host_resolver = RoundRobinHostResolver(host_count, resolver_max_tries) self._http = http_client or DefaultHTTPClient() self._serializer = serializer diff --git a/arango/connection.py b/arango/connection.py index 658e43fb..d49edfe3 100644 --- a/arango/connection.py +++ b/arango/connection.py @@ -6,13 +6,14 @@ "JwtSuperuserConnection", ] +import logging import sys import time from abc import abstractmethod -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union import jwt -from requests import Session +from requests import ConnectionError, Session from requests_toolbelt import MultipartEncoder from arango.exceptions import JWTAuthError, ServerConnectionError @@ -110,6 +111,48 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response: resp.is_success = http_ok and resp.error_code is None return resp + def process_request( + self, host_index: int, request: Request, auth: Optional[Tuple[str, str]] = None + ) -> Response: + """Execute a request until a valid response has been returned. + + :param host_index: The index of the first host to try + :type host_index: int + :param request: HTTP request. + :type request: arango.request.Request + :return: HTTP response. + :rtype: arango.response.Response + """ + tries = 0 + indexes_to_filter: Set[int] = set() + while tries < self._host_resolver.max_tries: + try: + resp = self._http.send_request( + session=self._sessions[host_index], + method=request.method, + url=self._url_prefixes[host_index] + request.endpoint, + params=request.params, + data=self.normalize_data(request.data), + headers=request.headers, + auth=auth, + ) + + return self.prep_response(resp, request.deserialize) + except ConnectionError: + url = self._url_prefixes[host_index] + request.endpoint + logging.debug(f"ConnectionError: {url}") + + if len(indexes_to_filter) == self._host_resolver.host_count - 1: + indexes_to_filter.clear() + indexes_to_filter.add(host_index) + + host_index = self._host_resolver.get_host_index(indexes_to_filter) + tries += 1 + + raise ConnectionAbortedError( + f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})" + ) + def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Response: """Build and return a bulk error response. @@ -227,16 +270,7 @@ def send_request(self, request: Request) -> Response: :rtype: arango.response.Response """ host_index = self._host_resolver.get_host_index() - resp = self._http.send_request( - session=self._sessions[host_index], - method=request.method, - url=self._url_prefixes[host_index] + request.endpoint, - params=request.params, - data=self.normalize_data(request.data), - headers=request.headers, - auth=self._auth, - ) - return self.prep_response(resp, request.deserialize) + return self.process_request(host_index, request, auth=self._auth) class JwtConnection(BaseConnection): @@ -302,15 +336,7 @@ def send_request(self, request: Request) -> Response: if self._auth_header is not None: request.headers["Authorization"] = self._auth_header - resp = self._http.send_request( - session=self._sessions[host_index], - method=request.method, - url=self._url_prefixes[host_index] + request.endpoint, - params=request.params, - data=self.normalize_data(request.data), - headers=request.headers, - ) - resp = self.prep_response(resp, request.deserialize) + resp = self.process_request(host_index, request) # Refresh the token and retry on HTTP 401 and error code 11. if resp.error_code != 11 or resp.status_code != 401: @@ -325,15 +351,7 @@ def send_request(self, request: Request) -> Response: if self._auth_header is not None: request.headers["Authorization"] = self._auth_header - resp = self._http.send_request( - session=self._sessions[host_index], - method=request.method, - url=self._url_prefixes[host_index] + request.endpoint, - params=request.params, - data=self.normalize_data(request.data), - headers=request.headers, - ) - return self.prep_response(resp, request.deserialize) + return self.process_request(host_index, request) def refresh_token(self) -> None: """Get a new JWT token for the current user (cannot be a superuser). @@ -349,13 +367,7 @@ def refresh_token(self) -> None: host_index = self._host_resolver.get_host_index() - resp = self._http.send_request( - session=self._sessions[host_index], - method=request.method, - url=self._url_prefixes[host_index] + request.endpoint, - data=self.normalize_data(request.data), - ) - resp = self.prep_response(resp) + resp = self.process_request(host_index, request) if not resp.is_success: raise JWTAuthError(resp, request) @@ -429,12 +441,4 @@ def send_request(self, request: Request) -> Response: host_index = self._host_resolver.get_host_index() request.headers["Authorization"] = self._auth_header - resp = self._http.send_request( - session=self._sessions[host_index], - method=request.method, - url=self._url_prefixes[host_index] + request.endpoint, - params=request.params, - data=self.normalize_data(request.data), - headers=request.headers, - ) - return self.prep_response(resp, request.deserialize) + return self.process_request(host_index, request) diff --git a/arango/resolver.py b/arango/resolver.py index 72dfe8bb..06a7aa77 100644 --- a/arango/resolver.py +++ b/arango/resolver.py @@ -7,40 +7,62 @@ import random from abc import ABC, abstractmethod +from typing import Optional, Set class HostResolver(ABC): # pragma: no cover """Abstract base class for host resolvers.""" + def __init__(self, host_count: int = 1, max_tries: Optional[int] = None) -> None: + max_tries = max_tries or host_count * 3 + if max_tries < host_count: + raise ValueError("max_tries cannot be less than host_count") + + self._host_count = host_count + self._max_tries = max_tries + @abstractmethod - def get_host_index(self) -> int: + def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int: raise NotImplementedError + @property + def host_count(self) -> int: + return self._host_count + + @property + def max_tries(self) -> int: + return self._max_tries + class SingleHostResolver(HostResolver): """Single host resolver.""" - def get_host_index(self) -> int: + def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int: return 0 class RandomHostResolver(HostResolver): """Random host resolver.""" - def __init__(self, host_count: int) -> None: - self._max = host_count - 1 + def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None: + super().__init__(host_count, max_tries) + + def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int: + host_index = None + indexes_to_filter = indexes_to_filter or set() + while host_index is None or host_index in indexes_to_filter: + host_index = random.randint(0, self.host_count - 1) - def get_host_index(self) -> int: - return random.randint(0, self._max) + return host_index class RoundRobinHostResolver(HostResolver): """Round-robin host resolver.""" - def __init__(self, host_count: int) -> None: + def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None: + super().__init__(host_count, max_tries) self._index = -1 - self._count = host_count - def get_host_index(self) -> int: - self._index = (self._index + 1) % self._count + def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int: + self._index = (self._index + 1) % self.host_count return self._index diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 0598addd..ff5630a1 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,3 +1,7 @@ +from typing import Set + +import pytest + from arango.resolver import ( RandomHostResolver, RoundRobinHostResolver, @@ -5,6 +9,11 @@ ) +def test_bad_resolver(): + with pytest.raises(ValueError): + RandomHostResolver(3, 2) + + def test_resolver_single_host(): resolver = SingleHostResolver() for _ in range(20): @@ -16,6 +25,21 @@ def test_resolver_random_host(): for _ in range(20): assert 0 <= resolver.get_host_index() < 10 + resolver = RandomHostResolver(3) + indexes_to_filter: Set[int] = set() + + index_a = resolver.get_host_index() + indexes_to_filter.add(index_a) + + index_b = resolver.get_host_index(indexes_to_filter) + indexes_to_filter.add(index_b) + assert index_b != index_a + + index_c = resolver.get_host_index(indexes_to_filter) + indexes_to_filter.clear() + indexes_to_filter.add(index_c) + assert index_c not in [index_a, index_b] + def test_resolver_round_robin(): resolver = RoundRobinHostResolver(10)