Skip to content

feature: host fallback support #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 43 additions & 46 deletions arango/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import sys
import time
from abc import abstractmethod
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import jwt
from requests import Session
from requests import ConnectionError, Session
from requests_toolbelt import MultipartEncoder

from arango.exceptions import JWTAuthError, ServerConnectionError
Expand Down Expand Up @@ -110,6 +110,42 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response:
resp.is_success = http_ok and resp.error_code is None
return resp

def process_response(
self, host_index: int, request: Request, auth: Optional[Tuple[str, str]] = None
) -> Response:
"""Execute a request until a valid response has been returned.

:param host_index: The index of the first host to try
:type host_index: int
:param request: HTTP request.
:type request: arango.request.Request
:return: HTTP response.
:rtype: arango.response.Response
"""
tries = 0
while tries < self._host_resolver._max_tries:
try:
resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
auth=auth,
)

return self.prep_response(resp, request.deserialize)
except ConnectionError:
tries += 1
host_index = self._host_resolver.get_host_index(
prev_host_index=host_index
)

raise ConnectionAbortedError(
"Unable to establish connection to host(s) within limit"
)

def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Response:
"""Build and return a bulk error response.

Expand Down Expand Up @@ -227,16 +263,7 @@ def send_request(self, request: Request) -> Response:
:rtype: arango.response.Response
"""
host_index = self._host_resolver.get_host_index()
resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
auth=self._auth,
)
return self.prep_response(resp, request.deserialize)
return self.process_response(host_index, request, auth=self._auth)


class JwtConnection(BaseConnection):
Expand Down Expand Up @@ -302,15 +329,7 @@ def send_request(self, request: Request) -> Response:
if self._auth_header is not None:
request.headers["Authorization"] = self._auth_header

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
)
resp = self.prep_response(resp, request.deserialize)
resp = self.process_response(host_index, request)

# Refresh the token and retry on HTTP 401 and error code 11.
if resp.error_code != 11 or resp.status_code != 401:
Expand All @@ -325,15 +344,7 @@ def send_request(self, request: Request) -> Response:
if self._auth_header is not None:
request.headers["Authorization"] = self._auth_header

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
)
return self.prep_response(resp, request.deserialize)
return self.process_response(host_index, request)

def refresh_token(self) -> None:
"""Get a new JWT token for the current user (cannot be a superuser).
Expand All @@ -349,13 +360,7 @@ def refresh_token(self) -> None:

host_index = self._host_resolver.get_host_index()

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
data=self.normalize_data(request.data),
)
resp = self.prep_response(resp)
resp = self.process_response(host_index, request)

if not resp.is_success:
raise JWTAuthError(resp, request)
Expand Down Expand Up @@ -429,12 +434,4 @@ def send_request(self, request: Request) -> Response:
host_index = self._host_resolver.get_host_index()
request.headers["Authorization"] = self._auth_header

resp = self._http.send_request(
session=self._sessions[host_index],
method=request.method,
url=self._url_prefixes[host_index] + request.endpoint,
params=request.params,
data=self.normalize_data(request.data),
headers=request.headers,
)
return self.prep_response(resp, request.deserialize)
return self.process_response(host_index, request)
27 changes: 22 additions & 5 deletions arango/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@

import random
from abc import ABC, abstractmethod
from typing import Optional, Set


class HostResolver(ABC): # pragma: no cover
"""Abstract base class for host resolvers."""

def __init__(self) -> None:
self._max_tries: int = 3

@abstractmethod
def get_host_index(self) -> int:
def get_host_index(self, prev_host_index: Optional[int] = None) -> int:
raise NotImplementedError


class SingleHostResolver(HostResolver):
"""Single host resolver."""

def get_host_index(self) -> int:
def get_host_index(self, prev_host_index: Optional[int] = None) -> int:
return 0


Expand All @@ -29,18 +33,31 @@ class RandomHostResolver(HostResolver):

def __init__(self, host_count: int) -> None:
self._max = host_count - 1
self._max_tries = host_count * 3
self._prev_host_indexes: Set[int] = set()

def get_host_index(self, prev_host_index: Optional[int] = None) -> int:
if prev_host_index is not None:
if len(self._prev_host_indexes) == self._max:
self._prev_host_indexes.clear()

self._prev_host_indexes.add(prev_host_index)

host_index = None
while host_index is None or host_index in self._prev_host_indexes:
host_index = random.randint(0, self._max)

def get_host_index(self) -> int:
return random.randint(0, self._max)
return host_index


class RoundRobinHostResolver(HostResolver):
"""Round-robin host resolver."""

def __init__(self, host_count: int) -> None:
self._max_tries = host_count * 3
self._index = -1
self._count = host_count

def get_host_index(self) -> int:
def get_host_index(self, prev_host_index: Optional[int] = None) -> int:
self._index = (self._index + 1) % self._count
return self._index
12 changes: 12 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ def test_resolver_random_host():
for _ in range(20):
assert 0 <= resolver.get_host_index() < 10

resolver = RandomHostResolver(2)
index_a = resolver.get_host_index()
index_b = resolver.get_host_index(prev_host_index=index_a)
index_c = resolver.get_host_index(prev_host_index=index_b)
assert index_c in [index_a, index_b]

resolver = RandomHostResolver(3)
index_a = resolver.get_host_index()
index_b = resolver.get_host_index(prev_host_index=index_a)
index_c = resolver.get_host_index(prev_host_index=index_b)
assert index_c not in [index_a, index_b]


def test_resolver_round_robin():
resolver = RoundRobinHostResolver(10)
Expand Down