diff --git a/arangoasync/aql.py b/arangoasync/aql.py new file mode 100644 index 0000000..072fdbe --- /dev/null +++ b/arangoasync/aql.py @@ -0,0 +1,115 @@ +__all__ = ["AQL"] + + +from typing import Optional + +from arangoasync.cursor import Cursor +from arangoasync.exceptions import AQLQueryExecuteError +from arangoasync.executor import ApiExecutor +from arangoasync.request import Method, Request +from arangoasync.response import Response +from arangoasync.serialization import Deserializer, Serializer +from arangoasync.typings import Json, Jsons, QueryProperties, Result + + +class AQL: + """AQL (ArangoDB Query Language) API wrapper. + + Allows you to execute, track, kill, explain, and validate queries written + in ArangoDB’s query language. + + Args: + executor: API executor. Required to execute the API requests. + """ + + def __init__(self, executor: ApiExecutor) -> None: + self._executor = executor + + @property + def name(self) -> str: + """Return the name of the current database.""" + return self._executor.db_name + + @property + def serializer(self) -> Serializer[Json]: + """Return the serializer.""" + return self._executor.serializer + + @property + def deserializer(self) -> Deserializer[Json, Jsons]: + """Return the deserializer.""" + return self._executor.deserializer + + def __repr__(self) -> str: + return f"" + + async def execute( + self, + query: str, + count: Optional[bool] = None, + batch_size: Optional[int] = None, + bind_vars: Optional[Json] = None, + cache: Optional[bool] = None, + memory_limit: Optional[int] = None, + ttl: Optional[int] = None, + allow_dirty_read: Optional[bool] = None, + options: Optional[QueryProperties | Json] = None, + ) -> Result[Cursor]: + """Execute the query and return the result cursor. + + Args: + query (str): Query string to be executed. + count (bool | None): If set to `True`, the total document count is + calculated and included in the result cursor. + batch_size (int | None): Maximum number of result documents to be + transferred from the server to the client in one roundtrip. + bind_vars (dict | None): An object with key/value pairs representing + the bind parameters. + cache (bool | None): Flag to determine whether the AQL query results + cache shall be used. + memory_limit (int | None): Maximum memory (in bytes) that the query is + allowed to use. + ttl (int | None): The time-to-live for the cursor (in seconds). The cursor + will be removed on the server automatically after the specified amount + of time. + allow_dirty_read (bool | None): Allow reads from followers in a cluster. + options (QueryProperties | dict | None): Extra options for the query. + + References: + - `create-a-cursor `__ + """ # noqa: E501 + data: Json = dict(query=query) + if count is not None: + data["count"] = count + if batch_size is not None: + data["batchSize"] = batch_size + if bind_vars is not None: + data["bindVars"] = bind_vars + if cache is not None: + data["cache"] = cache + if memory_limit is not None: + data["memoryLimit"] = memory_limit + if ttl is not None: + data["ttl"] = ttl + if options is not None: + if isinstance(options, QueryProperties): + options = options.to_dict() + data["options"] = options + + headers = dict() + if allow_dirty_read is not None: + headers["x-arango-allow-dirty-read"] = str(allow_dirty_read).lower() + + request = Request( + method=Method.POST, + endpoint="/_api/cursor", + data=self.serializer.dumps(data), + headers=headers, + ) + + def response_handler(resp: Response) -> Cursor: + if not resp.is_success: + raise AQLQueryExecuteError(resp, request) + return Cursor(self._executor, self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) diff --git a/arangoasync/cursor.py b/arangoasync/cursor.py new file mode 100644 index 0000000..55ba40a --- /dev/null +++ b/arangoasync/cursor.py @@ -0,0 +1,262 @@ +__all__ = ["Cursor"] + + +from collections import deque +from typing import Any, Deque, List, Optional + +from arangoasync.errno import HTTP_NOT_FOUND +from arangoasync.exceptions import ( + CursorCloseError, + CursorCountError, + CursorEmptyError, + CursorNextError, + CursorStateError, +) +from arangoasync.executor import ApiExecutor +from arangoasync.request import Method, Request +from arangoasync.response import Response +from arangoasync.serialization import Deserializer, Serializer +from arangoasync.typings import ( + Json, + Jsons, + QueryExecutionExtra, + QueryExecutionPlan, + QueryExecutionProfile, + QueryExecutionStats, +) + + +class Cursor: + """Cursor API wrapper. + + Cursors fetch query results from ArangoDB server in batches. Cursor objects + are *stateful* as they store the fetched items in-memory. They must not be + shared across threads without a proper locking mechanism. + + Args: + executor: Required to execute the API requests. + data: Cursor initialization data. Returned by the server when the query + is created. + """ + + def __init__(self, executor: ApiExecutor, data: Json) -> None: + self._executor = executor + self._cached: Optional[bool] = None + self._count: Optional[int] = None + self._extra = QueryExecutionExtra({}) + self._has_more: Optional[bool] = None + self._id: Optional[str] = None + self._next_batch_id: Optional[str] = None + self._batch: Deque[Any] = deque() + self._update(data) + + def __aiter__(self) -> "Cursor": + return self + + async def __anext__(self) -> Any: + return await self.next() + + async def __aenter__(self) -> "Cursor": + return self + + async def __aexit__(self, *_: Any) -> None: + await self.close(ignore_missing=True) + + def __len__(self) -> int: + if self._count is None: + raise CursorCountError("Cursor count not enabled") + return self._count + + def __repr__(self) -> str: + return f"" if self._id else "" + + @property + def cached(self) -> Optional[bool]: + """Whether the result was served from the query cache or not.""" + return self._cached + + @property + def count(self) -> Optional[int]: + """The total number of result documents available.""" + return self._count + + @property + def extra(self) -> QueryExecutionExtra: + """Extra information about the query execution.""" + return self._extra + + @property + def has_more(self) -> Optional[bool]: + """Whether there are more results available on the server.""" + return self._has_more + + @property + def id(self) -> Optional[str]: + """Cursor ID.""" + return self._id + + @property + def next_batch_id(self) -> Optional[str]: + """ID of the batch after current one.""" + return self._next_batch_id + + @property + def batch(self) -> Deque[Any]: + """Return the current batch of results.""" + return self._batch + + @property + def serializer(self) -> Serializer[Json]: + """Return the serializer.""" + return self._executor.serializer + + @property + def deserializer(self) -> Deserializer[Json, Jsons]: + """Return the deserializer.""" + return self._executor.deserializer + + @property + def statistics(self) -> QueryExecutionStats: + """Query statistics.""" + return self.extra.stats + + @property + def profile(self) -> QueryExecutionProfile: + """Query profiling information.""" + return self.extra.profile + + @property + def plan(self) -> QueryExecutionPlan: + """Execution plan for the query.""" + return self.extra.plan + + @property + def warnings(self) -> List[Json]: + """Warnings generated during query execution.""" + return self.extra.warnings + + def empty(self) -> bool: + """Check if the current batch is empty.""" + return len(self._batch) == 0 + + async def next(self) -> Any: + """Retrieve and pop the next item. + + If current batch is empty/depleted, an API request is automatically + sent to fetch the next batch from the server and update the cursor. + + Returns: + Any: Next item. + + Raises: + StopAsyncIteration: If there are no more items to retrieve. + CursorNextError: If the cursor failed to fetch the next batch. + CursorStateError: If the cursor ID is not set. + """ + if self.empty(): + if not self.has_more: + raise StopAsyncIteration + await self.fetch() + return self.pop() + + def pop(self) -> Any: + """Pop the next item from the current batch. + + If current batch is empty/depleted, an exception is raised. You must + call :func:`arangoasync.cursor.Cursor.fetch` to manually fetch the next + batch from server. + + Returns: + Any: Next item from the current batch. + + Raises: + CursorEmptyError: If the current batch is empty. + """ + try: + return self._batch.popleft() + except IndexError: + raise CursorEmptyError("Current batch is empty") + + async def fetch(self, batch_id: Optional[str] = None) -> List[Any]: + """Fetch the next batch from the server and update the cursor. + + Args: + batch_id (str | None): ID of the batch to fetch. If not set, the + next batch after the current one is fetched. + + Returns: + List[Any]: New batch results. + + Raises: + CursorNextError: If the cursor is empty. + CursorStateError: If the cursor ID is not set. + + References: + - `read-the-next-batch-from-a-cursor `__ + - `read-a-batch-from-the-cursor-again `__ + """ # noqa: E501 + if self._id is None: + raise CursorStateError("Cursor ID is not set") + + endpoint = f"/_api/cursor/{self._id}" + if batch_id is not None: + endpoint += f"/{batch_id}" + + request = Request( + method=Method.POST, + endpoint=endpoint, + ) + + def response_handler(resp: Response) -> List[Any]: + if not resp.is_success: + raise CursorNextError(resp, request) + return self._update(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + async def close(self, ignore_missing: bool = False) -> bool: + """Close the cursor and free any server resources associated with it. + + Args: + ignore_missing (bool): Do not raise an exception on missing cursor. + + Returns: + bool: `True` if the cursor was closed successfully. `False` if there + was no cursor to close. If there is no cursor associated with the + query, `False` is returned. + + Raises: + CursorCloseError: If the cursor failed to close. + + References: + - `delete-a-cursor `__ + """ # noqa: E501 + if self._id is None: + return False + + request = Request( + method=Method.DELETE, + endpoint=f"/_api/cursor/{self._id}", + ) + + def response_handler(resp: Response) -> bool: + if resp.is_success: + return True + if resp.status_code == HTTP_NOT_FOUND and ignore_missing: + return False + raise CursorCloseError(resp, request) + + return await self._executor.execute(request, response_handler) + + def _update(self, data: Json) -> List[Any]: + """Update the cursor with the new data.""" + if "id" in data: + self._id = data.get("id") + self._cached = data.get("cached") + self._count = data.get("count") + self._extra = QueryExecutionExtra(data.get("extra", dict())) + self._has_more = data.get("hasMore") + self._next_batch_id = data.get("nextBatchId") + result: List[Any] = data.get("result", list()) + self._batch.extend(result) + return result diff --git a/arangoasync/database.py b/arangoasync/database.py index 277e1a9..3f91c56 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -8,6 +8,7 @@ from typing import Any, List, Optional, Sequence, TypeVar, cast from warnings import warn +from arangoasync.aql import AQL from arangoasync.collection import StandardCollection from arangoasync.connection import Connection from arangoasync.errno import HTTP_FORBIDDEN, HTTP_NOT_FOUND @@ -80,7 +81,7 @@ def connection(self) -> Connection: @property def name(self) -> str: """Return the name of the current database.""" - return self.connection.db_name + return self._executor.db_name @property def serializer(self) -> Serializer[Json]: @@ -98,10 +99,18 @@ def context(self) -> str: Returns: str: API execution context. Possible values are "default", "transaction". - :rtype: str """ return self._executor.context + @property + def aql(self) -> AQL: + """Return the AQL API wrapper. + + Returns: + arangoasync.aql.AQL: AQL API wrapper. + """ + return AQL(self._executor) + async def properties(self) -> Result[DatabaseProperties]: """Return database properties. diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index a65a13e..7095b26 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -71,6 +71,10 @@ def __init__( self.http_headers = resp.headers +class AQLQueryExecuteError(ArangoServerError): + """Failed to execute query.""" + + class AuthHeaderError(ArangoClientError): """The authentication header could not be determined.""" @@ -99,6 +103,26 @@ class ClientConnectionError(ArangoClientError): """The request was unable to reach the server.""" +class CursorCloseError(ArangoServerError): + """Failed to delete the cursor result from server.""" + + +class CursorCountError(ArangoClientError, TypeError): + """The cursor count was not enabled.""" + + +class CursorEmptyError(ArangoClientError): + """The current batch in cursor was empty.""" + + +class CursorNextError(ArangoServerError): + """Failed to retrieve the next result batch from server.""" + + +class CursorStateError(ArangoClientError): + """The cursor object was in a bad state.""" + + class DatabaseCreateError(ArangoServerError): """Failed to create database.""" diff --git a/arangoasync/typings.py b/arangoasync/typings.py index aefe43c..2a0fec0 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -1,4 +1,5 @@ from enum import Enum +from numbers import Number from typing import ( Any, Callable, @@ -1046,3 +1047,447 @@ def format(self, formatter: Optional[Formatter] = None) -> Json: if formatter is not None: return super().format(formatter) return self.compatibility_formatter(self._data) + + +class QueryProperties(JsonWrapper): + """Extra options for AQL queries. + + Args: + allow_dirty_reads (bool | None): If set to `True`, when executing the query + against a cluster deployment, the Coordinator is allowed to read from any + shard replica and not only from the leader. + allow_retry (bool | None): Setting it to `True` makes it possible to retry + fetching the latest batch from a cursor. + fail_on_warning (bool | None): If set to `True`, the query will throw an + exception and abort instead of producing a warning. + fill_block_cache (bool | None): If set to `True`, it will make the query + store the data it reads via the RocksDB storage engine in the RocksDB + block cache. + full_count (bool | None): If set to `True` and the query contains a LIMIT + clause, then the result will have some extra attributes. + intermediate_commit_count (int | None): The maximum number of operations + after which an intermediate commit is performed automatically. + intermediate_commit_size (int | None): The maximum total size of operations + after which an intermediate commit is performed automatically. + max_dnf_condition_members (int | None): A threshold for the maximum number of + OR sub-nodes in the internal representation of an AQL FILTER condition. + max_nodes_per_callstack (int | None): The number of execution nodes in the + query plan after that stack splitting is performed to avoid a potential + stack overflow. + max_number_of_plans (int | None): Limits the maximum number of plans that + are created by the AQL query optimizer. + max_runtime (float | None): The query has to be executed within the given + runtime or it is killed. The value is specified in seconds. If unspecified, + there will be no timeout. + max_transaction_size (int | None): The maximum transaction size in bytes. + max_warning_count (int | None): Limits the maximum number of warnings a + query will return. + optimizer (dict | None): Options related to the query optimizer. + profile (int | None): Return additional profiling information in the query + result. Can be set to 1 or 2 (for more detailed profiling information). + satellite_sync_wait (flat | None): How long a DB-Server has time to bring + the SatelliteCollections involved in the query into sync (in seconds). + skip_inaccessible_collections (bool | None): Treat collections to which a user + has no access rights for as if these collections are empty. + spill_over_threshold_memory_usage (int | None): This option allows queries to + store intermediate and final results temporarily on disk if the amount of + memory used (in bytes) exceeds the specified value. + spill_over_threshold_num_rows (int | None): This option allows queries to + store intermediate and final results temporarily on disk if the number + of rows produced by the query exceeds the specified value. + stream (bool | None): Can be enabled to execute the query lazily. + + Example: + .. code-block:: json + + { + "maxPlans": 1, + "optimizer": { + "rules": [ + "-all", + "+remove-unnecessary-filters" + ] + } + } + + References: + - `create-a-cursor `__ + """ # noqa: E501 + + def __init__( + self, + allow_dirty_reads: Optional[bool] = None, + allow_retry: Optional[bool] = None, + fail_on_warning: Optional[bool] = None, + fill_block_cache: Optional[bool] = None, + full_count: Optional[bool] = None, + intermediate_commit_count: Optional[int] = None, + intermediate_commit_size: Optional[int] = None, + max_dnf_condition_members: Optional[int] = None, + max_nodes_per_callstack: Optional[int] = None, + max_number_of_plans: Optional[int] = None, + max_runtime: Optional[Number] = None, + max_transaction_size: Optional[int] = None, + max_warning_count: Optional[int] = None, + optimizer: Optional[Json] = None, + profile: Optional[int] = None, + satellite_sync_wait: Optional[Number] = None, + skip_inaccessible_collections: Optional[bool] = None, + spill_over_threshold_memory_usage: Optional[int] = None, + spill_over_threshold_num_rows: Optional[int] = None, + stream: Optional[bool] = None, + ) -> None: + data: Json = dict() + if allow_dirty_reads is not None: + data["allowDirtyReads"] = allow_dirty_reads + if allow_retry is not None: + data["allowRetry"] = allow_retry + if fail_on_warning is not None: + data["failOnWarning"] = fail_on_warning + if fill_block_cache is not None: + data["fillBlockCache"] = fill_block_cache + if full_count is not None: + data["fullCount"] = full_count + if intermediate_commit_count is not None: + data["intermediateCommitCount"] = intermediate_commit_count + if intermediate_commit_size is not None: + data["intermediateCommitSize"] = intermediate_commit_size + if max_dnf_condition_members is not None: + data["maxDNFConditionMembers"] = max_dnf_condition_members + if max_nodes_per_callstack is not None: + data["maxNodesPerCallstack"] = max_nodes_per_callstack + if max_number_of_plans is not None: + data["maxNumberOfPlans"] = max_number_of_plans + if max_runtime is not None: + data["maxRuntime"] = max_runtime + if max_transaction_size is not None: + data["maxTransactionSize"] = max_transaction_size + if max_warning_count is not None: + data["maxWarningCount"] = max_warning_count + if optimizer is not None: + data["optimizer"] = optimizer + if profile is not None: + data["profile"] = profile + if satellite_sync_wait is not None: + data["satelliteSyncWait"] = satellite_sync_wait + if skip_inaccessible_collections is not None: + data["skipInaccessibleCollections"] = skip_inaccessible_collections + if spill_over_threshold_memory_usage is not None: + data["spillOverThresholdMemoryUsage"] = spill_over_threshold_memory_usage + if spill_over_threshold_num_rows is not None: + data["spillOverThresholdNumRows"] = spill_over_threshold_num_rows + if stream is not None: + data["stream"] = stream + super().__init__(data) + + @property + def allow_dirty_reads(self) -> Optional[bool]: + return self._data.get("allowDirtyReads") + + @property + def allow_retry(self) -> Optional[bool]: + return self._data.get("allowRetry") + + @property + def fail_on_warning(self) -> Optional[bool]: + return self._data.get("failOnWarning") + + @property + def fill_block_cache(self) -> Optional[bool]: + return self._data.get("fillBlockCache") + + @property + def full_count(self) -> Optional[bool]: + return self._data.get("fullCount") + + @property + def intermediate_commit_count(self) -> Optional[int]: + return self._data.get("intermediateCommitCount") + + @property + def intermediate_commit_size(self) -> Optional[int]: + return self._data.get("intermediateCommitSize") + + @property + def max_dnf_condition_members(self) -> Optional[int]: + return self._data.get("maxDNFConditionMembers") + + @property + def max_nodes_per_callstack(self) -> Optional[int]: + return self._data.get("maxNodesPerCallstack") + + @property + def max_number_of_plans(self) -> Optional[int]: + return self._data.get("maxNumberOfPlans") + + @property + def max_runtime(self) -> Optional[Number]: + return self._data.get("maxRuntime") + + @property + def max_transaction_size(self) -> Optional[int]: + return self._data.get("maxTransactionSize") + + @property + def max_warning_count(self) -> Optional[int]: + return self._data.get("maxWarningCount") + + @property + def optimizer(self) -> Optional[Json]: + return self._data.get("optimizer") + + @property + def profile(self) -> Optional[int]: + return self._data.get("profile") + + @property + def satellite_sync_wait(self) -> Optional[Number]: + return self._data.get("satelliteSyncWait") + + @property + def skip_inaccessible_collections(self) -> Optional[bool]: + return self._data.get("skipInaccessibleCollections") + + @property + def spill_over_threshold_memory_usage(self) -> Optional[int]: + return self._data.get("spillOverThresholdMemoryUsage") + + @property + def spill_over_threshold_num_rows(self) -> Optional[int]: + return self._data.get("spillOverThresholdNumRows") + + @property + def stream(self) -> Optional[bool]: + return self._data.get("stream") + + +class QueryExecutionPlan(JsonWrapper): + """The execution plan of an AQL query. + + References: + - `plan `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def collections(self) -> Optional[Jsons]: + return self._data.get("collections") + + @property + def estimated_cost(self) -> Optional[float]: + return self._data.get("estimatedCost") + + @property + def estimated_nr_items(self) -> Optional[int]: + return self._data.get("estimatedNrItems") + + @property + def is_modification_query(self) -> Optional[bool]: + return self._data.get("isModificationQuery") + + @property + def nodes(self) -> Optional[Jsons]: + return self._data.get("nodes") + + @property + def rules(self) -> Optional[List[str]]: + return self._data.get("rules") + + @property + def variables(self) -> Optional[Jsons]: + return self._data.get("variables") + + +class QueryExecutionProfile(JsonWrapper): + """The duration of the different query execution phases in seconds. + + Example: + .. code-block:: json + + { + "initializing" : 0.0000028529999838156073, + "parsing" : 0.000029285000010759177, + "optimizing ast" : 0.0000040699999885873694, + "loading collections" : 0.000012807000018710823, + "instantiating plan" : 0.00002348999998957879, + "optimizing plan" : 0.00006598600000984334, + "instantiating executors" : 0.000027471999999306718, + "executing" : 0.7550992429999894, + "finalizing" : 0.00004103500000951499 + } + + References: + - `profile `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def executing(self) -> Optional[float]: + return self._data.get("executing") + + @property + def finalizing(self) -> Optional[float]: + return self._data.get("finalizing") + + @property + def initializing(self) -> Optional[float]: + return self._data.get("initializing") + + @property + def instantiating_executors(self) -> Optional[float]: + return self._data.get("instantiating executors") + + @property + def instantiating_plan(self) -> Optional[float]: + return self._data.get("instantiating plan") + + @property + def loading_collections(self) -> Optional[float]: + return self._data.get("loading collections") + + @property + def optimizing_ast(self) -> Optional[float]: + return self._data.get("optimizing ast") + + @property + def optimizing_plan(self) -> Optional[float]: + return self._data.get("optimizing plan") + + @property + def parsing(self) -> Optional[float]: + return self._data.get("parsing") + + +class QueryExecutionStats(JsonWrapper): + """Statistics of an AQL query. + + Example: + .. code-block:: json + + { + "writesExecuted" : 0, + "writesIgnored" : 0, + "documentLookups" : 0, + "seeks" : 0, + "scannedFull" : 2, + "scannedIndex" : 0, + "cursorsCreated" : 0, + "cursorsRearmed" : 0, + "cacheHits" : 0, + "cacheMisses" : 0, + "filtered" : 0, + "httpRequests" : 0, + "executionTime" : 0.00019362399999067748, + "peakMemoryUsage" : 0, + "intermediateCommits" : 0 + } + + References: + - `stats `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def cache_hits(self) -> Optional[int]: + return self._data.get("cacheHits") + + @property + def cache_misses(self) -> Optional[int]: + return self._data.get("cacheMisses") + + @property + def cursors_created(self) -> Optional[int]: + return self._data.get("cursorsCreated") + + @property + def cursors_rearmed(self) -> Optional[int]: + return self._data.get("cursorsRearmed") + + @property + def document_lookups(self) -> Optional[int]: + return self._data.get("documentLookups") + + @property + def execution_time(self) -> Optional[float]: + return self._data.get("executionTime") + + @property + def filtered(self) -> Optional[int]: + return self._data.get("filtered") + + @property + def full_count(self) -> Optional[int]: + return self._data.get("fullCount") + + @property + def http_requests(self) -> Optional[int]: + return self._data.get("httpRequests") + + @property + def intermediate_commits(self) -> Optional[int]: + return self._data.get("intermediateCommits") + + @property + def nodes(self) -> Optional[Jsons]: + return self._data.get("nodes") + + @property + def peak_memory_usage(self) -> Optional[int]: + return self._data.get("peakMemoryUsage") + + @property + def scanned_full(self) -> Optional[int]: + return self._data.get("scannedFull") + + @property + def scanned_index(self) -> Optional[int]: + return self._data.get("scannedIndex") + + @property + def seeks(self) -> Optional[int]: + return self._data.get("seeks") + + @property + def writes_executed(self) -> Optional[int]: + return self._data.get("writesExecuted") + + @property + def writes_ignored(self) -> Optional[int]: + return self._data.get("writesIgnored") + + +class QueryExecutionExtra(JsonWrapper): + """Extra information about the query result. + + References: + - `extra `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + self._plan = QueryExecutionPlan(data.get("plan", dict())) + self._profile = QueryExecutionProfile(data.get("profile", dict())) + self._stats = QueryExecutionStats(data.get("stats", dict())) + self._warnings: Jsons = data.get("warnings", list()) + + @property + def plan(self) -> QueryExecutionPlan: + return self._plan + + @property + def profile(self) -> QueryExecutionProfile: + return self._profile + + @property + def stats(self) -> QueryExecutionStats: + return self._stats + + @property + def warnings(self) -> Jsons: + return self._warnings diff --git a/docs/specs.rst b/docs/specs.rst index 13fbdbd..db326e4 100644 --- a/docs/specs.rst +++ b/docs/specs.rst @@ -13,6 +13,15 @@ python-arango-async. .. automodule:: arangoasync.database :members: +.. automodule:: arangoasync.collection + :members: + +.. automodule:: arangoasync.aql + :members: + +.. automodule:: arangoasync.cursor + :members: + .. automodule:: arangoasync.compression :members: diff --git a/tests/test_aql.py b/tests/test_aql.py new file mode 100644 index 0000000..23a2fd3 --- /dev/null +++ b/tests/test_aql.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.asyncio +async def test_simple_query(db, doc_col, docs): + await doc_col.insert(docs[0]) + aql = db.aql + _ = await aql.execute( + query="FOR doc IN @@collection RETURN doc", + bind_vars={"@collection": doc_col.name}, + ) diff --git a/tests/test_cursor.py b/tests/test_cursor.py new file mode 100644 index 0000000..f998836 --- /dev/null +++ b/tests/test_cursor.py @@ -0,0 +1,394 @@ +import asyncio + +import pytest + +from arangoasync.aql import AQL +from arangoasync.errno import CURSOR_NOT_FOUND, HTTP_BAD_PARAMETER +from arangoasync.exceptions import ( + CursorCloseError, + CursorCountError, + CursorEmptyError, + CursorNextError, + CursorStateError, +) +from arangoasync.typings import QueryExecutionStats, QueryProperties + + +@pytest.mark.asyncio +async def test_cursor_basic_query(db, doc_col, docs, cluster): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + # Execute query + aql: AQL = db.aql + options = QueryProperties(optimizer={"rules": ["+all"]}, profile=2) + cursor = await aql.execute( + query=f"FOR doc IN {doc_col.name} SORT doc.val RETURN doc", + count=True, + batch_size=2, + ttl=1000, + options=options, + ) + + # Check cursor attributes + cursor_id = cursor.id + assert "Cursor" in repr(cursor) + assert cursor.has_more is True + assert cursor.cached is False + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 6 + assert cursor.empty() is False + batch = cursor.batch + assert len(batch) == 2 + for idx in range(2): + assert batch[idx]["val"] == docs[idx]["val"] + + # Check cursor statistics + statistics: QueryExecutionStats = cursor.statistics + assert statistics.writes_executed == 0 + assert statistics.filtered == 0 + assert statistics.writes_ignored == 0 + assert statistics.execution_time > 0 + if cluster: + assert statistics.http_requests > 0 + assert statistics.scanned_full > 0 + assert "nodes" in statistics + + # Check cursor warnings + assert cursor.warnings == [] + + # Check cursor profile + profile = cursor.profile + assert profile.initializing > 0 + assert profile.parsing > 0 + + # Check query execution plan + plan = cursor.plan + assert "nodes" in plan + assert plan.collections[0]["name"] == doc_col.name + assert plan.is_modification_query is False + + # Retrieve the next document (should be already in the batch) + assert (await cursor.next())["val"] == docs[0]["val"] + assert cursor.id == cursor_id + assert cursor.has_more is True + assert cursor.cached is False + assert cursor.statistics == statistics + assert cursor.profile == profile + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 6 + assert len(cursor.batch) == 1 + assert cursor.batch[0]["val"] == docs[1]["val"] + + # Retrieve the next document (should be already in the batch) + assert (await cursor.next())["val"] == docs[1]["val"] + assert cursor.id == cursor_id + assert cursor.has_more is True + assert cursor.cached is False + assert cursor.statistics == statistics + assert cursor.profile == profile + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 6 + assert cursor.empty() is True + + # Retrieve the next document (should be fetched from the server) + assert (await cursor.next())["val"] == docs[2]["val"] + assert cursor.id == cursor_id + assert cursor.has_more is True + assert cursor.cached is False + assert cursor.statistics == statistics + assert cursor.profile == profile + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 6 + assert cursor.batch[0]["val"] == docs[3]["val"] + assert cursor.empty() is False + + # Retrieve the rest of the documents + for idx in range(3, 6): + assert (await cursor.next())["val"] == docs[idx]["val"] + + # There should be no longer any documents to retrieve + assert cursor.empty() is True + assert cursor.has_more is False + with pytest.raises(StopAsyncIteration): + await cursor.next() + + # Close the cursor (should be already gone because it has been consumed) + assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_write_query(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + # Execute query, updating some documents + aql: AQL = db.aql + options = QueryProperties(optimizer={"rules": ["+all"]}, profile=1, max_runtime=0.0) + cursor = await aql.execute( + """ + FOR d IN {col} FILTER d.val == @first OR d.val == @second + UPDATE {{_key: d._key, _val: @val }} IN {col} + RETURN NEW + """.format( + col=doc_col.name + ), + bind_vars={"first": 1, "second": 2, "val": 42}, + count=True, + batch_size=1, + ttl=1000, + options=options, + ) + + # Check cursor attributes + cursor_id = cursor.id + assert cursor.has_more is True + assert cursor.cached is False + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 2 + assert cursor.batch[0]["val"] == docs[0]["val"] + assert cursor.empty() is False + + statistics = cursor.statistics + assert statistics.writes_executed == 2 + assert statistics.filtered == 4 # 2 docs matched, 4 docs ignored + assert statistics.writes_ignored == 0 + assert statistics.execution_time > 0 + + profile = cursor.profile + assert profile.initializing > 0 + assert profile.parsing > 0 + + # First document + assert (await cursor.next())["val"] == docs[0]["val"] + assert cursor.id == cursor_id + assert cursor.has_more is True + assert cursor.cached is False + assert cursor.statistics == statistics + assert cursor.profile == profile + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 2 + assert cursor.empty() is True + assert len(cursor.batch) == 0 + + # Second document, this is fetched from the server + assert (await cursor.next())["val"] == docs[1]["val"] + assert cursor.id == cursor_id + assert cursor.has_more is False + assert cursor.cached is False + assert cursor.statistics == statistics + assert cursor.profile == profile + assert cursor.warnings == [] + assert cursor.count == len(cursor) == 2 + assert cursor.empty() is True + + # There should be no longer any documents to retrieve, hence the cursor is closed + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_invalid_id(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=2, + ttl=1000, + options={"optimizer": {"rules": ["+all"]}, "profile": 1}, + ) + + # Set the cursor ID to "invalid" and assert errors + setattr(cursor, "_id", "invalid") + + # Cursor should not be found + with pytest.raises(CursorNextError) as err: + async for _ in cursor: + pass + assert err.value.error_code == CURSOR_NOT_FOUND + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + # Set the cursor ID to None and assert errors + setattr(cursor, "_id", None) + with pytest.raises(CursorStateError): + print(await cursor.next()) + with pytest.raises(CursorStateError): + await cursor.fetch() + assert await cursor.close() is False + + +@pytest.mark.asyncio +async def test_cursor_premature_close(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=2, + ttl=1000, + ) + assert len(cursor.batch) == 2 + assert await cursor.close() is True + + # Cursor should be already closed + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_context_manager(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=2, + ttl=1000, + ) + async with cursor as ctx: + assert (await ctx.next())["val"] == docs[0]["val"] + + # Cursor should be already closed + with pytest.raises(CursorCloseError) as err: + await cursor.close(ignore_missing=False) + assert err.value.error_code == CURSOR_NOT_FOUND + assert await cursor.close(ignore_missing=True) is False + + +@pytest.mark.asyncio +async def test_cursor_manual_fetch_and_pop(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=1, + ttl=1000, + options={"allowRetry": True}, + ) + + # Fetch documents manually + for idx in range(2, len(docs)): + result = await cursor.fetch() + assert len(result) == 1 + assert cursor.count == len(docs) + assert cursor.has_more + assert len(cursor.batch) == idx + assert result[0]["val"] == docs[idx - 1]["val"] + result = await cursor.fetch() + assert result[0]["val"] == docs[len(docs) - 1]["val"] + assert len(cursor.batch) == len(docs) + assert not cursor.has_more + + # Pop documents manually + idx = 0 + while not cursor.empty(): + doc = cursor.pop() + assert doc["val"] == docs[idx]["val"] + idx += 1 + assert len(cursor.batch) == 0 + + # Cursor should be empty + with pytest.raises(CursorEmptyError): + await cursor.pop() + + +@pytest.mark.asyncio +async def test_cursor_retry(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + # Do not allow retries + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=1, + ttl=1000, + options={"allowRetry": False}, + ) + + # Increase the batch id by doing a fetch + await cursor.fetch() + while not cursor.empty(): + cursor.pop() + next_batch_id = cursor.next_batch_id + + # Fetch the next batch + await cursor.fetch() + # Retry is not allowed + with pytest.raises(CursorNextError) as err: + await cursor.fetch(batch_id=next_batch_id) + assert err.value.error_code == HTTP_BAD_PARAMETER + + await cursor.close() + + # Now let's allow retries + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=True, + batch_size=1, + ttl=1000, + options={"allowRetry": True}, + ) + + # Increase the batch id by doing a fetch + await cursor.fetch() + while not cursor.empty(): + cursor.pop() + next_batch_id = cursor.next_batch_id + + # Fetch the next batch + prev_batch = await cursor.fetch() + next_next_batch_id = cursor.next_batch_id + # Should fetch the same batch again + next_batch = await cursor.fetch(batch_id=next_batch_id) + assert next_batch == prev_batch + # Next batch id should be the same + assert cursor.next_batch_id == next_next_batch_id + + # Fetch the next batch + next_next_batch = await cursor.fetch() + assert next_next_batch != next_batch + + assert await cursor.close() + + +@pytest.mark.asyncio +async def test_cursor_no_count(db, doc_col, docs): + # Insert documents + await asyncio.gather(*[doc_col.insert(doc) for doc in docs]) + + aql: AQL = db.aql + cursor = await aql.execute( + f"FOR d IN {doc_col.name} SORT d._key RETURN d", + count=False, + batch_size=2, + ttl=1000, + ) + + # Cursor count is not enabled + with pytest.raises(CursorCountError): + _ = len(cursor) + with pytest.raises(CursorCountError): + _ = bool(cursor) + + while cursor.has_more: + assert cursor.count is None + assert await cursor.fetch() diff --git a/tests/test_typings.py b/tests/test_typings.py index 9f00f89..166a0ce 100644 --- a/tests/test_typings.py +++ b/tests/test_typings.py @@ -6,6 +6,11 @@ CollectionType, JsonWrapper, KeyOptions, + QueryExecutionExtra, + QueryExecutionPlan, + QueryExecutionProfile, + QueryExecutionStats, + QueryProperties, UserInfo, ) @@ -125,3 +130,155 @@ def test_UserInfo(): assert user_info.active is True assert user_info.extra == {"role": "admin"} assert user_info.to_dict() == data + + +def test_QueryProperties(): + properties = QueryProperties( + allow_dirty_reads=True, + allow_retry=False, + fail_on_warning=True, + fill_block_cache=False, + full_count=True, + intermediate_commit_count=1000, + intermediate_commit_size=1048576, + max_dnf_condition_members=10, + max_nodes_per_callstack=100, + max_number_of_plans=5, + max_runtime=60.0, + max_transaction_size=10485760, + max_warning_count=10, + optimizer={"rules": ["-all", "+use-indexes"]}, + profile=1, + satellite_sync_wait=10.0, + skip_inaccessible_collections=True, + spill_over_threshold_memory_usage=10485760, + spill_over_threshold_num_rows=100000, + stream=True, + ) + assert properties.allow_dirty_reads is True + assert properties.allow_retry is False + assert properties.fail_on_warning is True + assert properties.fill_block_cache is False + assert properties.full_count is True + assert properties.intermediate_commit_count == 1000 + assert properties.intermediate_commit_size == 1048576 + assert properties.max_dnf_condition_members == 10 + assert properties.max_nodes_per_callstack == 100 + assert properties.max_number_of_plans == 5 + assert properties.max_runtime == 60.0 + assert properties.max_transaction_size == 10485760 + assert properties.max_warning_count == 10 + assert properties.optimizer == {"rules": ["-all", "+use-indexes"]} + assert properties.profile == 1 + assert properties.satellite_sync_wait == 10.0 + assert properties.skip_inaccessible_collections is True + assert properties.spill_over_threshold_memory_usage == 10485760 + assert properties.spill_over_threshold_num_rows == 100000 + assert properties.stream is True + + +def test_QueryExecutionPlan(): + data = { + "collections": [{"name": "test_collection"}], + "estimatedCost": 10.5, + "estimatedNrItems": 100, + "isModificationQuery": False, + "nodes": [{"type": "SingletonNode"}], + "rules": ["rule1", "rule2"], + "variables": [{"name": "var1"}], + } + plan = QueryExecutionPlan(data) + assert plan.collections == [{"name": "test_collection"}] + assert plan.estimated_cost == 10.5 + assert plan.estimated_nr_items == 100 + assert plan.is_modification_query is False + assert plan.nodes == [{"type": "SingletonNode"}] + assert plan.rules == ["rule1", "rule2"] + assert plan.variables == [{"name": "var1"}] + + +def test_QueryExecutionProfile(): + data = { + "initializing": 0.0000028529999838156073, + "parsing": 0.000029285000010759177, + "optimizing ast": 0.0000040699999885873694, + "loading collections": 0.000012807000018710823, + "instantiating plan": 0.00002348999998957879, + "optimizing plan": 0.00006598600000984334, + "instantiating executors": 0.000027471999999306718, + "executing": 0.7550992429999894, + "finalizing": 0.00004103500000951499, + } + profile = QueryExecutionProfile(data) + assert profile.initializing == 0.0000028529999838156073 + assert profile.parsing == 0.000029285000010759177 + assert profile.optimizing_ast == 0.0000040699999885873694 + assert profile.loading_collections == 0.000012807000018710823 + assert profile.instantiating_plan == 0.00002348999998957879 + assert profile.optimizing_plan == 0.00006598600000984334 + assert profile.instantiating_executors == 0.000027471999999306718 + assert profile.executing == 0.7550992429999894 + assert profile.finalizing == 0.00004103500000951499 + + +def test_QueryExecutionStats(): + data = { + "writesExecuted": 10, + "writesIgnored": 2, + "scannedFull": 100, + "scannedIndex": 50, + "filtered": 20, + "httpRequests": 5, + "executionTime": 0.123, + "peakMemoryUsage": 1024, + } + stats = QueryExecutionStats(data) + assert stats.writes_executed == 10 + assert stats.writes_ignored == 2 + assert stats.scanned_full == 100 + assert stats.scanned_index == 50 + assert stats.filtered == 20 + assert stats.http_requests == 5 + assert stats.execution_time == 0.123 + assert stats.peak_memory_usage == 1024 + + +def test_QueryExecutionExtra(): + data = { + "plan": { + "collections": [{"name": "test_collection"}], + "estimatedCost": 10.5, + "estimatedNrItems": 100, + "isModificationQuery": False, + "nodes": [{"type": "SingletonNode"}], + "rules": ["rule1", "rule2"], + "variables": [{"name": "var1"}], + }, + "profile": { + "initializing": 0.0000028529999838156073, + "parsing": 0.000029285000010759177, + "optimizing ast": 0.0000040699999885873694, + "loading collections": 0.000012807000018710823, + "instantiating plan": 0.00002348999998957879, + "optimizing plan": 0.00006598600000984334, + "instantiating executors": 0.000027471999999306718, + "executing": 0.7550992429999894, + "finalizing": 0.00004103500000951499, + }, + "stats": { + "writesExecuted": 10, + "writesIgnored": 2, + "scannedFull": 100, + "scannedIndex": 50, + "filtered": 20, + "httpRequests": 5, + "executionTime": 0.123, + "peakMemoryUsage": 1024, + }, + "warnings": [{"code": 123, "message": "test warning"}], + } + extra = QueryExecutionExtra(data) + assert isinstance(extra.plan, QueryExecutionPlan) + assert isinstance(extra.profile, QueryExecutionProfile) + assert isinstance(extra.stats, QueryExecutionStats) + assert extra.warnings == [{"code": 123, "message": "test warning"}]