Skip to content

Commit 8f5d576

Browse files
committed
BasicConnection supports authentication and compression
1 parent 7940d6d commit 8f5d576

11 files changed

+383
-85
lines changed

arangoasync/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
import logging
2+
13
from .version import __version__
4+
5+
logger = logging.getLogger(__name__)

arangoasync/auth.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import jwt
99

10-
from arangoasync.exceptions import JWTExpiredError
11-
1210

1311
@dataclass
1412
class Auth:
@@ -32,6 +30,7 @@ class JwtToken:
3230
token (str | bytes): JWT token.
3331
3432
Raises:
33+
TypeError: If the token type is not str or bytes.
3534
JWTExpiredError: If the token expired.
3635
"""
3736

@@ -49,29 +48,27 @@ def token(self, token: str | bytes) -> None:
4948
"""Set token.
5049
5150
Raises:
52-
JWTExpiredError: If the token expired.
51+
jwt.ExpiredSignatureError: If the token expired.
5352
"""
5453
self._token = token
5554
self._validate()
5655

5756
def _validate(self) -> None:
5857
"""Validate the token."""
59-
if type(self._token) is not str:
60-
raise TypeError("Token must be a string")
61-
try:
62-
jwt_payload = jwt.decode(
63-
self._token,
64-
issuer="arangodb",
65-
algorithms=["HS256"],
66-
options={
67-
"require_exp": True,
68-
"require_iat": True,
69-
"verify_iat": True,
70-
"verify_exp": True,
71-
"verify_signature": False,
72-
},
73-
)
74-
except jwt.ExpiredSignatureError:
75-
raise JWTExpiredError("JWT token has expired")
58+
if type(self._token) not in (str, bytes):
59+
raise TypeError("Token must be str or bytes")
60+
61+
jwt_payload = jwt.decode(
62+
self._token,
63+
issuer="arangodb",
64+
algorithms=["HS256"],
65+
options={
66+
"require_exp": True,
67+
"require_iat": True,
68+
"verify_iat": True,
69+
"verify_exp": True,
70+
"verify_signature": False,
71+
},
72+
)
7673

7774
self._token_exp = jwt_payload["exp"]

arangoasync/compression.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
__all__ = [
2+
"AcceptEncoding",
3+
"ContentEncoding",
4+
"CompressionManager",
5+
"DefaultCompressionManager",
6+
]
7+
8+
import zlib
9+
from abc import ABC, abstractmethod
10+
from enum import Enum, auto
11+
from typing import Optional
12+
13+
14+
class AcceptEncoding(Enum):
15+
"""Valid accepted encodings for the Accept-Encoding header."""
16+
17+
DEFLATE = auto()
18+
GZIP = auto()
19+
IDENTITY = auto()
20+
21+
22+
class ContentEncoding(Enum):
23+
"""Valid content encodings for the Content-Encoding header."""
24+
25+
DEFLATE = auto()
26+
GZIP = auto()
27+
28+
29+
class CompressionManager(ABC): # pragma: no cover
30+
"""Abstract base class for handling request/response compression."""
31+
32+
@abstractmethod
33+
def needs_compression(self, data: str | bytes) -> bool:
34+
"""Determine if the data needs to be compressed
35+
36+
Args:
37+
data (str | bytes): Data to check
38+
39+
Returns:
40+
bool: True if the data needs to be compressed
41+
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def compress(self, data: str | bytes) -> bytes:
46+
"""Compress the data
47+
48+
Args:
49+
data (str | bytes): Data to compress
50+
51+
Returns:
52+
bytes: Compressed data
53+
"""
54+
raise NotImplementedError
55+
56+
@abstractmethod
57+
def content_encoding(self) -> str:
58+
"""Return the content encoding.
59+
60+
This is the value of the Content-Encoding header in the HTTP request.
61+
Must match the encoding used in the compress method.
62+
63+
Returns:
64+
str: Content encoding
65+
"""
66+
raise NotImplementedError
67+
68+
@abstractmethod
69+
def accept_encoding(self) -> str | None:
70+
"""Return the accept encoding.
71+
72+
This is the value of the Accept-Encoding header in the HTTP request.
73+
Currently, only deflate and "gzip" are supported.
74+
75+
Returns:
76+
str: Accept encoding
77+
"""
78+
raise NotImplementedError
79+
80+
81+
class DefaultCompressionManager(CompressionManager):
82+
"""Compress requests using the deflate algorithm.
83+
84+
Args:
85+
threshold (int): Will compress requests to the server if
86+
the size of the request body (in bytes) is at least the value of this option.
87+
Setting it to -1 will disable request compression (default).
88+
level (int): Compression level. Defaults to 6.
89+
accept (str | None): Accepted encoding. By default, there is
90+
no compression of responses.
91+
"""
92+
93+
def __init__(
94+
self,
95+
threshold: int = -1,
96+
level: int = 6,
97+
accept: Optional[AcceptEncoding] = None,
98+
) -> None:
99+
self._threshold = threshold
100+
self._level = level
101+
self._content_encoding = ContentEncoding.DEFLATE.name.lower()
102+
self._accept_encoding = accept.name.lower() if accept else None
103+
104+
def needs_compression(self, data: str | bytes) -> bool:
105+
return self._threshold != -1 and len(data) >= self._threshold
106+
107+
def compress(self, data: str | bytes) -> bytes:
108+
if data is not None:
109+
if isinstance(data, bytes):
110+
return zlib.compress(data, self._level)
111+
return zlib.compress(data.encode("utf-8"), self._level)
112+
return b""
113+
114+
def content_encoding(self) -> str:
115+
return self._content_encoding
116+
117+
def accept_encoding(self) -> str | None:
118+
return self._accept_encoding

