From 00a57b518647c55d32a2d957e8820156ede1b59d Mon Sep 17 00:00:00 2001 From: Alex Petenchea Date: Sun, 29 Sep 2024 16:33:40 +0300 Subject: [PATCH] Introducing custom serialization --- arangoasync/client.py | 30 ++++++++ arangoasync/connection.py | 86 +++++++++++++++++---- arangoasync/database.py | 27 ++++--- arangoasync/executor.py | 9 +++ arangoasync/job.py | 12 +++ arangoasync/serialization.py | 96 +++++++++++++++++++++++ arangoasync/typings.py | 8 +- arangoasync/wrapper.py | 144 +++++++++++++++++++++++++++++++++++ docs/specs.rst | 9 +++ tests/test_database.py | 2 +- tests/test_wrapper.py | 32 ++++++++ 11 files changed, 428 insertions(+), 27 deletions(-) create mode 100644 arangoasync/job.py create mode 100644 arangoasync/serialization.py create mode 100644 arangoasync/wrapper.py create mode 100644 tests/test_wrapper.py diff --git a/arangoasync/client.py b/arangoasync/client.py index 464501e..57291c4 100644 --- a/arangoasync/client.py +++ b/arangoasync/client.py @@ -14,6 +14,12 @@ from arangoasync.database import StandardDatabase from arangoasync.http import DefaultHTTPClient, HTTPClient from arangoasync.resolver import HostResolver, get_resolver +from arangoasync.serialization import ( + DefaultDeserializer, + DefaultSerializer, + Deserializer, + Serializer, +) from arangoasync.version import __version__ @@ -45,6 +51,14 @@ class ArangoClient: ` or a custom subclass of :class:`CompressionManager `. + serializer (Serializer | None): Custom serializer implementation. + Leave as `None` to use the default serializer. + See :class:`DefaultSerializer + `. + deserializer (Deserializer | None): Custom deserializer implementation. + Leave as `None` to use the default deserializer. + See :class:`DefaultDeserializer + `. Raises: ValueError: If the `host_resolver` is not supported. @@ -56,6 +70,8 @@ def __init__( host_resolver: str | HostResolver = "default", http_client: Optional[HTTPClient] = None, compression: Optional[CompressionManager] = None, + serializer: Optional[Serializer] = None, + deserializer: Optional[Deserializer] = None, ) -> None: self._hosts = [hosts] if isinstance(hosts, str) else hosts self._host_resolver = ( @@ -68,6 +84,8 @@ def __init__( self._http_client.create_session(host) for host in self._hosts ] self._compression = compression + self._serializer = serializer or DefaultSerializer() + self._deserializer = deserializer or DefaultDeserializer() def __repr__(self) -> str: return f"" @@ -124,6 +142,8 @@ async def db( token: Optional[JwtToken] = None, verify: bool = False, compression: Optional[CompressionManager] = None, + serializer: Optional[Serializer] = None, + deserializer: Optional[Deserializer] = None, ) -> StandardDatabase: """Connects to a database and returns and API wrapper. @@ -145,6 +165,10 @@ async def db( verify (bool): Verify the connection by sending a test request. compression (CompressionManager | None): If set, supersedes the client-level compression settings. + serializer (Serializer | None): If set, supersedes the client-level + serializer. + deserializer (Deserializer | None): If set, supersedes the client-level + deserializer. Returns: StandardDatabase: Database API wrapper. @@ -163,6 +187,8 @@ async def db( http_client=self._http_client, db_name=name, compression=compression or self._compression, + serializer=serializer or self._serializer, + deserializer=deserializer or self._deserializer, auth=auth, ) elif auth_method == "jwt": @@ -176,6 +202,8 @@ async def db( http_client=self._http_client, db_name=name, compression=compression or self._compression, + serializer=serializer or self._serializer, + deserializer=deserializer or self._deserializer, auth=auth, token=token, ) @@ -190,6 +218,8 @@ async def db( http_client=self._http_client, db_name=name, compression=compression or self._compression, + serializer=serializer or self._serializer, + deserializer=deserializer or self._deserializer, token=token, ) else: diff --git a/arangoasync/connection.py b/arangoasync/connection.py index 68a021f..1b135f2 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -6,11 +6,11 @@ "JwtSuperuserConnection", ] -import json from abc import ABC, abstractmethod +from json import JSONDecodeError from typing import Any, List, Optional -import jwt +from jwt import ExpiredSignatureError from arangoasync import errno, logger from arangoasync.auth import Auth, JwtToken @@ -26,6 +26,12 @@ from arangoasync.request import Method, Request from arangoasync.resolver import HostResolver from arangoasync.response import Response +from arangoasync.serialization import ( + DefaultDeserializer, + DefaultSerializer, + Deserializer, + Serializer, +) class BaseConnection(ABC): @@ -37,6 +43,10 @@ class BaseConnection(ABC): http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. + serializer (Serializer | None): For custom serialization. + Leave `None` for default. + deserializer (Deserializer | None): For custom deserialization. + Leave `None` for default. """ def __init__( @@ -46,6 +56,8 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, + serializer: Optional[Serializer] = None, + deserializer: Optional[Deserializer] = None, ) -> None: self._sessions = sessions self._db_endpoint = f"/_db/{db_name}" @@ -53,12 +65,24 @@ def __init__( self._http_client = http_client self._db_name = db_name self._compression = compression + self._serializer = serializer or DefaultSerializer() + self._deserializer = deserializer or DefaultDeserializer() @property def db_name(self) -> str: """Return the database name.""" return self._db_name + @property + def serializer(self) -> Serializer: + """Return the serializer.""" + return self._serializer + + @property + def deserializer(self) -> Deserializer: + """Return the deserializer.""" + return self._deserializer + @staticmethod def raise_for_status(request: Request, resp: Response) -> None: """Raise an exception based on the response. @@ -75,8 +99,7 @@ def raise_for_status(request: Request, resp: Response) -> None: if not resp.is_success: raise ServerConnectionError(resp, request, "Bad server response.") - @staticmethod - def prep_response(request: Request, resp: Response) -> Response: + def prep_response(self, request: Request, resp: Response) -> Response: """Prepare response for return. Args: @@ -89,8 +112,8 @@ def prep_response(request: Request, resp: Response) -> Response: resp.is_success = 200 <= resp.status_code < 300 if not resp.is_success: try: - body = json.loads(resp.raw_body) - except json.JSONDecodeError as e: + body = self._deserializer.from_bytes(resp.raw_body) + except JSONDecodeError as e: logger.debug( f"Failed to decode response body: {e} (from request {request})" ) @@ -202,6 +225,8 @@ class BasicConnection(BaseConnection): http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. + serializer (Serializer | None): For custom serialization. + deserializer (Deserializer | None): For custom deserialization. auth (Auth | None): Authentication information. """ @@ -212,9 +237,19 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, + serializer: Optional[Serializer] = None, + deserializer: Optional[Deserializer] = None, auth: Optional[Auth] = None, ) -> None: - super().__init__(sessions, host_resolver, http_client, db_name, compression) + super().__init__( + sessions, + host_resolver, + http_client, + db_name, + compression, + serializer, + deserializer, + ) self._auth = auth async def send_request(self, request: Request) -> Response: @@ -249,6 +284,8 @@ class JwtConnection(BaseConnection): http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. + serializer (Serializer | None): For custom serialization. + deserializer (Deserializer | None): For custom deserialization. auth (Auth | None): Authentication information. token (JwtToken | None): JWT token. @@ -263,10 +300,20 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, + serializer: Optional[Serializer] = None, + deserializer: Optional[Deserializer] = None, auth: Optional[Auth] = None, token: Optional[JwtToken] = None, ) -> None: - super().__init__(sessions, host_resolver, http_client, db_name, compression) + super().__init__( + sessions, + host_resolver, + http_client, + db_name, + compression, + serializer, + deserializer, + ) self._auth = auth self._expire_leeway: int = 0 self._token: Optional[JwtToken] = token @@ -306,10 +353,8 @@ async def refresh_token(self) -> None: if self._auth is None: raise JWTRefreshError("Auth must be provided to refresh the token.") - auth_data = json.dumps( + auth_data = self._serializer.to_str( dict(username=self._auth.username, password=self._auth.password), - separators=(",", ":"), - ensure_ascii=False, ) request = Request( method=Method.POST, @@ -330,10 +375,10 @@ async def refresh_token(self) -> None: f"{resp.status_code} {resp.status_text}" ) - token = json.loads(resp.raw_body) + token = self._deserializer.from_bytes(resp.raw_body) try: self.token = JwtToken(token["jwt"]) - except jwt.ExpiredSignatureError as e: + except ExpiredSignatureError as e: raise JWTRefreshError( "Failed to refresh the JWT token: got an expired token" ) from e @@ -385,6 +430,8 @@ class JwtSuperuserConnection(BaseConnection): http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. + serializer (Serializer | None): For custom serialization. + deserializer (Deserializer | None): For custom deserialization. token (JwtToken | None): JWT token. """ @@ -395,10 +442,19 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, + serializer: Optional[Serializer] = None, + deserializer: Optional[Deserializer] = None, token: Optional[JwtToken] = None, ) -> None: - super().__init__(sessions, host_resolver, http_client, db_name, compression) - self._expire_leeway: int = 0 + super().__init__( + sessions, + host_resolver, + http_client, + db_name, + compression, + serializer, + deserializer, + ) self._token: Optional[JwtToken] = token self._auth_header: Optional[str] = None self.token = self._token diff --git a/arangoasync/database.py b/arangoasync/database.py index 51ac136..ac776bf 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -3,14 +3,15 @@ "StandardDatabase", ] -import json -from typing import Any from arangoasync.connection import Connection from arangoasync.exceptions import ServerStatusError from arangoasync.executor import ApiExecutor, DefaultApiExecutor from arangoasync.request import Method, Request from arangoasync.response import Response +from arangoasync.serialization import Deserializer, Serializer +from arangoasync.typings import Result +from arangoasync.wrapper import ServerStatusInformation class Database: @@ -29,25 +30,31 @@ def name(self) -> str: """Return the name of the current database.""" return self.connection.db_name - # TODO - user real return type - async def status(self) -> Any: + @property + def serializer(self) -> Serializer: + """Return the serializer.""" + return self._executor.serializer + + @property + def deserializer(self) -> Deserializer: + """Return the deserializer.""" + return self._executor.deserializer + + async def status(self) -> Result[ServerStatusInformation]: """Query the server status. Returns: - Json: Server status. + ServerStatusInformation: Server status. Raises: ServerSatusError: If retrieval fails. """ request = Request(method=Method.GET, endpoint="/_admin/status") - # TODO - # - introduce specific return type for response_handler - # - introduce specific serializer and deserializer - def response_handler(resp: Response) -> Any: + def response_handler(resp: Response) -> ServerStatusInformation: if not resp.is_success: raise ServerStatusError(resp, request) - return json.loads(resp.raw_body) + return ServerStatusInformation(self.deserializer.from_bytes(resp.raw_body)) return await self._executor.execute(request, response_handler) diff --git a/arangoasync/executor.py b/arangoasync/executor.py index a07479d..e839fce 100644 --- a/arangoasync/executor.py +++ b/arangoasync/executor.py @@ -3,6 +3,7 @@ from arangoasync.connection import Connection from arangoasync.request import Request from arangoasync.response import Response +from arangoasync.serialization import Deserializer, Serializer T = TypeVar("T") @@ -27,6 +28,14 @@ def connection(self) -> Connection: def context(self) -> str: return "default" + @property + def serializer(self) -> Serializer: + return self._conn.serializer + + @property + def deserializer(self) -> Deserializer: + return self._conn.deserializer + async def execute( self, request: Request, response_handler: Callable[[Response], T] ) -> T: diff --git a/arangoasync/job.py b/arangoasync/job.py new file mode 100644 index 0000000..eb1be0e --- /dev/null +++ b/arangoasync/job.py @@ -0,0 +1,12 @@ +__all__ = ["AsyncJob"] + + +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class AsyncJob(Generic[T]): + """Job for tracking and retrieving result of an async API execution.""" + + pass diff --git a/arangoasync/serialization.py b/arangoasync/serialization.py new file mode 100644 index 0000000..9a0dfe1 --- /dev/null +++ b/arangoasync/serialization.py @@ -0,0 +1,96 @@ +__all__ = [ + "Serializer", + "Deserializer", + "JsonSerializer", + "JsonDeserializer", + "DefaultSerializer", + "DefaultDeserializer", +] + +from abc import ABC, abstractmethod +from json import dumps, loads +from typing import Any + + +class Serializer(ABC): # pragma: no cover + """Abstract base class for serialization. + + Custom serialization classes should inherit from this class. + """ + + @abstractmethod + def to_str(self, data: Any) -> str: + """Serialize any generic data. + + This method impacts all serialization operations within the client. + Please be mindful of the performance implications. + + Args: + data: Data to serialize. + + Returns: + str: Serialized data. + """ + raise NotImplementedError + + +class Deserializer(ABC): # pragma: no cover + """Abstract base class for deserialization. + + Custom deserialization classes should inherit from this class. + """ + + @abstractmethod + def from_bytes(self, data: bytes) -> Any: + """Deserialize generic response data that does not represent documents. + + This is to be used when the response is not a document, but some other + information (for example, server status). + + Args: + data (bytes): Data to deserialize. + + Returns: + Deserialized data. + + Raises: + json.JSONDecodeError: If the data cannot be deserialized. + """ + raise NotImplementedError + + @abstractmethod + def from_doc(self, data: bytes) -> Any: + """Deserialize document data. + + This is to be used when the response represents (a) document(s). + The implementation **must** support deserializing both a single documents + and a list of documents. + + Args: + data (bytes): Data to deserialize. + + Returns: + Deserialized data. + """ + raise NotImplementedError + + +class JsonSerializer(Serializer): + """JSON serializer.""" + + def to_str(self, data: Any) -> str: + return dumps(data, separators=(",", ":")) + + +class JsonDeserializer(Deserializer): + """JSON deserializer.""" + + def from_bytes(self, data: bytes) -> Any: + return loads(data) + + def from_doc(self, data: bytes) -> Any: + return loads(data) + + +DefaultSerializer = JsonSerializer +DefaultDeserializer = JsonDeserializer diff --git a/arangoasync/typings.py b/arangoasync/typings.py index b3622f9..46b0fa4 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -2,12 +2,15 @@ "RequestHeaders", "ResponseHeaders", "Params", + "Result", ] -from typing import MutableMapping +from typing import MutableMapping, TypeVar, Union from multidict import CIMultiDictProxy, MultiDict +from arangoasync.job import AsyncJob + RequestHeaders = MutableMapping[str, str] | MultiDict[str] RequestHeaders.__doc__ = """Type definition for request HTTP headers""" @@ -16,3 +19,6 @@ Params = MutableMapping[str, bool | int | str] Params.__doc__ = """Type definition for URL (query) parameters""" + +T = TypeVar("T") +Result = Union[T, AsyncJob[T]] diff --git a/arangoasync/wrapper.py b/arangoasync/wrapper.py new file mode 100644 index 0000000..da2e974 --- /dev/null +++ b/arangoasync/wrapper.py @@ -0,0 +1,144 @@ +from typing import Any, Dict, Iterator, Optional, Tuple + + +class Wrapper: + """Wrapper over server response objects.""" + + def __init__(self, data: Dict[str, Any]) -> None: + self._data = data + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._data[key] = value + + def __delitem__(self, key: str) -> None: + del self._data[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __contains__(self, item: str) -> bool: + return item in self._data + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._data})" + + def __str__(self) -> str: + return str(self._data) + + def __eq__(self, other: object) -> bool: + return self._data == other + + def get(self, key: str, default: Optional[Any] = None) -> Any: + """Return the value for key if key is in the dictionary, else default.""" + return self._data.get(key, default) + + def items(self) -> Iterator[Tuple[str, Any]]: + """Return an iterator over the dictionary’s key-value pairs.""" + return iter(self._data.items()) + + +class ServerStatusInformation(Wrapper): + """ + https://docs.arangodb.com/stable/develop/http-api/administration/#get-server-status-information + + Example: + .. code-block:: json + + { + "server" : "arango", + "version" : "3.12.2", + "pid" : 244, + "license" : "enterprise", + "mode" : "server", + "operationMode" : "server", + "foxxApi" : true, + "host" : "localhost", + "hostname" : "ebd1509c9185", + "serverInfo" : { + "progress" : { + "phase" : "in wait", + "feature" : "", + "recoveryTick" : 0 + }, + "maintenance" : false, + "role" : "COORDINATOR", + "writeOpsEnabled" : true, + "readOnly" : false, + "persistedId" : "CRDN-329cfc20-071f-4faf-9727-7e48a7aed1e5", + "rebootId" : 1, + "address" : "tcp://localhost:8529", + "serverId" : "CRDN-329cfc20-071f-4faf-9727-7e48a7aed1e5", + "state" : "SERVING" + }, + "coordinator" : { + "foxxmaster" : "CRDN-0ed76822-3e64-47ed-a61b-510f2a696175", + "isFoxxmaster" : false + }, + "agency" : { + "agencyComm" : { + "endpoints" : [ + "tcp://localhost:8551", + "tcp://localhost:8541", + "tcp://localhost:8531" + ] + } + } + } + """ + + def __init__(self, data: Dict[str, Any]) -> None: + super().__init__(data) + + @property + def server(self) -> Optional[str]: + return self._data.get("server") + + @property + def version(self) -> Optional[str]: + return self._data.get("version") + + @property + def pid(self) -> Optional[int]: + return self._data.get("pid") + + @property + def license(self) -> Optional[str]: + return self._data.get("license") + + @property + def mode(self) -> Optional[str]: + return self._data.get("mode") + + @property + def operation_mode(self) -> Optional[str]: + return self._data.get("operationMode") + + @property + def foxx_api(self) -> Optional[bool]: + return self._data.get("foxxApi") + + @property + def host(self) -> Optional[str]: + return self._data.get("host") + + @property + def hostname(self) -> Optional[str]: + return self._data.get("hostname") + + @property + def server_info(self) -> Optional[Dict[str, Any]]: + return self._data.get("serverInfo") + + @property + def coordinator(self) -> Optional[Dict[str, Any]]: + return self._data.get("coordinator") + + @property + def agency(self) -> Optional[Dict[str, Any]]: + return self._data.get("agency") diff --git a/docs/specs.rst b/docs/specs.rst index d9f6ad7..05290ad 100644 --- a/docs/specs.rst +++ b/docs/specs.rst @@ -10,9 +10,15 @@ python-arango-async. .. automodule:: arangoasync.auth :members: +.. automodule:: arangoasync.database + :members: + .. automodule:: arangoasync.compression :members: +.. automodule:: arangoasync.serialization + :members: + .. automodule:: arangoasync.connection :members: @@ -30,3 +36,6 @@ python-arango-async. .. automodule:: arangoasync.response :members: + +.. automodule:: arangoasync.wrapper + :members: diff --git a/tests/test_database.py b/tests/test_database.py index d25c42b..0ae3b3b 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -5,7 +5,7 @@ @pytest.mark.asyncio -async def test_client_basic_auth(url, sys_db_name, root, password): +async def test_database_misc_methods(url, sys_db_name, root, password): auth = Auth(username=root, password=password) # TODO create a test database and user diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py new file mode 100644 index 0000000..d2396b4 --- /dev/null +++ b/tests/test_wrapper.py @@ -0,0 +1,32 @@ +from arangoasync.wrapper import Wrapper + + +def test_basic_wrapper(): + wrapper = Wrapper({"a": 1, "b": 2}) + assert wrapper["a"] == 1 + assert wrapper["b"] == 2 + + wrapper["c"] = 3 + assert wrapper["c"] == 3 + + del wrapper["a"] + assert "a" not in wrapper + + wrapper = Wrapper({"a": 1, "b": 2}) + keys = list(iter(wrapper)) + assert keys == ["a", "b"] + assert len(wrapper) == 2 + + assert "a" in wrapper + assert "c" not in wrapper + + assert repr(wrapper) == "Wrapper({'a': 1, 'b': 2})" + wrapper = Wrapper({"a": 1, "b": 2}) + assert str(wrapper) == "{'a': 1, 'b': 2}" + assert wrapper == {"a": 1, "b": 2} + + assert wrapper.get("a") == 1 + assert wrapper.get("c", 3) == 3 + + items = list(wrapper.items()) + assert items == [("a", 1), ("b", 2)]