6
6
"JwtSuperuserConnection" ,
7
7
]
8
8
9
+ import logging
9
10
import sys
10
11
import time
11
12
from abc import abstractmethod
12
- from typing import Any , Callable , Optional , Sequence , Union
13
+ from typing import Any , Callable , Optional , Sequence , Set , Tuple , Union
13
14
14
15
import jwt
15
- from requests import Session
16
+ from requests import ConnectionError , Session
16
17
from requests_toolbelt import MultipartEncoder
17
18
18
19
from arango .exceptions import JWTAuthError , ServerConnectionError
@@ -110,6 +111,48 @@ def prep_response(self, resp: Response, deserialize: bool = True) -> Response:
110
111
resp .is_success = http_ok and resp .error_code is None
111
112
return resp
112
113
114
+ def process_request (
115
+ self , host_index : int , request : Request , auth : Optional [Tuple [str , str ]] = None
116
+ ) -> Response :
117
+ """Execute a request until a valid response has been returned.
118
+
119
+ :param host_index: The index of the first host to try
120
+ :type host_index: int
121
+ :param request: HTTP request.
122
+ :type request: arango.request.Request
123
+ :return: HTTP response.
124
+ :rtype: arango.response.Response
125
+ """
126
+ tries = 0
127
+ indexes_to_filter : Set [int ] = set ()
128
+ while tries < self ._host_resolver .max_tries :
129
+ try :
130
+ resp = self ._http .send_request (
131
+ session = self ._sessions [host_index ],
132
+ method = request .method ,
133
+ url = self ._url_prefixes [host_index ] + request .endpoint ,
134
+ params = request .params ,
135
+ data = self .normalize_data (request .data ),
136
+ headers = request .headers ,
137
+ auth = auth ,
138
+ )
139
+
140
+ return self .prep_response (resp , request .deserialize )
141
+ except ConnectionError :
142
+ url = self ._url_prefixes [host_index ] + request .endpoint
143
+ logging .debug (f"ConnectionError: { url } " )
144
+
145
+ if len (indexes_to_filter ) == self ._host_resolver .host_count - 1 :
146
+ indexes_to_filter .clear ()
147
+ indexes_to_filter .add (host_index )
148
+
149
+ host_index = self ._host_resolver .get_host_index (indexes_to_filter )
150
+ tries += 1
151
+
152
+ raise ConnectionAbortedError (
153
+ f"Can't connect to host(s) within limit ({ self ._host_resolver .max_tries } )"
154
+ )
155
+
113
156
def prep_bulk_err_response (self , parent_response : Response , body : Json ) -> Response :
114
157
"""Build and return a bulk error response.
115
158
@@ -227,16 +270,7 @@ def send_request(self, request: Request) -> Response:
227
270
:rtype: arango.response.Response
228
271
"""
229
272
host_index = self ._host_resolver .get_host_index ()
230
- resp = self ._http .send_request (
231
- session = self ._sessions [host_index ],
232
- method = request .method ,
233
- url = self ._url_prefixes [host_index ] + request .endpoint ,
234
- params = request .params ,
235
- data = self .normalize_data (request .data ),
236
- headers = request .headers ,
237
- auth = self ._auth ,
238
- )
239
- return self .prep_response (resp , request .deserialize )
273
+ return self .process_request (host_index , request , auth = self ._auth )
240
274
241
275
242
276
class JwtConnection (BaseConnection ):
@@ -302,15 +336,7 @@ def send_request(self, request: Request) -> Response:
302
336
if self ._auth_header is not None :
303
337
request .headers ["Authorization" ] = self ._auth_header
304
338
305
- resp = self ._http .send_request (
306
- session = self ._sessions [host_index ],
307
- method = request .method ,
308
- url = self ._url_prefixes [host_index ] + request .endpoint ,
309
- params = request .params ,
310
- data = self .normalize_data (request .data ),
311
- headers = request .headers ,
312
- )
313
- resp = self .prep_response (resp , request .deserialize )
339
+ resp = self .process_request (host_index , request )
314
340
315
341
# Refresh the token and retry on HTTP 401 and error code 11.
316
342
if resp .error_code != 11 or resp .status_code != 401 :
@@ -325,15 +351,7 @@ def send_request(self, request: Request) -> Response:
325
351
if self ._auth_header is not None :
326
352
request .headers ["Authorization" ] = self ._auth_header
327
353
328
- resp = self ._http .send_request (
329
- session = self ._sessions [host_index ],
330
- method = request .method ,
331
- url = self ._url_prefixes [host_index ] + request .endpoint ,
332
- params = request .params ,
333
- data = self .normalize_data (request .data ),
334
- headers = request .headers ,
335
- )
336
- return self .prep_response (resp , request .deserialize )
354
+ return self .process_request (host_index , request )
337
355
338
356
def refresh_token (self ) -> None :
339
357
"""Get a new JWT token for the current user (cannot be a superuser).
@@ -349,13 +367,7 @@ def refresh_token(self) -> None:
349
367
350
368
host_index = self ._host_resolver .get_host_index ()
351
369
352
- resp = self ._http .send_request (
353
- session = self ._sessions [host_index ],
354
- method = request .method ,
355
- url = self ._url_prefixes [host_index ] + request .endpoint ,
356
- data = self .normalize_data (request .data ),
357
- )
358
- resp = self .prep_response (resp )
370
+ resp = self .process_request (host_index , request )
359
371
360
372
if not resp .is_success :
361
373
raise JWTAuthError (resp , request )
@@ -429,12 +441,4 @@ def send_request(self, request: Request) -> Response:
429
441
host_index = self ._host_resolver .get_host_index ()
430
442
request .headers ["Authorization" ] = self ._auth_header
431
443
432
- resp = self ._http .send_request (
433
- session = self ._sessions [host_index ],
434
- method = request .method ,
435
- url = self ._url_prefixes [host_index ] + request .endpoint ,
436
- params = request .params ,
437
- data = self .normalize_data (request .data ),
438
- headers = request .headers ,
439
- )
440
- return self .prep_response (resp , request .deserialize )
444
+ return self .process_request (host_index , request )
0 commit comments