arangoasync/connection.py

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44
]
55

66
from abc import ABC, abstractmethod
7-
from typing import Any, List
8-
7+
from typing import Any, List, Optional
8+
9+
from arangoasync.auth import Auth
10+
from arangoasync.compression import CompressionManager, DefaultCompressionManager
11+
from arangoasync.exceptions import (
12+
ClientConnectionError,
13+
ConnectionAbortedError,
14+
ServerConnectionError,
15+
)
916
from arangoasync.http import HTTPClient
1017
from arangoasync.request import Method, Request
1118
from arangoasync.resolver import HostResolver
@@ -20,6 +27,7 @@ class BaseConnection(ABC):
2027
host_resolver (HostResolver): Host resolver.
2128
http_client (HTTPClient): HTTP client.
2229
db_name (str): Database name.
30+
compression (CompressionManager | None): Compression manager.
2331
"""
2432

2533
def __init__(
@@ -28,39 +36,85 @@ def __init__(
2836
host_resolver: HostResolver,
2937
http_client: HTTPClient,
3038
db_name: str,
39+
compression: Optional[CompressionManager] = None,
3140
) -> None:
3241
self._sessions = sessions
3342
self._db_endpoint = f"/_db/{db_name}"
3443
self._host_resolver = host_resolver
3544
self._http_client = http_client
3645
self._db_name = db_name
46+
self._compression = compression or DefaultCompressionManager()
3747

3848
@property
3949
def db_name(self) -> str:
4050
"""Return the database name."""
4151
return self._db_name
4252

43-
def prep_response(selfs, resp: Response) -> None:
44-
"""Prepare response for return."""
45-
# TODO: Populate response fields
53+
def prep_response(self, request: Request, resp: Response) -> Response:
54+
"""Prepare response for return.
55+
56+
Args:
57+
request (Request): Request object.
58+
resp (Response): Response object.
59+
60+
Returns:
61+
Response: Response object
62+
63+
Raises:
64+
ServerConnectionError: If the response status code is not successful.
65+
"""
66+
resp.is_success = 200 <= resp.status_code < 300
67+
if not resp.is_success:
68+
raise ServerConnectionError(resp, request)
69+
return resp
4670

4771
async def process_request(self, request: Request) -> Response:
48-
"""Process request."""
49-
# TODO add accept-encoding header option
50-
# TODO regulate number of tries
51-
# TODO error handling
72+
"""Process request, potentially trying multiple hosts.
73+
74+
Args:
75+
request (Request): Request object.
76+
77+
Returns:
78+
Response: Response object.
79+
80+
Raises:
81+
ConnectionAbortedError: If can't connect to host(s) within limit.
82+
"""
83+
84+
ex_host_index = -1
5285
host_index = self._host_resolver.get_host_index()
53-
return await self._http_client.send_request(self._sessions[host_index], request)
86+
for tries in range(self._host_resolver.max_tries):
87+
try:
88+
resp = await self._http_client.send_request(
89+
self._sessions[host_index], request
90+
)
91+
return self.prep_response(request, resp)
92+
except ClientConnectionError:
93+
ex_host_index = host_index
94+
host_index = self._host_resolver.get_host_index()
95+
if ex_host_index == host_index:
96+
self._host_resolver.change_host()
97+
host_index = self._host_resolver.get_host_index()
98+
99+
raise ConnectionAbortedError(
100+
f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
101+
)
54102

55103
async def ping(self) -> int:
56104
"""Ping host to check if connection is established.
57105
58106
Returns:
59107
int: Response status code.
108+
109+
Raises:
110+
ServerConnectionError: If the response status code is not successful.
60111
"""
61112
request = Request(method=Method.GET, endpoint="/_api/collection")
62113
resp = await self.send_request(request)
63-
# TODO check raise ServerConnectionError
114+
if resp.status_code in {401, 403}:
115+
raise ServerConnectionError(resp, request, "Authentication failed.")
116+
if not resp.is_success:
117+
raise ServerConnectionError(resp, request, "Bad server response.")
64118
return resp.status_code
65119

66120
@abstractmethod
@@ -86,6 +140,8 @@ class BasicConnection(BaseConnection):
86140
host_resolver (HostResolver): Host resolver.
87141
http_client (HTTPClient): HTTP client.
88142
db_name (str): Database name.
143+
compression (CompressionManager | None): Compression manager.
144+
auth (Auth | None): Authentication information.
89145
"""
90146

91147
def __init__(
@@ -94,11 +150,23 @@ def __init__(
94150
host_resolver: HostResolver,
95151
http_client: HTTPClient,
96152
db_name: str,
153+
compression: Optional[CompressionManager] = None,
154+
auth: Optional[Auth] = None,
97155
) -> None:
98-
super().__init__(sessions, host_resolver, http_client, db_name)
156+
super().__init__(sessions, host_resolver, http_client, db_name, compression)
157+
self._auth = auth
99158

100159
async def send_request(self, request: Request) -> Response:
101160
"""Send an HTTP request to the ArangoDB server."""
102-
response = await self.process_request(request)
103-
self.prep_response(response)
104-
return response
161+
if request.data is not None and self._compression.needs_compression(
162+
request.data
163+
):
164+
request.data = self._compression.compress(request.data)
165+
request.headers["content-encoding"] = self._compression.content_encoding()
166+
if self._compression.accept_encoding():
167+
request.headers["accept-encoding"] = self._compression.accept_encoding()
168+
169+
if self._auth:
170+
request.auth = self._auth
171+
172+
return await self.process_request(request)

0 commit comments

Comments
 (0)