Skip to content

Commit 9626270

Browse files
committed
Working on JWT Connection Support
1 parent bd76c61 commit 9626270

File tree

7 files changed

+1545
-12
lines changed

7 files changed

+1545
-12
lines changed

arangoasync/auth.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"JwtToken",
44
]
55

6+
import time
67
from dataclasses import dataclass
78

89
import jwt
@@ -27,24 +28,24 @@ class JwtToken:
2728
"""JWT token.
2829
2930
Args:
30-
token (str | bytes): JWT token.
31+
token (str): JWT token.
3132
3233
Raises:
3334
TypeError: If the token type is not str or bytes.
34-
JWTExpiredError: If the token expired.
35+
jwt.ExpiredSignatureError: If the token expired.
3536
"""
3637

37-
def __init__(self, token: str | bytes) -> None:
38+
def __init__(self, token: str) -> None:
3839
self._token = token
3940
self._validate()
4041

4142
@property
42-
def token(self) -> str | bytes:
43+
def token(self) -> str:
4344
"""Get token."""
4445
return self._token
4546

4647
@token.setter
47-
def token(self, token: str | bytes) -> None:
48+
def token(self, token: str) -> None:
4849
"""Set token.
4950
5051
Raises:
@@ -53,9 +54,22 @@ def token(self, token: str | bytes) -> None:
5354
self._token = token
5455
self._validate()
5556

57+
def needs_refresh(self, leeway: int = 0) -> bool:
58+
"""Check if the token needs to be refreshed.
59+
60+
Args:
61+
leeway (int): Leeway in seconds, before official expiration,
62+
when to consider the token expired.
63+
64+
Returns:
65+
bool: True if the token needs to be refreshed, False otherwise.
66+
"""
67+
refresh: bool = int(time.time()) > self._token_exp - leeway
68+
return refresh
69+
5670
def _validate(self) -> None:
5771
"""Validate the token."""
58-
if type(self._token) not in (str, bytes):
72+
if type(self._token) is not str:
5973
raise TypeError("Token must be str or bytes")
6074

6175
jwt_payload = jwt.decode(

arangoasync/compression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ def level(self, value: int) -> None:
120120
self._level = value
121121

122122
@property
123-
def accept_encoding(self) -> str | None:
123+
def accept_encoding(self) -> Optional[str]:
124124
return self._accept_encoding
125125

126126
@accept_encoding.setter
127-
def accept_encoding(self, value: AcceptEncoding | None) -> None:
127+
def accept_encoding(self, value: Optional[AcceptEncoding]) -> None:
128128
self._accept_encoding = value.name.lower() if value else None
129129

130130
@property

arangoasync/connection.py

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,18 @@
33
"BasicConnection",
44
]
55

6+
import json
67
from abc import ABC, abstractmethod
78
from typing import Any, List, Optional
89

9-
from arangoasync.auth import Auth
10+
import jwt
11+
12+
from arangoasync.auth import Auth, JwtToken
1013
from arangoasync.compression import CompressionManager, DefaultCompressionManager
1114
from arangoasync.exceptions import (
1215
ClientConnectionError,
1316
ConnectionAbortedError,
17+
JWTRefreshError,
1418
ServerConnectionError,
1519
)
1620
from arangoasync.http import HTTPClient
@@ -63,6 +67,7 @@ def prep_response(self, request: Request, resp: Response) -> Response:
6367
Raises:
6468
ServerConnectionError: If the response status code is not successful.
6569
"""
70+
# TODO needs refactoring such that it does not throw
6671
resp.is_success = 200 <= resp.status_code < 300
6772
if resp.status_code in {401, 403}:
6873
raise ServerConnectionError(resp, request, "Authentication failed.")
@@ -154,7 +159,18 @@ def __init__(
154159
self._auth = auth
155160

156161
async def send_request(self, request: Request) -> Response:
157-
"""Send an HTTP request to the ArangoDB server."""
162+
"""Send an HTTP request to the ArangoDB server.
163+
164+
Args:
165+
request (Request): HTTP request.
166+
167+
Returns:
168+
Response: HTTP response
169+
170+
Raises:
171+
ArangoClientError: If an error occurred from the client side.
172+
ArangoServerError: If an error occurred from the server side.
173+
"""
158174
if request.data is not None and self._compression.needs_compression(
159175
request.data
160176
):
@@ -169,3 +185,126 @@ async def send_request(self, request: Request) -> Response:
169185
request.auth = self._auth
170186

171187
return await self.process_request(request)
188+
189+
190+
class JwtConnection(BaseConnection):
191+
"""Connection to a specific ArangoDB database, using JWT authentication.
192+
193+
Allows for basic authentication to be used (username and password),
194+
together with JWT.
195+
196+
Args:
197+
sessions (list): List of client sessions.
198+
host_resolver (HostResolver): Host resolver.
199+
http_client (HTTPClient): HTTP client.
200+
db_name (str): Database name.
201+
compression (CompressionManager | None): Compression manager.
202+
auth (Auth | None): Authentication information.
203+
token (JwtToken | None): JWT token.
204+
205+
Raises:
206+
ValueError: If neither token nor auth is provided.
207+
"""
208+
209+
def __init__(
210+
self,
211+
sessions: List[Any],
212+
host_resolver: HostResolver,
213+
http_client: HTTPClient,
214+
db_name: str,
215+
compression: Optional[CompressionManager] = None,
216+
auth: Optional[Auth] = None,
217+
token: Optional[JwtToken] = None,
218+
) -> None:
219+
super().__init__(sessions, host_resolver, http_client, db_name, compression)
220+
self._auth = auth
221+
self._expire_leeway: int = 0
222+
self._token: Optional[JwtToken] = None
223+
self._auth_header: Optional[str] = None
224+
self.set_token(token)
225+
226+
if self._token is None and self._auth is None:
227+
raise ValueError("Either token or auth must be provided.")
228+
229+
async def refresh_token(self) -> None:
230+
"""Refresh the JWT token.
231+
232+
Raises:
233+
JWTRefreshError: If the token can't be refreshed.
234+
"""
235+
if self._auth is None:
236+
raise JWTRefreshError("Auth must be provided to refresh the token.")
237+
238+
data = json.dumps(
239+
dict(username=self._auth.username, password=self._auth.password),
240+
separators=(",", ":"),
241+
ensure_ascii=False,
242+
)
243+
request = Request(
244+
method=Method.POST,
245+
endpoint="/_open/auth",
246+
data=data.encode("utf-8"),
247+
)
248+
249+
try:
250+
resp = await self.process_request(request)
251+
except ConnectionAbortedError as e:
252+
raise JWTRefreshError(str(e)) from e
253+
except ServerConnectionError as e:
254+
raise JWTRefreshError(str(e)) from e
255+
256+
if not resp.is_success:
257+
raise JWTRefreshError(
258+
f"Failed to refresh the JWT token: "
259+
f"{resp.status_code} {resp.status_text}"
260+
)
261+
262+
token = json.loads(resp.raw_body)
263+
try:
264+
self.set_token(JwtToken(token["jwt"]))
265+
except jwt.ExpiredSignatureError as e:
266+
raise JWTRefreshError(
267+
"Failed to refresh the JWT token: got an expired token"
268+
) from e
269+
270+
def set_token(self, value: Optional[JwtToken]) -> None:
271+
"""Set the JWT token.
272+
273+
Args:
274+
value (JwtToken | None): JWT token.
275+
Setting it to None will cause the token to be automatically
276+
refreshed on the next request, if auth information is provided.
277+
"""
278+
self._token = value
279+
self._auth_header = f"bearer {self._token.token}" if self._token else None
280+
281+
async def send_request(self, request: Request) -> Response:
282+
"""Send an HTTP request to the ArangoDB server.
283+
284+
Args:
285+
request (Request): HTTP request.
286+
287+
Returns:
288+
Response: HTTP response
289+
290+
Raises:
291+
ArangoClientError: If an error occurred from the client side.
292+
ArangoServerError: If an error occurred from the server side.
293+
"""
294+
if self._auth_header is not None:
295+
request.headers["authorization"] = self._auth_header
296+
else:
297+
await self.refresh_token()
298+
299+
try:
300+
resp = await self.process_request(request)
301+
if (
302+
resp.status_code == 401
303+
and self._token is not None
304+
and self._token.needs_refresh(self._expire_leeway)
305+
):
306+
await self.refresh_token()
307+
return await self.process_request(request)
308+
except ServerConnectionError as e:
309+
# TODO modify after refactoring of prep_response, so we can inspect response
310+
raise e

0 commit comments

Comments
 (0)