Skip to content

Commit 7940d6d

Browse files
committed
Introducing BasicConnection
1 parent 41a9bda commit 7940d6d

12 files changed

+452
-6
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ repos:
3838
hooks:
3939
- id: mypy
4040
files: ^arangoasync/
41-
additional_dependencies: ['types-requests', "types-setuptools"]
41+
additional_dependencies: ["types-requests", "types-setuptools"]

arangoasync/auth.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
__all__ = [
2+
"Auth",
3+
"JwtToken",
4+
]
5+
6+
from dataclasses import dataclass
7+
8+
import jwt
9+
10+
from arangoasync.exceptions import JWTExpiredError
11+
12+
13+
@dataclass
14+
class Auth:
15+
"""Authentication details for the ArangoDB instance.
16+
17+
Attributes:
18+
username (str): Username.
19+
password (str): Password.
20+
encoding (str): Encoding for the password (default: utf-8)
21+
"""
22+
23+
username: str
24+
password: str
25+
encoding: str = "utf-8"
26+
27+
28+
class JwtToken:
29+
"""JWT token.
30+
31+
Args:
32+
token (str | bytes): JWT token.
33+
34+
Raises:
35+
JWTExpiredError: If the token expired.
36+
"""
37+
38+
def __init__(self, token: str | bytes) -> None:
39+
self._token = token
40+
self._validate()
41+
42+
@property
43+
def token(self) -> str | bytes:
44+
"""Get token."""
45+
return self._token
46+
47+
@token.setter
48+
def token(self, token: str | bytes) -> None:
49+
"""Set token.
50+
51+
Raises:
52+
JWTExpiredError: If the token expired.
53+
"""
54+
self._token = token
55+
self._validate()
56+
57+
def _validate(self) -> None:
58+
"""Validate the token."""
59+
if type(self._token) is not str:
60+
raise TypeError("Token must be a string")
61+
try:
62+
jwt_payload = jwt.decode(
63+
self._token,
64+
issuer="arangodb",
65+
algorithms=["HS256"],
66+
options={
67+
"require_exp": True,
68+
"require_iat": True,
69+
"verify_iat": True,
70+
"verify_exp": True,
71+
"verify_signature": False,
72+
},
73+
)
74+
except jwt.ExpiredSignatureError:
75+
raise JWTExpiredError("JWT token has expired")
76+
77+
self._token_exp = jwt_payload["exp"]

arangoasync/connection.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
__all__ = [
2+
"BaseConnection",
3+
"BasicConnection",
4+
]
5+
6+
from abc import ABC, abstractmethod
7+
from typing import Any, List
8+
9+
from arangoasync.http import HTTPClient
10+
from arangoasync.request import Method, Request
11+
from arangoasync.resolver import HostResolver
12+
from arangoasync.response import Response
13+
14+
15+
class BaseConnection(ABC):
16+
"""Blueprint for connection to a specific ArangoDB database.
17+
18+
Args:
19+
sessions (list): List of client sessions.
20+
host_resolver (HostResolver): Host resolver.
21+
http_client (HTTPClient): HTTP client.
22+
db_name (str): Database name.
23+
"""
24+
25+
def __init__(
26+
self,
27+
sessions: List[Any],
28+
host_resolver: HostResolver,
29+
http_client: HTTPClient,
30+
db_name: str,
31+
) -> None:
32+
self._sessions = sessions
33+
self._db_endpoint = f"/_db/{db_name}"
34+
self._host_resolver = host_resolver
35+
self._http_client = http_client
36+
self._db_name = db_name
37+
38+
@property
39+
def db_name(self) -> str:
40+
"""Return the database name."""
41+
return self._db_name
42+
43+
def prep_response(selfs, resp: Response) -> None:
44+
"""Prepare response for return."""
45+
# TODO: Populate response fields
46+
47+
async def process_request(self, request: Request) -> Response:
48+
"""Process request."""
49+
# TODO add accept-encoding header option
50+
# TODO regulate number of tries
51+
# TODO error handling
52+
host_index = self._host_resolver.get_host_index()
53+
return await self._http_client.send_request(self._sessions[host_index], request)
54+
55+
async def ping(self) -> int:
56+
"""Ping host to check if connection is established.
57+
58+
Returns:
59+
int: Response status code.
60+
"""
61+
request = Request(method=Method.GET, endpoint="/_api/collection")
62+
resp = await self.send_request(request)
63+
# TODO check raise ServerConnectionError
64+
return resp.status_code
65+
66+
@abstractmethod
67+
async def send_request(self, request: Request) -> Response: # pragma: no cover
68+
"""Send an HTTP request to the ArangoDB server.
69+
70+
Args:
71+
request (Request): HTTP request.
72+
73+
Returns:
74+
Response: HTTP response.
75+
"""
76+
raise NotImplementedError
77+
78+
79+
class BasicConnection(BaseConnection):
80+
"""Connection to a specific ArangoDB database.
81+
82+
Allows for basic authentication to be used (username and password).
83+
84+
Args:
85+
sessions (list): List of client sessions.
86+
host_resolver (HostResolver): Host resolver.
87+
http_client (HTTPClient): HTTP client.
88+
db_name (str): Database name.
89+
"""
90+
91+
def __init__(
92+
self,
93+
sessions: List[Any],
94+
host_resolver: HostResolver,
95+
http_client: HTTPClient,
96+
db_name: str,
97+
) -> None:
98+
super().__init__(sessions, host_resolver, http_client, db_name)
99+
100+
async def send_request(self, request: Request) -> Response:
101+
"""Send an HTTP request to the ArangoDB server."""
102+
response = await self.process_request(request)
103+
self.prep_response(response)
104+
return response

