4
4
]
5
5
6
6
from abc import ABC , abstractmethod
7
- from typing import Any , List
8
-
7
+ from typing import Any , List , Optional
8
+
9
+ from arangoasync .auth import Auth
10
+ from arangoasync .compression import CompressionManager , DefaultCompressionManager
11
+ from arangoasync .exceptions import (
12
+ ClientConnectionError ,
13
+ ConnectionAbortedError ,
14
+ ServerConnectionError ,
15
+ )
9
16
from arangoasync .http import HTTPClient
10
17
from arangoasync .request import Method , Request
11
18
from arangoasync .resolver import HostResolver
@@ -20,6 +27,7 @@ class BaseConnection(ABC):
20
27
host_resolver (HostResolver): Host resolver.
21
28
http_client (HTTPClient): HTTP client.
22
29
db_name (str): Database name.
30
+ compression (CompressionManager | None): Compression manager.
23
31
"""
24
32
25
33
def __init__ (
@@ -28,39 +36,85 @@ def __init__(
28
36
host_resolver : HostResolver ,
29
37
http_client : HTTPClient ,
30
38
db_name : str ,
39
+ compression : Optional [CompressionManager ] = None ,
31
40
) -> None :
32
41
self ._sessions = sessions
33
42
self ._db_endpoint = f"/_db/{ db_name } "
34
43
self ._host_resolver = host_resolver
35
44
self ._http_client = http_client
36
45
self ._db_name = db_name
46
+ self ._compression = compression or DefaultCompressionManager ()
37
47
38
48
@property
39
49
def db_name (self ) -> str :
40
50
"""Return the database name."""
41
51
return self ._db_name
42
52
43
- def prep_response (selfs , resp : Response ) -> None :
44
- """Prepare response for return."""
45
- # TODO: Populate response fields
53
+ def prep_response (self , request : Request , resp : Response ) -> Response :
54
+ """Prepare response for return.
55
+
56
+ Args:
57
+ request (Request): Request object.
58
+ resp (Response): Response object.
59
+
60
+ Returns:
61
+ Response: Response object
62
+
63
+ Raises:
64
+ ServerConnectionError: If the response status code is not successful.
65
+ """
66
+ resp .is_success = 200 <= resp .status_code < 300
67
+ if not resp .is_success :
68
+ raise ServerConnectionError (resp , request )
69
+ return resp
46
70
47
71
async def process_request (self , request : Request ) -> Response :
48
- """Process request."""
49
- # TODO add accept-encoding header option
50
- # TODO regulate number of tries
51
- # TODO error handling
72
+ """Process request, potentially trying multiple hosts.
73
+
74
+ Args:
75
+ request (Request): Request object.
76
+
77
+ Returns:
78
+ Response: Response object.
79
+
80
+ Raises:
81
+ ConnectionAbortedError: If can't connect to host(s) within limit.
82
+ """
83
+
84
+ ex_host_index = - 1
52
85
host_index = self ._host_resolver .get_host_index ()
53
- return await self ._http_client .send_request (self ._sessions [host_index ], request )
86
+ for tries in range (self ._host_resolver .max_tries ):
87
+ try :
88
+ resp = await self ._http_client .send_request (
89
+ self ._sessions [host_index ], request
90
+ )
91
+ return self .prep_response (request , resp )
92
+ except ClientConnectionError :
93
+ ex_host_index = host_index
94
+ host_index = self ._host_resolver .get_host_index ()
95
+ if ex_host_index == host_index :
96
+ self ._host_resolver .change_host ()
97
+ host_index = self ._host_resolver .get_host_index ()
98
+
99
+ raise ConnectionAbortedError (
100
+ f"Can't connect to host(s) within limit ({ self ._host_resolver .max_tries } )"
101
+ )
54
102
55
103
async def ping (self ) -> int :
56
104
"""Ping host to check if connection is established.
57
105
58
106
Returns:
59
107
int: Response status code.
108
+
109
+ Raises:
110
+ ServerConnectionError: If the response status code is not successful.
60
111
"""
61
112
request = Request (method = Method .GET , endpoint = "/_api/collection" )
62
113
resp = await self .send_request (request )
63
- # TODO check raise ServerConnectionError
114
+ if resp .status_code in {401 , 403 }:
115
+ raise ServerConnectionError (resp , request , "Authentication failed." )
116
+ if not resp .is_success :
117
+ raise ServerConnectionError (resp , request , "Bad server response." )
64
118
return resp .status_code
65
119
66
120
@abstractmethod
@@ -86,6 +140,8 @@ class BasicConnection(BaseConnection):
86
140
host_resolver (HostResolver): Host resolver.
87
141
http_client (HTTPClient): HTTP client.
88
142
db_name (str): Database name.
143
+ compression (CompressionManager | None): Compression manager.
144
+ auth (Auth | None): Authentication information.
89
145
"""
90
146
91
147
def __init__ (
@@ -94,11 +150,23 @@ def __init__(
94
150
host_resolver : HostResolver ,
95
151
http_client : HTTPClient ,
96
152
db_name : str ,
153
+ compression : Optional [CompressionManager ] = None ,
154
+ auth : Optional [Auth ] = None ,
97
155
) -> None :
98
- super ().__init__ (sessions , host_resolver , http_client , db_name )
156
+ super ().__init__ (sessions , host_resolver , http_client , db_name , compression )
157
+ self ._auth = auth
99
158
100
159
async def send_request (self , request : Request ) -> Response :
101
160
"""Send an HTTP request to the ArangoDB server."""
102
- response = await self .process_request (request )
103
- self .prep_response (response )
104
- return response
161
+ if request .data is not None and self ._compression .needs_compression (
162
+ request .data
163
+ ):
164
+ request .data = self ._compression .compress (request .data )
165
+ request .headers ["content-encoding" ] = self ._compression .content_encoding ()
166
+ if self ._compression .accept_encoding ():
167
+ request .headers ["accept-encoding" ] = self ._compression .accept_encoding ()
168
+
169
+ if self ._auth :
170
+ request .auth = self ._auth
171
+
172
+ return await self .process_request (request )
0 commit comments