Skip to content

Commit 1900f4a

Browse files
authored
feature: host fallback support (#184)
1 parent c65f5ee commit 1900f4a

File tree

4 files changed

+114
-59
lines changed

4 files changed

+114
-59
lines changed

arango/client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class ArangoClient:
3131
multiple host URLs are provided). Accepted values are "roundrobin" and
3232
"random". Any other value defaults to round robin.
3333
:type host_resolver: str
34+
:param resolver_max_tries: Number of attempts to process an HTTP request
35+
before throwing a ConnectionAbortedError. Must not be lower than the
36+
number of hosts.
37+
:type resolver_max_tries: int
3438
:param http_client: User-defined HTTP client.
3539
:type http_client: arango.http.HTTPClient
3640
:param serializer: User-defined JSON serializer. Must be a callable
@@ -48,6 +52,7 @@ def __init__(
4852
self,
4953
hosts: Union[str, Sequence[str]] = "http://127.0.0.1:8529",
5054
host_resolver: str = "roundrobin",
55+
resolver_max_tries: Optional[int] = None,
5156
http_client: Optional[HTTPClient] = None,
5257
serializer: Callable[..., str] = lambda x: dumps(x),
5358
deserializer: Callable[[str], Any] = lambda x: loads(x),
@@ -61,11 +66,11 @@ def __init__(
6166
self._host_resolver: HostResolver
6267

6368
if host_count == 1:
64-
self._host_resolver = SingleHostResolver()
69+
self._host_resolver = SingleHostResolver(1, resolver_max_tries)
6570
elif host_resolver == "random":
66-
self._host_resolver = RandomHostResolver(host_count)
71+
self._host_resolver = RandomHostResolver(host_count, resolver_max_tries)
6772
else:
68-
self._host_resolver = RoundRobinHostResolver(host_count)
73+
self._host_resolver = RoundRobinHostResolver(host_count, resolver_max_tries)
6974

7075
self._http = http_client or DefaultHTTPClient()
7176
self._serializer = serializer

arango/connection.py

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
"JwtSuperuserConnection",
77
]
88

9+
import logging
910
import sys
1011
import time
1112
from abc import abstractmethod
12-
from typing import Any, Callable, Optional, Sequence, Union
13+
from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union
1314

1415
import jwt
15-
from requests import Session
16+
from requests import ConnectionError, Session
1617
from requests_toolbelt import MultipartEncoder
1718

1819
from arango.exceptions import JWTAuthError, ServerConnectionError
@@ -110,6 +111,48 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response:
110111
resp.is_success = http_ok and resp.error_code is None
111112
return resp
112113

114+
def process_request(
115+
self, host_index: int, request: Request, auth: Optional[Tuple[str, str]] = None
116+
) -> Response:
117+
"""Execute a request until a valid response has been returned.
118+
119+
:param host_index: The index of the first host to try
120+
:type host_index: int
121+
:param request: HTTP request.
122+
:type request: arango.request.Request
123+
:return: HTTP response.
124+
:rtype: arango.response.Response
125+
"""
126+
tries = 0
127+
indexes_to_filter: Set[int] = set()
128+
while tries < self._host_resolver.max_tries:
129+
try:
130+
resp = self._http.send_request(
131+
session=self._sessions[host_index],
132+
method=request.method,
133+
url=self._url_prefixes[host_index] + request.endpoint,
134+
params=request.params,
135+
data=self.normalize_data(request.data),
136+
headers=request.headers,
137+
auth=auth,
138+
)
139+
140+
return self.prep_response(resp, request.deserialize)
141+
except ConnectionError:
142+
url = self._url_prefixes[host_index] + request.endpoint
143+
logging.debug(f"ConnectionError: {url}")
144+
145+
if len(indexes_to_filter) == self._host_resolver.host_count - 1:
146+
indexes_to_filter.clear()
147+
indexes_to_filter.add(host_index)
148+
149+
host_index = self._host_resolver.get_host_index(indexes_to_filter)
150+
tries += 1
151+
152+
raise ConnectionAbortedError(
153+
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
154+
)
155+
113156
def prep_bulk_err_response(self, parent_response: Response, body: Json) -> Response:
114157
"""Build and return a bulk error response.
115158
@@ -227,16 +270,7 @@ def send_request(self, request: Request) -> Response:
227270
:rtype: arango.response.Response
228271
"""
229272
host_index = self._host_resolver.get_host_index()
230-
resp = self._http.send_request(
231-
session=self._sessions[host_index],
232-
method=request.method,
233-
url=self._url_prefixes[host_index] + request.endpoint,
234-
params=request.params,
235-
data=self.normalize_data(request.data),
236-
headers=request.headers,
237-
auth=self._auth,
238-
)
239-
return self.prep_response(resp, request.deserialize)
273+
return self.process_request(host_index, request, auth=self._auth)
240274

241275

242276
class JwtConnection(BaseConnection):
@@ -302,15 +336,7 @@ def send_request(self, request: Request) -> Response:
302336
if self._auth_header is not None:
303337
request.headers["Authorization"] = self._auth_header
304338

305-
resp = self._http.send_request(
306-
session=self._sessions[host_index],
307-
method=request.method,
308-
url=self._url_prefixes[host_index] + request.endpoint,
309-
params=request.params,
310-
data=self.normalize_data(request.data),
311-
headers=request.headers,
312-
)
313-
resp = self.prep_response(resp, request.deserialize)
339+
resp = self.process_request(host_index, request)
314340

315341
# Refresh the token and retry on HTTP 401 and error code 11.
316342
if resp.error_code != 11 or resp.status_code != 401:
@@ -325,15 +351,7 @@ def send_request(self, request: Request) -> Response:
325351
if self._auth_header is not None:
326352
request.headers["Authorization"] = self._auth_header
327353

328-
resp = self._http.send_request(
329-
session=self._sessions[host_index],
330-
method=request.method,
331-
url=self._url_prefixes[host_index] + request.endpoint,
332-
params=request.params,
333-
data=self.normalize_data(request.data),
334-
headers=request.headers,
335-
)
336-
return self.prep_response(resp, request.deserialize)
354+
return self.process_request(host_index, request)
337355

338356
def refresh_token(self) -> None:
339357
"""Get a new JWT token for the current user (cannot be a superuser).
@@ -349,13 +367,7 @@ def refresh_token(self) -> None:
349367

350368
host_index = self._host_resolver.get_host_index()
351369

352-
resp = self._http.send_request(
353-
session=self._sessions[host_index],
354-
method=request.method,
355-
url=self._url_prefixes[host_index] + request.endpoint,
356-
data=self.normalize_data(request.data),
357-
)
358-
resp = self.prep_response(resp)
370+
resp = self.process_request(host_index, request)
359371

360372
if not resp.is_success:
361373
raise JWTAuthError(resp, request)
@@ -429,12 +441,4 @@ def send_request(self, request: Request) -> Response:
429441
host_index = self._host_resolver.get_host_index()
430442
request.headers["Authorization"] = self._auth_header
431443

432-
resp = self._http.send_request(
433-
session=self._sessions[host_index],
434-
method=request.method,
435-
url=self._url_prefixes[host_index] + request.endpoint,
436-
params=request.params,
437-
data=self.normalize_data(request.data),
438-
headers=request.headers,
439-
)
440-
return self.prep_response(resp, request.deserialize)
444+
return self.process_request(host_index, request)

arango/resolver.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,40 +7,62 @@
77

88
import random
99
from abc import ABC, abstractmethod
10+
from typing import Optional, Set
1011

1112

1213
class HostResolver(ABC): # pragma: no cover
1314
"""Abstract base class for host resolvers."""
1415

16+
def __init__(self, host_count: int = 1, max_tries: Optional[int] = None) -> None:
17+
max_tries = max_tries or host_count * 3
18+
if max_tries < host_count:
19+
raise ValueError("max_tries cannot be less than host_count")
20+
21+
self._host_count = host_count
22+
self._max_tries = max_tries
23+
1524
@abstractmethod
16-
def get_host_index(self) -> int:
25+
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
1726
raise NotImplementedError
1827

28+
@property
29+
def host_count(self) -> int:
30+
return self._host_count
31+
32+
@property
33+
def max_tries(self) -> int:
34+
return self._max_tries
35+
1936

2037
class SingleHostResolver(HostResolver):
2138
"""Single host resolver."""
2239

23-
def get_host_index(self) -> int:
40+
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
2441
return 0
2542

2643

2744
class RandomHostResolver(HostResolver):
2845
"""Random host resolver."""
2946

30-
def __init__(self, host_count: int) -> None:
31-
self._max = host_count - 1
47+
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
48+
super().__init__(host_count, max_tries)
49+
50+
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
51+
host_index = None
52+
indexes_to_filter = indexes_to_filter or set()
53+
while host_index is None or host_index in indexes_to_filter:
54+
host_index = random.randint(0, self.host_count - 1)
3255

33-
def get_host_index(self) -> int:
34-
return random.randint(0, self._max)
56+
return host_index
3557

3658

3759
class RoundRobinHostResolver(HostResolver):
3860
"""Round-robin host resolver."""
3961

40-
def __init__(self, host_count: int) -> None:
62+
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
63+
super().__init__(host_count, max_tries)
4164
self._index = -1
42-
self._count = host_count
4365

44-
def get_host_index(self) -> int:
45-
self._index = (self._index + 1) % self._count
66+
def get_host_index(self, indexes_to_filter: Optional[Set[int]] = None) -> int:
67+
self._index = (self._index + 1) % self.host_count
4668
return self._index

tests/test_resolver.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1+
from typing import Set
2+
3+
import pytest
4+
15
from arango.resolver import (
26
RandomHostResolver,
37
RoundRobinHostResolver,
48
SingleHostResolver,
59
)
610

711

12+
def test_bad_resolver():
13+
with pytest.raises(ValueError):
14+
RandomHostResolver(3, 2)
15+
16+
817
def test_resolver_single_host():
918
resolver = SingleHostResolver()
1019
for _ in range(20):
@@ -16,6 +25,21 @@ def test_resolver_random_host():
1625
for _ in range(20):
1726
assert 0 <= resolver.get_host_index() < 10
1827

28+
resolver = RandomHostResolver(3)
29+
indexes_to_filter: Set[int] = set()
30+
31+
index_a = resolver.get_host_index()
32+
indexes_to_filter.add(index_a)
33+
34+
index_b = resolver.get_host_index(indexes_to_filter)
35+
indexes_to_filter.add(index_b)
36+
assert index_b != index_a
37+
38+
index_c = resolver.get_host_index(indexes_to_filter)
39+
indexes_to_filter.clear()
40+
indexes_to_filter.add(index_c)
41+
assert index_c not in [index_a, index_b]
42+
1943

2044
def test_resolver_round_robin():
2145
resolver = RoundRobinHostResolver(10)

0 commit comments

Comments
 (0)