arangoasync/exceptions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class ArangoError(Exception):
2+
"""Base class for all exceptions in python-arango-async."""
3+
4+
5+
class ArangoClientError(ArangoError):
6+
"""Base class for all client-related exceptions.
7+
8+
Args:
9+
msg (str): Error message.
10+
11+
Attributes:
12+
source (str): Source of the error (always set to "client")
13+
message (str): Error message.
14+
"""
15+
16+
source = "client"
17+
18+
def __init__(self, msg: str) -> None:
19+
super().__init__(msg)
20+
self.message = msg
21+
22+
23+
class JWTExpiredError(ArangoClientError):
24+
"""JWT token has expired."""

arangoasync/http.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from aiohttp import BaseConnector, BasicAuth, ClientSession, ClientTimeout, TCPConnector
1111

12+
from arangoasync.auth import Auth
1213
from arangoasync.request import Request
1314
from arangoasync.response import Response
1415

@@ -74,7 +75,7 @@ class AioHTTPClient(HTTPClient):
7475
timeout (aiohttp.ClientTimeout | None): Client timeout settings.
7576
300s total timeout by default for a complete request/response operation.
7677
read_bufsize (int): Size of read buffer (64KB default).
77-
auth (aiohttp.BasicAuth | None): HTTP authentication helper.
78+
auth (Auth | None): HTTP authentication helper.
7879
Should be used for specifying authorization data in client API.
7980
compression_threshold (int): Will compress requests to the server if the size
8081
of the request body (in bytes) is at least the value of this option.
@@ -88,7 +89,7 @@ def __init__(
8889
connector: Optional[BaseConnector] = None,
8990
timeout: Optional[ClientTimeout] = None,
9091
read_bufsize: int = 2**16,
91-
auth: Optional[BasicAuth] = None,
92+
auth: Optional[Auth] = None,
9293
compression_threshold: int = 1024,
9394
) -> None:
9495
self._connector = connector or TCPConnector(
@@ -100,7 +101,13 @@ def __init__(
100101
connect=60, # max number of seconds for acquiring a pool connection
101102
)
102103
self._read_bufsize = read_bufsize
103-
self._auth = auth
104+
self._auth = (
105+
BasicAuth(
106+
login=auth.username, password=auth.password, encoding=auth.encoding
107+
)
108+
if auth
109+
else None
110+
)
104111
self._compression_threshold = compression_threshold
105112

106113
def create_session(self, host: str) -> ClientSession:

arangoasync/resolver.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
__all__ = [
2+
"HostResolver",
3+
"SingleHostResolver",
4+
"RoundRobinHostResolver",
5+
"DefaultHostResolver",
6+
"get_resolver",
7+
]
8+
9+
from abc import ABC, abstractmethod
10+
from typing import List, Optional
11+
12+
13+
class HostResolver(ABC):
14+
"""Abstract base class for host resolvers.
15+
16+
Args:
17+
host_count (int): Number of hosts.
18+
max_tries (int): Maximum number of attempts to try a host.
19+
20+
Raises:
21+
ValueError: If max_tries is less than host_count.
22+
"""
23+
24+
def __init__(self, host_count: int = 1, max_tries: Optional[int] = None) -> None:
25+
max_tries = max_tries or host_count * 3
26+
if max_tries < host_count:
27+
raise ValueError(
28+
"The maximum number of attempts cannot be "
29+
"lower than the number of hosts."
30+
)
31+
self._host_count = host_count
32+
self._max_tries = max_tries
33+
self._index = 0
34+
35+
@abstractmethod
36+
def get_host_index(self) -> int: # pragma: no cover
37+
"""Return the index of the host to use.
38+
39+
Returns:
40+
int: Index of the host.
41+
"""
42+
raise NotImplementedError
43+
44+
def change_host(self) -> None:
45+
"""If there aer multiple hosts available, switch to the next one."""
46+
self._index = (self._index + 1) % self.host_count
47+
48+
@property
49+
def host_count(self) -> int:
50+
"""Return the number of hosts."""
51+
return self._host_count
52+
53+
@property
54+
def max_tries(self) -> int:
55+
"""Return the maximum number of attempts."""
56+
return self._max_tries
57+
58+
59+
class SingleHostResolver(HostResolver):
60+
"""Single host resolver. Always returns the same host index."""
61+
62+
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
63+
super().__init__(host_count, max_tries)
64+
65+
def get_host_index(self) -> int:
66+
return self._index
67+
68+
69+
class RoundRobinHostResolver(HostResolver):
70+
"""Round-robin host resolver. Changes host every time.
71+
72+
Useful for bulk inserts or updates.
73+
74+
Note:
75+
Do not use this resolver for stream transactions.
76+
Transaction IDs cannot be shared across different coordinators.
77+
"""
78+
79+
def __init__(self, host_count: int, max_tries: Optional[int] = None) -> None:
80+
super().__init__(host_count, max_tries)
81+
self._index = -1
82+
83+
def get_host_index(self, indexes_to_filter: Optional[List[int]] = None) -> int:
84+
self.change_host()
85+
return self._index
86+
87+
88+
DefaultHostResolver = SingleHostResolver
89+
90+
91+
def get_resolver(
92+
strategy: str,
93+
host_count: int,
94+
max_tries: Optional[int] = None,
95+
) -> HostResolver:
96+
"""Return a host resolver based on the strategy.
97+
98+
Args:
99+
strategy (str): Resolver strategy.
100+
host_count (int): Number of hosts.
101+
max_tries (int): Maximum number of attempts to try a host.
102+
103+
Returns:
104+
HostResolver: Host resolver.
105+
106+
Raises:
107+
ValueError: If the strategy is not supported.
108+
"""
109+
if strategy == "roundrobin":
110+
return RoundRobinHostResolver(host_count, max_tries)
111+
if strategy == "single":
112+
return SingleHostResolver(host_count, max_tries)
113+
if strategy == "default":
114+
return DefaultHostResolver(host_count, max_tries)
115+
raise ValueError(f"Unsupported host resolver strategy: {strategy}")

docs/specs.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,23 @@ API Specification
44
This page contains the specification for all classes and methods available in
55
python-arango-async.
66

7+
.. automodule:: arangoasync.auth
8+
:members:
9+
10+
.. automodule:: arangoasync.connection
11+
:members:
12+
13+
.. automodule:: arangoasync.exceptions
14+
:members: ArangoError, ArangoClientError
15+
716
.. automodule:: arangoasync.http
817
:members:
918

1019
.. automodule:: arangoasync.request
1120
:members:
1221

22+
.. automodule:: arangoasync.resolver
23+
:members:
24+
1325
.. automodule:: arangoasync.response
1426
:members:

0 commit comments

Comments
 (0)