Skip to content

Commit e642e5b

Browse files
authored
JwtSuperuserConnection (#15)
* Implemented JwtSuperuserConnection * Fixed failing test
1 parent 44f4fa0 commit e642e5b

File tree

6 files changed

+260
-56
lines changed

6 files changed

+260
-56
lines changed

arangoasync/auth.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import time
77
from dataclasses import dataclass
8+
from typing import Optional
89

910
import jwt
1011

@@ -39,6 +40,38 @@ def __init__(self, token: str) -> None:
3940
self._token = token
4041
self._validate()
4142

43+
@staticmethod
44+
def generate_token(
45+
secret: str | bytes,
46+
iat: Optional[int] = None,
47+
exp: int = 3600,
48+
iss: str = "arangodb",
49+
server_id: str = "client",
50+
) -> "JwtToken":
51+
"""Generate and return a JWT token.
52+
53+
Args:
54+
secret (str | bytes): JWT secret.
55+
iat (int): Time the token was issued in seconds. Defaults to current time.
56+
exp (int): Time to expire in seconds.
57+
iss (str): Issuer.
58+
server_id (str): Server ID.
59+
60+
Returns:
61+
str: JWT token.
62+
"""
63+
iat = iat or int(time.time())
64+
token = jwt.encode(
65+
payload={
66+
"iat": iat,
67+
"exp": iat + exp,
68+
"iss": iss,
69+
"server_id": server_id,
70+
},
71+
key=secret,
72+
)
73+
return JwtToken(token)
74+
4275
@property
4376
def token(self) -> str:
4477
"""Get token."""

arangoasync/connection.py

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
__all__ = [
22
"BaseConnection",
33
"BasicConnection",
4+
"JwtConnection",
5+
"JwtSuperuserConnection",
46
]
57

68
import json
@@ -9,6 +11,7 @@
911

1012
import jwt
1113

14+
from arangoasync import errno, logger
1215
from arangoasync.auth import Auth, JwtToken
1316
from arangoasync.compression import CompressionManager, DefaultCompressionManager
1417
from arangoasync.exceptions import (
@@ -55,25 +58,45 @@ def db_name(self) -> str:
5558
"""Return the database name."""
5659
return self._db_name
5760

58-
def prep_response(self, request: Request, resp: Response) -> Response:
59-
"""Prepare response for return.
61+
@staticmethod
62+
def raise_for_status(request: Request, resp: Response) -> None:
63+
"""Raise an exception based on the response.
6064
6165
Args:
6266
request (Request): Request object.
6367
resp (Response): Response object.
6468
65-
Returns:
66-
Response: Response object
67-
6869
Raises:
6970
ServerConnectionError: If the response status code is not successful.
7071
"""
71-
# TODO needs refactoring such that it does not throw
72-
resp.is_success = 200 <= resp.status_code < 300
7372
if resp.status_code in {401, 403}:
7473
raise ServerConnectionError(resp, request, "Authentication failed.")
7574
if not resp.is_success:
7675
raise ServerConnectionError(resp, request, "Bad server response.")
76+
77+
@staticmethod
78+
def prep_response(request: Request, resp: Response) -> Response:
79+
"""Prepare response for return.
80+
81+
Args:
82+
request (Request): Request object.
83+
resp (Response): Response object.
84+
85+
Returns:
86+
Response: Response object
87+
"""
88+
resp.is_success = 200 <= resp.status_code < 300
89+
if not resp.is_success:
90+
try:
91+
body = json.loads(resp.raw_body)
92+
except json.JSONDecodeError as e:
93+
logger.debug(
94+
f"Failed to decode response body: {e} (from request {request})"
95+
)
96+
else:
97+
if body.get("error") is True:
98+
resp.error_code = body.get("errorNum")
99+
resp.error_message = body.get("errorMessage")
77100
return resp
78101

79102
async def process_request(self, request: Request) -> Response:
@@ -86,7 +109,7 @@ async def process_request(self, request: Request) -> Response:
86109
Response: Response object.
87110
88111
Raises:
89-
ConnectionAbortedError: If can't connect to host(s) within limit.
112+
ConnectionAbortedError: If it can't connect to host(s) within limit.
90113
"""
91114

92115
host_index = self._host_resolver.get_host_index()
@@ -100,6 +123,7 @@ async def process_request(self, request: Request) -> Response:
100123
ex_host_index = host_index
101124
host_index = self._host_resolver.get_host_index()
102125
if ex_host_index == host_index:
126+
# Force change host if the same host is selected
103127
self._host_resolver.change_host()
104128
host_index = self._host_resolver.get_host_index()
105129

@@ -117,8 +141,8 @@ async def ping(self) -> int:
117141
ServerConnectionError: If the response status code is not successful.
118142
"""
119143
request = Request(method=Method.GET, endpoint="/_api/collection")
120-
request.headers = {"abde": "fghi"}
121144
resp = await self.send_request(request)
145+
self.raise_for_status(request, resp)
122146
return resp.status_code
123147

124148
@abstractmethod
@@ -257,15 +281,15 @@ async def refresh_token(self) -> None:
257281
if self._auth is None:
258282
raise JWTRefreshError("Auth must be provided to refresh the token.")
259283

260-
data = json.dumps(
284+
auth_data = json.dumps(
261285
dict(username=self._auth.username, password=self._auth.password),
262286
separators=(",", ":"),
263287
ensure_ascii=False,
264288
)
265289
request = Request(
266290
method=Method.POST,
267291
endpoint="/_open/auth",
268-
data=data.encode("utf-8"),
292+
data=auth_data.encode("utf-8"),
269293
)
270294

271295
try:
@@ -310,16 +334,86 @@ async def send_request(self, request: Request) -> Response:
310334

311335
request.headers["authorization"] = self._auth_header
312336

313-
try:
314-
resp = await self.process_request(request)
315-
if (
316-
resp.status_code == 401 # Unauthorized
317-
and self._token is not None
318-
and self._token.needs_refresh(self._expire_leeway)
319-
):
320-
await self.refresh_token()
321-
return await self.process_request(request) # Retry with new token
322-
except ServerConnectionError:
323-
# TODO modify after refactoring of prep_response, so we can inspect response
337+
resp = await self.process_request(request)
338+
if (
339+
resp.status_code == errno.HTTP_UNAUTHORIZED
340+
and self._token is not None
341+
and self._token.needs_refresh(self._expire_leeway)
342+
):
343+
# If the token has expired, refresh it and retry the request
324344
await self.refresh_token()
325-
return await self.process_request(request) # Retry with new token
345+
resp = await self.process_request(request)
346+
self.raise_for_status(request, resp)
347+
return resp
348+
349+
350+
class JwtSuperuserConnection(BaseConnection):
351+
"""Connection to a specific ArangoDB database, using superuser JWT.
352+
353+
The JWT token is not refreshed and (username and password) are not required.
354+
355+
Args:
356+
sessions (list): List of client sessions.
357+
host_resolver (HostResolver): Host resolver.
358+
http_client (HTTPClient): HTTP client.
359+
db_name (str): Database name.
360+
compression (CompressionManager | None): Compression manager.
361+
token (JwtToken | None): JWT token.
362+
"""
363+
364+
def __init__(
365+
self,
366+
sessions: List[Any],
367+
host_resolver: HostResolver,
368+
http_client: HTTPClient,
369+
db_name: str,
370+
compression: Optional[CompressionManager] = None,
371+
token: Optional[JwtToken] = None,
372+
) -> None:
373+
super().__init__(sessions, host_resolver, http_client, db_name, compression)
374+
self._expire_leeway: int = 0
375+
self._token: Optional[JwtToken] = None
376+
self._auth_header: Optional[str] = None
377+
self.token = token
378+
379+
@property
380+
def token(self) -> Optional[JwtToken]:
381+
"""Get the JWT token.
382+
383+
Returns:
384+
JwtToken | None: JWT token.
385+
"""
386+
return self._token
387+
388+
@token.setter
389+
def token(self, token: Optional[JwtToken]) -> None:
390+
"""Set the JWT token.
391+
392+
Args:
393+
token (JwtToken | None): JWT token.
394+
Setting it to None will cause the token to be automatically
395+
refreshed on the next request, if auth information is provided.
396+
"""
397+
self._token = token
398+
self._auth_header = f"bearer {self._token.token}" if self._token else None
399+
400+
async def send_request(self, request: Request) -> Response:
401+
"""Send an HTTP request to the ArangoDB server.
402+
403+
Args:
404+
request (Request): HTTP request.
405+
406+
Returns:
407+
Response: HTTP response
408+
409+
Raises:
410+
ArangoClientError: If an error occurred from the client side.
411+
ArangoServerError: If an error occurred from the server side.
412+
"""
413+
if self._auth_header is None:
414+
raise AuthHeaderError("Failed to generate authorization header.")
415+
request.headers["authorization"] = self._auth_header
416+
417+
resp = await self.process_request(request)
418+
self.raise_for_status(request, resp)
419+
return resp

arangoasync/request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ def normalized_params(self) -> Params:
102102
normalized_params[key] = str(value)
103103

104104
return normalized_params
105+
106+
def __repr__(self) -> str:
107+
return f"<{self.method.name} {self.endpoint}>"

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import pytest_asyncio
55

6-
from tests.helpers import generate_jwt
6+
from arangoasync.auth import JwtToken
77

88

99
@dataclass
@@ -45,8 +45,8 @@ def pytest_configure(config):
4545
global_data.url = url
4646
global_data.root = config.getoption("root")
4747
global_data.password = config.getoption("password")
48-
global_data.secret = generate_jwt(config.getoption("secret"))
49-
global_data.token = generate_jwt(global_data.secret)
48+
global_data.secret = config.getoption("secret")
49+
global_data.token = JwtToken.generate_token(global_data.secret)
5050

5151

5252
@pytest.fixture(autouse=False)
@@ -76,6 +76,7 @@ def sys_db_name():
7676

7777
@pytest_asyncio.fixture
7878
async def client_session():
79+
"""Make sure we close all sessions after the test is done."""
7980
sessions = []
8081

8182
def get_client_session(client, url):

tests/helpers.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)