1
1
__all__ = [
2
2
"BaseConnection" ,
3
3
"BasicConnection" ,
4
+ "JwtConnection" ,
5
+ "JwtSuperuserConnection" ,
4
6
]
5
7
6
8
import json
9
11
10
12
import jwt
11
13
14
+ from arangoasync import errno , logger
12
15
from arangoasync .auth import Auth , JwtToken
13
16
from arangoasync .compression import CompressionManager , DefaultCompressionManager
14
17
from arangoasync .exceptions import (
@@ -55,25 +58,45 @@ def db_name(self) -> str:
55
58
"""Return the database name."""
56
59
return self ._db_name
57
60
58
- def prep_response (self , request : Request , resp : Response ) -> Response :
59
- """Prepare response for return.
61
+ @staticmethod
62
+ def raise_for_status (request : Request , resp : Response ) -> None :
63
+ """Raise an exception based on the response.
60
64
61
65
Args:
62
66
request (Request): Request object.
63
67
resp (Response): Response object.
64
68
65
- Returns:
66
- Response: Response object
67
-
68
69
Raises:
69
70
ServerConnectionError: If the response status code is not successful.
70
71
"""
71
- # TODO needs refactoring such that it does not throw
72
- resp .is_success = 200 <= resp .status_code < 300
73
72
if resp .status_code in {401 , 403 }:
74
73
raise ServerConnectionError (resp , request , "Authentication failed." )
75
74
if not resp .is_success :
76
75
raise ServerConnectionError (resp , request , "Bad server response." )
76
+
77
+ @staticmethod
78
+ def prep_response (request : Request , resp : Response ) -> Response :
79
+ """Prepare response for return.
80
+
81
+ Args:
82
+ request (Request): Request object.
83
+ resp (Response): Response object.
84
+
85
+ Returns:
86
+ Response: Response object
87
+ """
88
+ resp .is_success = 200 <= resp .status_code < 300
89
+ if not resp .is_success :
90
+ try :
91
+ body = json .loads (resp .raw_body )
92
+ except json .JSONDecodeError as e :
93
+ logger .debug (
94
+ f"Failed to decode response body: { e } (from request { request } )"
95
+ )
96
+ else :
97
+ if body .get ("error" ) is True :
98
+ resp .error_code = body .get ("errorNum" )
99
+ resp .error_message = body .get ("errorMessage" )
77
100
return resp
78
101
79
102
async def process_request (self , request : Request ) -> Response :
@@ -86,7 +109,7 @@ async def process_request(self, request: Request) -> Response:
86
109
Response: Response object.
87
110
88
111
Raises:
89
- ConnectionAbortedError: If can't connect to host(s) within limit.
112
+ ConnectionAbortedError: If it can't connect to host(s) within limit.
90
113
"""
91
114
92
115
host_index = self ._host_resolver .get_host_index ()
@@ -100,6 +123,7 @@ async def process_request(self, request: Request) -> Response:
100
123
ex_host_index = host_index
101
124
host_index = self ._host_resolver .get_host_index ()
102
125
if ex_host_index == host_index :
126
+ # Force change host if the same host is selected
103
127
self ._host_resolver .change_host ()
104
128
host_index = self ._host_resolver .get_host_index ()
105
129
@@ -117,8 +141,8 @@ async def ping(self) -> int:
117
141
ServerConnectionError: If the response status code is not successful.
118
142
"""
119
143
request = Request (method = Method .GET , endpoint = "/_api/collection" )
120
- request .headers = {"abde" : "fghi" }
121
144
resp = await self .send_request (request )
145
+ self .raise_for_status (request , resp )
122
146
return resp .status_code
123
147
124
148
@abstractmethod
@@ -257,15 +281,15 @@ async def refresh_token(self) -> None:
257
281
if self ._auth is None :
258
282
raise JWTRefreshError ("Auth must be provided to refresh the token." )
259
283
260
- data = json .dumps (
284
+ auth_data = json .dumps (
261
285
dict (username = self ._auth .username , password = self ._auth .password ),
262
286
separators = ("," , ":" ),
263
287
ensure_ascii = False ,
264
288
)
265
289
request = Request (
266
290
method = Method .POST ,
267
291
endpoint = "/_open/auth" ,
268
- data = data .encode ("utf-8" ),
292
+ data = auth_data .encode ("utf-8" ),
269
293
)
270
294
271
295
try :
@@ -310,16 +334,86 @@ async def send_request(self, request: Request) -> Response:
310
334
311
335
request .headers ["authorization" ] = self ._auth_header
312
336
313
- try :
314
- resp = await self .process_request (request )
315
- if (
316
- resp .status_code == 401 # Unauthorized
317
- and self ._token is not None
318
- and self ._token .needs_refresh (self ._expire_leeway )
319
- ):
320
- await self .refresh_token ()
321
- return await self .process_request (request ) # Retry with new token
322
- except ServerConnectionError :
323
- # TODO modify after refactoring of prep_response, so we can inspect response
337
+ resp = await self .process_request (request )
338
+ if (
339
+ resp .status_code == errno .HTTP_UNAUTHORIZED
340
+ and self ._token is not None
341
+ and self ._token .needs_refresh (self ._expire_leeway )
342
+ ):
343
+ # If the token has expired, refresh it and retry the request
324
344
await self .refresh_token ()
325
- return await self .process_request (request ) # Retry with new token
345
+ resp = await self .process_request (request )
346
+ self .raise_for_status (request , resp )
347
+ return resp
348
+
349
+
350
+ class JwtSuperuserConnection (BaseConnection ):
351
+ """Connection to a specific ArangoDB database, using superuser JWT.
352
+
353
+ The JWT token is not refreshed and (username and password) are not required.
354
+
355
+ Args:
356
+ sessions (list): List of client sessions.
357
+ host_resolver (HostResolver): Host resolver.
358
+ http_client (HTTPClient): HTTP client.
359
+ db_name (str): Database name.
360
+ compression (CompressionManager | None): Compression manager.
361
+ token (JwtToken | None): JWT token.
362
+ """
363
+
364
+ def __init__ (
365
+ self ,
366
+ sessions : List [Any ],
367
+ host_resolver : HostResolver ,
368
+ http_client : HTTPClient ,
369
+ db_name : str ,
370
+ compression : Optional [CompressionManager ] = None ,
371
+ token : Optional [JwtToken ] = None ,
372
+ ) -> None :
373
+ super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
374
+ self ._expire_leeway : int = 0
375
+ self ._token : Optional [JwtToken ] = None
376
+ self ._auth_header : Optional [str ] = None
377
+ self .token = token
378
+
379
+ @property
380
+ def token (self ) -> Optional [JwtToken ]:
381
+ """Get the JWT token.
382
+
383
+ Returns:
384
+ JwtToken | None: JWT token.
385
+ """
386
+ return self ._token
387
+
388
+ @token .setter
389
+ def token (self , token : Optional [JwtToken ]) -> None :
390
+ """Set the JWT token.
391
+
392
+ Args:
393
+ token (JwtToken | None): JWT token.
394
+ Setting it to None will cause the token to be automatically
395
+ refreshed on the next request, if auth information is provided.
396
+ """
397
+ self ._token = token
398
+ self ._auth_header = f"bearer { self ._token .token } " if self ._token else None
399
+
400
+ async def send_request (self , request : Request ) -> Response :
401
+ """Send an HTTP request to the ArangoDB server.
402
+
403
+ Args:
404
+ request (Request): HTTP request.
405
+
406
+ Returns:
407
+ Response: HTTP response
408
+
409
+ Raises:
410
+ ArangoClientError: If an error occurred from the client side.
411
+ ArangoServerError: If an error occurred from the server side.
412
+ """
413
+ if self ._auth_header is None :
414
+ raise AuthHeaderError ("Failed to generate authorization header." )
415
+ request .headers ["authorization" ] = self ._auth_header
416
+
417
+ resp = await self .process_request (request )
418
+ self .raise_for_status (request , resp )
419
+ return resp
0 commit comments