3
3
"BasicConnection" ,
4
4
]
5
5
6
+ import json
6
7
from abc import ABC , abstractmethod
7
8
from typing import Any , List , Optional
8
9
9
- from arangoasync .auth import Auth
10
+ import jwt
11
+
12
+ from arangoasync .auth import Auth , JwtToken
10
13
from arangoasync .compression import CompressionManager , DefaultCompressionManager
11
14
from arangoasync .exceptions import (
12
15
ClientConnectionError ,
13
16
ConnectionAbortedError ,
17
+ JWTRefreshError ,
14
18
ServerConnectionError ,
15
19
)
16
20
from arangoasync .http import HTTPClient
@@ -63,6 +67,7 @@ def prep_response(self, request: Request, resp: Response) -> Response:
63
67
Raises:
64
68
ServerConnectionError: If the response status code is not successful.
65
69
"""
70
+ # TODO needs refactoring such that it does not throw
66
71
resp .is_success = 200 <= resp .status_code < 300
67
72
if resp .status_code in {401 , 403 }:
68
73
raise ServerConnectionError (resp , request , "Authentication failed." )
@@ -154,7 +159,18 @@ def __init__(
154
159
self ._auth = auth
155
160
156
161
async def send_request (self , request : Request ) -> Response :
157
- """Send an HTTP request to the ArangoDB server."""
162
+ """Send an HTTP request to the ArangoDB server.
163
+
164
+ Args:
165
+ request (Request): HTTP request.
166
+
167
+ Returns:
168
+ Response: HTTP response
169
+
170
+ Raises:
171
+ ArangoClientError: If an error occurred from the client side.
172
+ ArangoServerError: If an error occurred from the server side.
173
+ """
158
174
if request .data is not None and self ._compression .needs_compression (
159
175
request .data
160
176
):
@@ -169,3 +185,126 @@ async def send_request(self, request: Request) -> Response:
169
185
request .auth = self ._auth
170
186
171
187
return await self .process_request (request )
188
+
189
+
190
+ class JwtConnection (BaseConnection ):
191
+ """Connection to a specific ArangoDB database, using JWT authentication.
192
+
193
+ Allows for basic authentication to be used (username and password),
194
+ together with JWT.
195
+
196
+ Args:
197
+ sessions (list): List of client sessions.
198
+ host_resolver (HostResolver): Host resolver.
199
+ http_client (HTTPClient): HTTP client.
200
+ db_name (str): Database name.
201
+ compression (CompressionManager | None): Compression manager.
202
+ auth (Auth | None): Authentication information.
203
+ token (JwtToken | None): JWT token.
204
+
205
+ Raises:
206
+ ValueError: If neither token nor auth is provided.
207
+ """
208
+
209
+ def __init__ (
210
+ self ,
211
+ sessions : List [Any ],
212
+ host_resolver : HostResolver ,
213
+ http_client : HTTPClient ,
214
+ db_name : str ,
215
+ compression : Optional [CompressionManager ] = None ,
216
+ auth : Optional [Auth ] = None ,
217
+ token : Optional [JwtToken ] = None ,
218
+ ) -> None :
219
+ super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
220
+ self ._auth = auth
221
+ self ._expire_leeway : int = 0
222
+ self ._token : Optional [JwtToken ] = None
223
+ self ._auth_header : Optional [str ] = None
224
+ self .set_token (token )
225
+
226
+ if self ._token is None and self ._auth is None :
227
+ raise ValueError ("Either token or auth must be provided." )
228
+
229
+ async def refresh_token (self ) -> None :
230
+ """Refresh the JWT token.
231
+
232
+ Raises:
233
+ JWTRefreshError: If the token can't be refreshed.
234
+ """
235
+ if self ._auth is None :
236
+ raise JWTRefreshError ("Auth must be provided to refresh the token." )
237
+
238
+ data = json .dumps (
239
+ dict (username = self ._auth .username , password = self ._auth .password ),
240
+ separators = ("," , ":" ),
241
+ ensure_ascii = False ,
242
+ )
243
+ request = Request (
244
+ method = Method .POST ,
245
+ endpoint = "/_open/auth" ,
246
+ data = data .encode ("utf-8" ),
247
+ )
248
+
249
+ try :
250
+ resp = await self .process_request (request )
251
+ except ConnectionAbortedError as e :
252
+ raise JWTRefreshError (str (e )) from e
253
+ except ServerConnectionError as e :
254
+ raise JWTRefreshError (str (e )) from e
255
+
256
+ if not resp .is_success :
257
+ raise JWTRefreshError (
258
+ f"Failed to refresh the JWT token: "
259
+ f"{ resp .status_code } { resp .status_text } "
260
+ )
261
+
262
+ token = json .loads (resp .raw_body )
263
+ try :
264
+ self .set_token (JwtToken (token ["jwt" ]))
265
+ except jwt .ExpiredSignatureError as e :
266
+ raise JWTRefreshError (
267
+ "Failed to refresh the JWT token: got an expired token"
268
+ ) from e
269
+
270
+ def set_token (self , value : Optional [JwtToken ]) -> None :
271
+ """Set the JWT token.
272
+
273
+ Args:
274
+ value (JwtToken | None): JWT token.
275
+ Setting it to None will cause the token to be automatically
276
+ refreshed on the next request, if auth information is provided.
277
+ """
278
+ self ._token = value
279
+ self ._auth_header = f"bearer { self ._token .token } " if self ._token else None
280
+
281
+ async def send_request (self , request : Request ) -> Response :
282
+ """Send an HTTP request to the ArangoDB server.
283
+
284
+ Args:
285
+ request (Request): HTTP request.
286
+
287
+ Returns:
288
+ Response: HTTP response
289
+
290
+ Raises:
291
+ ArangoClientError: If an error occurred from the client side.
292
+ ArangoServerError: If an error occurred from the server side.
293
+ """
294
+ if self ._auth_header is not None :
295
+ request .headers ["authorization" ] = self ._auth_header
296
+ else :
297
+ await self .refresh_token ()
298
+
299
+ try :
300
+ resp = await self .process_request (request )
301
+ if (
302
+ resp .status_code == 401
303
+ and self ._token is not None
304
+ and self ._token .needs_refresh (self ._expire_leeway )
305
+ ):
306
+ await self .refresh_token ()
307
+ return await self .process_request (request )
308
+ except ServerConnectionError as e :
309
+ # TODO modify after refactoring of prep_response, so we can inspect response
310
+ raise e
0 commit comments