diff --git a/arangoasync/client.py b/arangoasync/client.py index 57291c4..1b1159f 100644 --- a/arangoasync/client.py +++ b/arangoasync/client.py @@ -20,6 +20,7 @@ Deserializer, Serializer, ) +from arangoasync.typings import Json, Jsons from arangoasync.version import __version__ @@ -51,14 +52,18 @@ class ArangoClient: ` or a custom subclass of :class:`CompressionManager `. - serializer (Serializer | None): Custom serializer implementation. + serializer (Serializer | None): Custom JSON serializer implementation. Leave as `None` to use the default serializer. See :class:`DefaultSerializer `. - deserializer (Deserializer | None): Custom deserializer implementation. + For custom serialization of collection documents, see :class:`Collection + `. + deserializer (Deserializer | None): Custom JSON deserializer implementation. Leave as `None` to use the default deserializer. See :class:`DefaultDeserializer `. + For custom deserialization of collection documents, see :class:`Collection + `. Raises: ValueError: If the `host_resolver` is not supported. @@ -70,8 +75,8 @@ def __init__( host_resolver: str | HostResolver = "default", http_client: Optional[HTTPClient] = None, compression: Optional[CompressionManager] = None, - serializer: Optional[Serializer] = None, - deserializer: Optional[Deserializer] = None, + serializer: Optional[Serializer[Json]] = None, + deserializer: Optional[Deserializer[Json, Jsons]] = None, ) -> None: self._hosts = [hosts] if isinstance(hosts, str) else hosts self._host_resolver = ( @@ -84,8 +89,10 @@ def __init__( self._http_client.create_session(host) for host in self._hosts ] self._compression = compression - self._serializer = serializer or DefaultSerializer() - self._deserializer = deserializer or DefaultDeserializer() + self._serializer: Serializer[Json] = serializer or DefaultSerializer() + self._deserializer: Deserializer[Json, Jsons] = ( + deserializer or DefaultDeserializer() + ) def __repr__(self) -> str: return f"" @@ -142,8 +149,8 @@ async def db( token: Optional[JwtToken] = None, verify: bool = False, compression: Optional[CompressionManager] = None, - serializer: Optional[Serializer] = None, - deserializer: Optional[Deserializer] = None, + serializer: Optional[Serializer[Json]] = None, + deserializer: Optional[Deserializer[Json, Jsons]] = None, ) -> StandardDatabase: """Connects to a database and returns and API wrapper. @@ -178,6 +185,7 @@ async def db( ServerConnectionError: If `verify` is `True` and the connection fails. """ connection: Connection + if auth_method == "basic": if auth is None: raise ValueError("Basic authentication requires the `auth` parameter") diff --git a/arangoasync/collection.py b/arangoasync/collection.py new file mode 100644 index 0000000..b0606a8 --- /dev/null +++ b/arangoasync/collection.py @@ -0,0 +1,205 @@ +__all__ = ["Collection", "Collection", "StandardCollection"] + + +from enum import Enum +from typing import Generic, Optional, Tuple, TypeVar + +from arangoasync.errno import HTTP_NOT_FOUND, HTTP_PRECONDITION_FAILED +from arangoasync.exceptions import ( + DocumentGetError, + DocumentParseError, + DocumentRevisionError, +) +from arangoasync.executor import ApiExecutor +from arangoasync.request import Method, Request +from arangoasync.response import Response +from arangoasync.serialization import Deserializer, Serializer +from arangoasync.typings import Json, Result + +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") + + +class CollectionType(Enum): + """Collection types.""" + + DOCUMENT = 2 + EDGE = 3 + + +class Collection(Generic[T, U, V]): + """Base class for collection API wrappers. + + Args: + executor (ApiExecutor): API executor. + name (str): Collection name + doc_serializer (Serializer): Document serializer. + doc_deserializer (Deserializer): Document deserializer. + """ + + def __init__( + self, + executor: ApiExecutor, + name: str, + doc_serializer: Serializer[T], + doc_deserializer: Deserializer[U, V], + ) -> None: + self._executor = executor + self._name = name + self._doc_serializer = doc_serializer + self._doc_deserializer = doc_deserializer + self._id_prefix = f"{self._name}/" + + def __repr__(self) -> str: + return f"" + + def _validate_id(self, doc_id: str) -> str: + """Check the collection name in the document ID. + + Args: + doc_id (str): Document ID. + + Returns: + str: Verified document ID. + + Raises: + DocumentParseError: On bad collection name. + """ + if not doc_id.startswith(self._id_prefix): + raise DocumentParseError(f'Bad collection name in document ID "{doc_id}"') + return doc_id + + def _extract_id(self, body: Json) -> str: + """Extract the document ID from document body. + + Args: + body (dict): Document body. + + Returns: + str: Document ID. + + Raises: + DocumentParseError: On missing ID and key. + """ + try: + if "_id" in body: + return self._validate_id(body["_id"]) + else: + key: str = body["_key"] + return self._id_prefix + key + except KeyError: + raise DocumentParseError('Field "_key" or "_id" required') + + def _prep_from_doc( + self, + document: str | Json, + rev: Optional[str] = None, + check_rev: bool = False, + ) -> Tuple[str, Json]: + """Prepare document ID, body and request headers before a query. + + Args: + document (str | dict): Document ID, key or body. + rev (str | None): Document revision. + check_rev (bool): Whether to check the revision. + + Returns: + Document ID and request headers. + + Raises: + DocumentParseError: On missing ID and key. + TypeError: On bad document type. + """ + if isinstance(document, dict): + doc_id = self._extract_id(document) + rev = rev or document.get("_rev") + elif isinstance(document, str): + if "/" in document: + doc_id = self._validate_id(document) + else: + doc_id = self._id_prefix + document + else: + raise TypeError("Document must be str or a dict") + + if not check_rev or rev is None: + return doc_id, {} + else: + return doc_id, {"If-Match": rev} + + @property + def name(self) -> str: + """Return the name of the collection. + + Returns: + str: Collection name. + """ + return self._name + + +class StandardCollection(Collection[T, U, V]): + """Standard collection API wrapper. + + Args: + executor (ApiExecutor): API executor. + name (str): Collection name + doc_serializer (Serializer): Document serializer. + doc_deserializer (Deserializer): Document deserializer. + """ + + def __init__( + self, + executor: ApiExecutor, + name: str, + doc_serializer: Serializer[T], + doc_deserializer: Deserializer[U, V], + ) -> None: + super().__init__(executor, name, doc_serializer, doc_deserializer) + + async def get( + self, + document: str | Json, + rev: Optional[str] = None, + check_rev: bool = True, + allow_dirty_read: bool = False, + ) -> Result[Optional[U]]: + """Return a document. + + Args: + document (str | dict): Document ID, key or body. + Document body must contain the "_id" or "_key" field. + rev (str | None): Expected document revision. Overrides the + value of "_rev" field in **document** if present. + check_rev (bool): If set to True, revision of **document** (if given) + is compared against the revision of target document. + allow_dirty_read (bool): Allow reads from followers in a cluster. + + Returns: + Document or None if not found. + + Raises: + DocumentRevisionError: If the revision is incorrect. + DocumentGetError: If retrieval fails. + """ + handle, headers = self._prep_from_doc(document, rev, check_rev) + + if allow_dirty_read: + headers["x-arango-allow-dirty-read"] = "true" + + request = Request( + method=Method.GET, + endpoint=f"/_api/document/{handle}", + headers=headers, + ) + + def response_handler(resp: Response) -> Optional[U]: + if resp.is_success: + return self._doc_deserializer.loads(resp.raw_body) + elif resp.error_code == HTTP_NOT_FOUND: + return None + elif resp.error_code == HTTP_PRECONDITION_FAILED: + raise DocumentRevisionError(resp, request) + else: + raise DocumentGetError(resp, request) + + return await self._executor.execute(request, response_handler) diff --git a/arangoasync/connection.py b/arangoasync/connection.py index 1b135f2..779a85f 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -7,7 +7,6 @@ ] from abc import ABC, abstractmethod -from json import JSONDecodeError from typing import Any, List, Optional from jwt import ExpiredSignatureError @@ -19,7 +18,9 @@ AuthHeaderError, ClientConnectionAbortedError, ClientConnectionError, + DeserializationError, JWTRefreshError, + SerializationError, ServerConnectionError, ) from arangoasync.http import HTTPClient @@ -32,6 +33,7 @@ Deserializer, Serializer, ) +from arangoasync.typings import Json, Jsons class BaseConnection(ABC): @@ -43,10 +45,10 @@ class BaseConnection(ABC): http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. - serializer (Serializer | None): For custom serialization. - Leave `None` for default. - deserializer (Deserializer | None): For custom deserialization. + serializer (Serializer | None): For overriding the default JSON serialization. Leave `None` for default. + deserializer (Deserializer | None): For overriding the default JSON + deserialization. Leave `None` for default. """ def __init__( @@ -56,8 +58,8 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, - serializer: Optional[Serializer] = None, - deserializer: Optional[Deserializer] = None, + serializer: Optional[Serializer[Json]] = None, + deserializer: Optional[Deserializer[Json, Jsons]] = None, ) -> None: self._sessions = sessions self._db_endpoint = f"/_db/{db_name}" @@ -65,8 +67,10 @@ def __init__( self._http_client = http_client self._db_name = db_name self._compression = compression - self._serializer = serializer or DefaultSerializer() - self._deserializer = deserializer or DefaultDeserializer() + self._serializer: Serializer[Json] = serializer or DefaultSerializer() + self._deserializer: Deserializer[Json, Jsons] = ( + deserializer or DefaultDeserializer() + ) @property def db_name(self) -> str: @@ -74,12 +78,12 @@ def db_name(self) -> str: return self._db_name @property - def serializer(self) -> Serializer: + def serializer(self) -> Serializer[Json]: """Return the serializer.""" return self._serializer @property - def deserializer(self) -> Deserializer: + def deserializer(self) -> Deserializer[Json, Jsons]: """Return the deserializer.""" return self._deserializer @@ -112,8 +116,8 @@ def prep_response(self, request: Request, resp: Response) -> Response: resp.is_success = 200 <= resp.status_code < 300 if not resp.is_success: try: - body = self._deserializer.from_bytes(resp.raw_body) - except JSONDecodeError as e: + body = self._deserializer.loads(resp.raw_body) + except DeserializationError as e: logger.debug( f"Failed to decode response body: {e} (from request {request})" ) @@ -225,8 +229,8 @@ class BasicConnection(BaseConnection): http_client (HTTPClient): HTTP client. db_name (str): Database name. compression (CompressionManager | None): Compression manager. - serializer (Serializer | None): For custom serialization. - deserializer (Deserializer | None): For custom deserialization. + serializer (Serializer | None): Override default JSON serialization. + deserializer (Deserializer | None): Override default JSON deserialization. auth (Auth | None): Authentication information. """ @@ -237,8 +241,8 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, - serializer: Optional[Serializer] = None, - deserializer: Optional[Deserializer] = None, + serializer: Optional[Serializer[Json]] = None, + deserializer: Optional[Deserializer[Json, Jsons]] = None, auth: Optional[Auth] = None, ) -> None: super().__init__( @@ -300,8 +304,8 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, - serializer: Optional[Serializer] = None, - deserializer: Optional[Deserializer] = None, + serializer: Optional[Serializer[Json]] = None, + deserializer: Optional[Deserializer[Json, Jsons]] = None, auth: Optional[Auth] = None, token: Optional[JwtToken] = None, ) -> None: @@ -353,13 +357,17 @@ async def refresh_token(self) -> None: if self._auth is None: raise JWTRefreshError("Auth must be provided to refresh the token.") - auth_data = self._serializer.to_str( - dict(username=self._auth.username, password=self._auth.password), - ) + auth_data = dict(username=self._auth.username, password=self._auth.password) + try: + auth = self._serializer.dumps(auth_data) + except SerializationError as e: + logger.debug(f"Failed to serialize auth data: {auth_data}") + raise JWTRefreshError(str(e)) from e + request = Request( method=Method.POST, endpoint="/_open/auth", - data=auth_data.encode("utf-8"), + data=auth.encode("utf-8"), ) try: @@ -375,7 +383,7 @@ async def refresh_token(self) -> None: f"{resp.status_code} {resp.status_text}" ) - token = self._deserializer.from_bytes(resp.raw_body) + token = self._deserializer.loads(resp.raw_body) try: self.token = JwtToken(token["jwt"]) except ExpiredSignatureError as e: @@ -442,8 +450,8 @@ def __init__( http_client: HTTPClient, db_name: str, compression: Optional[CompressionManager] = None, - serializer: Optional[Serializer] = None, - deserializer: Optional[Deserializer] = None, + serializer: Optional[Serializer[Json]] = None, + deserializer: Optional[Deserializer[Json, Jsons]] = None, token: Optional[JwtToken] = None, ) -> None: super().__init__( diff --git a/arangoasync/database.py b/arangoasync/database.py index ac776bf..7cba762 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -4,18 +4,36 @@ ] +from typing import Optional, Sequence, TypeVar, cast + +from arangoasync.collection import CollectionType, StandardCollection from arangoasync.connection import Connection -from arangoasync.exceptions import ServerStatusError +from arangoasync.errno import HTTP_NOT_FOUND +from arangoasync.exceptions import ( + CollectionCreateError, + CollectionDeleteError, + CollectionListError, + ServerStatusError, +) from arangoasync.executor import ApiExecutor, DefaultApiExecutor from arangoasync.request import Method, Request from arangoasync.response import Response from arangoasync.serialization import Deserializer, Serializer -from arangoasync.typings import Result +from arangoasync.typings import Json, Jsons, Params, Result from arangoasync.wrapper import ServerStatusInformation +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") + class Database: - """Database API.""" + """Database API. + + Args: + executor: API executor. + Responsible for executing requests and handling responses. + """ def __init__(self, executor: ApiExecutor) -> None: self._executor = executor @@ -31,12 +49,12 @@ def name(self) -> str: return self.connection.db_name @property - def serializer(self) -> Serializer: + def serializer(self) -> Serializer[Json]: """Return the serializer.""" return self._executor.serializer @property - def deserializer(self) -> Deserializer: + def deserializer(self) -> Deserializer[Json, Jsons]: """Return the deserializer.""" return self._executor.deserializer @@ -54,7 +72,242 @@ async def status(self) -> Result[ServerStatusInformation]: def response_handler(resp: Response) -> ServerStatusInformation: if not resp.is_success: raise ServerStatusError(resp, request) - return ServerStatusInformation(self.deserializer.from_bytes(resp.raw_body)) + return ServerStatusInformation(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + def collection( + self, + name: str, + doc_serializer: Optional[Serializer[T]] = None, + doc_deserializer: Optional[Deserializer[U, V]] = None, + ) -> StandardCollection[T, U, V]: + """Return the collection API wrapper. + + Args: + name (str): Collection name. + doc_serializer (Serializer): Custom document serializer. + This will be used only for document operations. + doc_deserializer (Deserializer): Custom document deserializer. + This will be used only for document operations. + + Returns: + StandardCollection: Collection API wrapper. + """ + if doc_serializer is None: + serializer = cast(Serializer[T], self.serializer) + else: + serializer = doc_serializer + if doc_deserializer is None: + deserializer = cast(Deserializer[U, V], self.deserializer) + else: + deserializer = doc_deserializer + + return StandardCollection[T, U, V]( + self._executor, name, serializer, deserializer + ) + + async def has_collection(self, name: str) -> Result[bool]: + """Check if a collection exists in the database. + + Args: + name (str): Collection name. + + Returns: + bool: True if the collection exists, False otherwise. + """ + request = Request(method=Method.GET, endpoint="/_api/collection") + + def response_handler(resp: Response) -> bool: + if not resp.is_success: + raise CollectionListError(resp, request) + body = self.deserializer.loads(resp.raw_body) + return any(c["name"] == name for c in body["result"]) + + return await self._executor.execute(request, response_handler) + + async def create_collection( + self, + name: str, + doc_serializer: Optional[Serializer[T]] = None, + doc_deserializer: Optional[Deserializer[U, V]] = None, + col_type: Optional[CollectionType] = None, + write_concern: Optional[int] = None, + wait_for_sync: Optional[bool] = None, + number_of_shards: Optional[int] = None, + replication_factor: Optional[int] = None, + cache_enabled: Optional[bool] = None, + computed_values: Optional[Jsons] = None, + distribute_shards_like: Optional[str] = None, + is_system: Optional[bool] = False, + key_options: Optional[Json] = None, + schema: Optional[Json] = None, + shard_keys: Optional[Sequence[str]] = None, + sharding_strategy: Optional[str] = None, + smart_graph_attribute: Optional[str] = None, + smart_join_attribute: Optional[str] = None, + wait_for_sync_replication: Optional[bool] = None, + enforce_replication_factor: Optional[bool] = None, + ) -> Result[StandardCollection[T, U, V]]: + """Create a new collection. + + Args: + name (str): Collection name. + doc_serializer (Serializer): Custom document serializer. + This will be used only for document operations. + doc_deserializer (Deserializer): Custom document deserializer. + This will be used only for document operations. + col_type (CollectionType | None): Collection type. + write_concern (int | None): Determines how many copies of each shard are + required to be in sync on the different DB-Servers. + wait_for_sync (bool | None): If `True`, the data is synchronised to disk + before returning from a document create, update, replace or removal + operation. + number_of_shards (int | None): In a cluster, this value determines the + number of shards to create for the collection. + replication_factor (int | None): In a cluster, this attribute determines + how many copies of each shard are kept on different DB-Servers. + cache_enabled (bool | None): Whether the in-memory hash cache for + documents should be enabled for this collection. + computed_values (Jsons | None): An optional list of objects, each + representing a computed value. + distribute_shards_like (str | None): The name of another collection. + If this property is set in a cluster, the collection copies the + replicationFactor, numberOfShards and shardingStrategy properties + from the specified collection (referred to as the prototype + collection) and distributes the shards of this collection in the same + way as the shards of the other collection. + is_system (bool | None): If `True`, create a system collection. + In this case, the collection name should start with an underscore. + key_options (dict | None): Additional options for key generation. + schema (dict | None): Optional object that specifies the collection + level schema for documents. + shard_keys (list | None): In a cluster, this attribute determines which + document attributes are used to determine the target shard for + documents. + sharding_strategy (str | None): Name of the sharding strategy. + smart_graph_attribute: (str | None): The attribute that is used for + sharding: vertices with the same value of this attribute are placed + in the same shard. + smart_join_attribute: (str | None): Determines an attribute of the + collection that must contain the shard key value of the referred-to + SmartJoin collection. + wait_for_sync_replication (bool | None): If `True`, the server only + reports success back to the client when all replicas have created + the collection. Set it to `False` if you want faster server + responses and don’t care about full replication. + enforce_replication_factor (bool | None): If `True`, the server checks + if there are enough replicas available at creation time and bail out + otherwise. Set it to `False` to disable this extra check. + + Returns: + StandardCollection: Collection API wrapper. + + Raises: + CollectionCreateError: If the operation fails. + """ + data: Json = {"name": name} + if col_type is not None: + data["type"] = col_type.value + if write_concern is not None: + data["writeConcern"] = write_concern + if wait_for_sync is not None: + data["waitForSync"] = wait_for_sync + if number_of_shards is not None: + data["numberOfShards"] = number_of_shards + if replication_factor is not None: + data["replicationFactor"] = replication_factor + if cache_enabled is not None: + data["cacheEnabled"] = cache_enabled + if computed_values is not None: + data["computedValues"] = computed_values + if distribute_shards_like is not None: + data["distributeShardsLike"] = distribute_shards_like + if is_system is not None: + data["isSystem"] = is_system + if key_options is not None: + data["keyOptions"] = key_options + if schema is not None: + data["schema"] = schema + if shard_keys is not None: + data["shardKeys"] = shard_keys + if sharding_strategy is not None: + data["shardingStrategy"] = sharding_strategy + if smart_graph_attribute is not None: + data["smartGraphAttribute"] = smart_graph_attribute + if smart_join_attribute is not None: + data["smartJoinAttribute"] = smart_join_attribute + + params: Params = {} + if wait_for_sync_replication is not None: + params["waitForSyncReplication"] = wait_for_sync_replication + if enforce_replication_factor is not None: + params["enforceReplicationFactor"] = enforce_replication_factor + + request = Request( + method=Method.POST, + endpoint="/_api/collection", + data=self.serializer.dumps(data), + params=params, + ) + + def response_handler(resp: Response) -> StandardCollection[T, U, V]: + nonlocal doc_serializer, doc_deserializer + if not resp.is_success: + raise CollectionCreateError(resp, request) + if doc_serializer is None: + serializer = cast(Serializer[T], self.serializer) + else: + serializer = doc_serializer + if doc_deserializer is None: + deserializer = cast(Deserializer[U, V], self.deserializer) + else: + deserializer = doc_deserializer + return StandardCollection[T, U, V]( + self._executor, name, serializer, deserializer + ) + + return await self._executor.execute(request, response_handler) + + async def delete_collection( + self, + name: str, + ignore_missing: bool = False, + is_system: Optional[bool] = None, + ) -> Result[bool]: + """Delete a collection. + + Args: + name (str): Collection name. + ignore_missing (bool): Do not raise an exception on missing collection. + is_system (bool | None): Whether to drop a system collection. This parameter + must be set to `True` in order to drop a system collection. + + Returns: + bool: True if the collection was deleted successfully, `False` if the + collection was not found but **ignore_missing** was set to `True`. + + Raises: + CollectionDeleteError: If the operation fails. + """ + params: Params = {} + if is_system is not None: + params["isSystem"] = is_system + + request = Request( + method=Method.DELETE, + endpoint=f"/_api/collection/{name}", + params=params, + ) + + def response_handler(resp: Response) -> bool: + nonlocal ignore_missing + if resp.is_success: + return True + if resp.error_code == HTTP_NOT_FOUND: + if ignore_missing: + return False + raise CollectionDeleteError(resp, request) return await self._executor.execute(request, response_handler) diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index e2f0100..417cc45 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -68,6 +68,22 @@ def __init__( self.http_headers = resp.headers +class AuthHeaderError(ArangoClientError): + """The authentication header could not be determined.""" + + +class CollectionCreateError(ArangoServerError): + """Failed to create collection.""" + + +class CollectionDeleteError(ArangoServerError): + """Failed to delete collection.""" + + +class CollectionListError(ArangoServerError): + """Failed to retrieve collections.""" + + class ClientConnectionAbortedError(ArangoClientError): """The connection was aborted.""" @@ -76,14 +92,30 @@ class ClientConnectionError(ArangoClientError): """The request was unable to reach the server.""" -class AuthHeaderError(ArangoClientError): - """The authentication header could not be determined.""" +class DeserializationError(ArangoClientError): + """Failed to deserialize the server response.""" + + +class DocumentGetError(ArangoServerError): + """Failed to retrieve document.""" + + +class DocumentParseError(ArangoClientError): + """Failed to parse document input.""" + + +class DocumentRevisionError(ArangoServerError): + """The expected and actual document revisions mismatched.""" class JWTRefreshError(ArangoClientError): """Failed to refresh the JWT token.""" +class SerializationError(ArangoClientError): + """Failed to serialize the request.""" + + class ServerConnectionError(ArangoServerError): """Failed to connect to ArangoDB server.""" diff --git a/arangoasync/executor.py b/arangoasync/executor.py index e839fce..7d9c888 100644 --- a/arangoasync/executor.py +++ b/arangoasync/executor.py @@ -4,6 +4,7 @@ from arangoasync.request import Request from arangoasync.response import Response from arangoasync.serialization import Deserializer, Serializer +from arangoasync.typings import Json, Jsons T = TypeVar("T") @@ -29,11 +30,11 @@ def context(self) -> str: return "default" @property - def serializer(self) -> Serializer: + def serializer(self) -> Serializer[Json]: return self._conn.serializer @property - def deserializer(self) -> Deserializer: + def deserializer(self) -> Deserializer[Json, Jsons]: return self._conn.deserializer async def execute( diff --git a/arangoasync/request.py b/arangoasync/request.py index 9e824f0..951c9e9 100644 --- a/arangoasync/request.py +++ b/arangoasync/request.py @@ -58,14 +58,14 @@ def __init__( endpoint: str, headers: Optional[RequestHeaders] = None, params: Optional[Params] = None, - data: Optional[bytes] = None, + data: Optional[bytes | str] = None, auth: Optional[Auth] = None, ) -> None: self.method: Method = method self.endpoint: str = endpoint self.headers: RequestHeaders = headers or dict() self.params: Params = params or dict() - self.data: Optional[bytes] = data + self.data: Optional[bytes | str] = data self.auth: Optional[Auth] = auth def normalized_headers(self) -> RequestHeaders: diff --git a/arangoasync/serialization.py b/arangoasync/serialization.py index 9a0dfe1..d17a2cd 100644 --- a/arangoasync/serialization.py +++ b/arangoasync/serialization.py @@ -7,45 +7,53 @@ "DefaultDeserializer", ] +import json from abc import ABC, abstractmethod -from json import dumps, loads -from typing import Any +from typing import Generic, TypeVar +from arangoasync.exceptions import DeserializationError, SerializationError +from arangoasync.typings import Json, Jsons -class Serializer(ABC): # pragma: no cover +T = TypeVar("T") +U = TypeVar("U") + + +class Serializer(ABC, Generic[T]): # pragma: no cover """Abstract base class for serialization. Custom serialization classes should inherit from this class. + Please be mindful of the performance implications. """ @abstractmethod - def to_str(self, data: Any) -> str: + def dumps(self, data: T) -> str: """Serialize any generic data. - This method impacts all serialization operations within the client. - Please be mindful of the performance implications. - Args: data: Data to serialize. Returns: str: Serialized data. + + Raises: + SerializationError: If the data cannot be serialized. """ raise NotImplementedError -class Deserializer(ABC): # pragma: no cover +class Deserializer(ABC, Generic[T, U]): # pragma: no cover """Abstract base class for deserialization. Custom deserialization classes should inherit from this class. + Please be mindful of the performance implications. """ @abstractmethod - def from_bytes(self, data: bytes) -> Any: - """Deserialize generic response data that does not represent documents. + def loads(self, data: bytes) -> T: + """Deserialize response data. - This is to be used when the response is not a document, but some other - information (for example, server status). + Will be called on generic server data (such as server status) and + single documents. Args: data (bytes): Data to deserialize. @@ -54,42 +62,52 @@ def from_bytes(self, data: bytes) -> Any: Deserialized data. Raises: - json.JSONDecodeError: If the data cannot be deserialized. + DeserializationError: If the data cannot be deserialized. """ raise NotImplementedError @abstractmethod - def from_doc(self, data: bytes) -> Any: - """Deserialize document data. + def loads_many(self, data: bytes) -> U: + """Deserialize response data. - This is to be used when the response represents (a) document(s). - The implementation **must** support deserializing both a single documents - and a list of documents. + Will only be called when deserializing a list of documents. Args: data (bytes): Data to deserialize. Returns: Deserialized data. + + Raises: + DeserializationError: If the data cannot be deserialized. """ raise NotImplementedError -class JsonSerializer(Serializer): +class JsonSerializer(Serializer[Json]): """JSON serializer.""" - def to_str(self, data: Any) -> str: - return dumps(data, separators=(",", ":")) + def dumps(self, data: T) -> str: + try: + return json.dumps(data, separators=(",", ":")) + except Exception as e: + raise SerializationError("Failed to serialize data to JSON.") from e -class JsonDeserializer(Deserializer): +class JsonDeserializer(Deserializer[Json, Jsons]): """JSON deserializer.""" - def from_bytes(self, data: bytes) -> Any: - return loads(data) - - def from_doc(self, data: bytes) -> Any: - return loads(data) + def loads(self, data: bytes) -> Json: + try: + return json.loads(data) # type: ignore[no-any-return] + except Exception as e: + raise DeserializationError("Failed to deserialize data from JSON.") from e + + def loads_many(self, data: bytes) -> Jsons: + try: + return json.loads(data) # type: ignore[no-any-return] + except Exception as e: + raise DeserializationError("Failed to deserialize data from JSON.") from e DefaultSerializer = JsonSerializer diff --git a/arangoasync/typings.py b/arangoasync/typings.py index 46b0fa4..b75defc 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -1,16 +1,24 @@ __all__ = [ + "Json", + "Jsons", "RequestHeaders", "ResponseHeaders", "Params", "Result", ] -from typing import MutableMapping, TypeVar, Union +from typing import Any, Dict, List, MutableMapping, TypeVar, Union from multidict import CIMultiDictProxy, MultiDict from arangoasync.job import AsyncJob +Json = Dict[str, Any] +Json.__doc__ = """Type definition for request/response body""" + +Jsons = List[Json] +Jsons.__doc__ = """Type definition for a list of JSON objects""" + RequestHeaders = MutableMapping[str, str] | MultiDict[str] RequestHeaders.__doc__ = """Type definition for request HTTP headers""" diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..e7220ba --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,10 @@ +from uuid import uuid4 + + +def generate_col_name(): + """Generate and return a random collection name. + + Returns: + str: Random collection name. + """ + return f"test_collection_{uuid4().hex}" diff --git a/tests/test_database.py b/tests/test_database.py index 0ae3b3b..14e615f 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,14 +2,33 @@ from arangoasync.auth import Auth from arangoasync.client import ArangoClient +from arangoasync.collection import StandardCollection +from tests.helpers import generate_col_name @pytest.mark.asyncio async def test_database_misc_methods(url, sys_db_name, root, password): auth = Auth(username=root, password=password) - # TODO create a test database and user + # TODO also handle exceptions async with ArangoClient(hosts=url) as client: db = await client.db(sys_db_name, auth_method="basic", auth=auth, verify=True) status = await db.status() assert status["server"] == "arango" + + +@pytest.mark.asyncio +async def test_create_drop_collection(url, sys_db_name, root, password): + auth = Auth(username=root, password=password) + + # TODO also handle exceptions + async with ArangoClient(hosts=url) as client: + db = await client.db(sys_db_name, auth_method="basic", auth=auth, verify=True) + col_name = generate_col_name() + col = await db.create_collection(col_name) + assert isinstance(col, StandardCollection) + assert await db.has_collection(col_name) + await db.delete_collection(col_name) + assert not await db.has_collection(col_name) + non_existent_col = generate_col_name() + assert await db.has_collection(non_existent_col) is False