diff --git a/arangoasync/aql.py b/arangoasync/aql.py index 072fdbe..021c054 100644 --- a/arangoasync/aql.py +++ b/arangoasync/aql.py @@ -4,12 +4,30 @@ from typing import Optional from arangoasync.cursor import Cursor -from arangoasync.exceptions import AQLQueryExecuteError +from arangoasync.errno import HTTP_NOT_FOUND +from arangoasync.exceptions import ( + AQLQueryClearError, + AQLQueryExecuteError, + AQLQueryExplainError, + AQLQueryKillError, + AQLQueryListError, + AQLQueryRulesGetError, + AQLQueryTrackingGetError, + AQLQueryTrackingSetError, + AQLQueryValidateError, +) from arangoasync.executor import ApiExecutor from arangoasync.request import Method, Request from arangoasync.response import Response from arangoasync.serialization import Deserializer, Serializer -from arangoasync.typings import Json, Jsons, QueryProperties, Result +from arangoasync.typings import ( + Json, + Jsons, + QueryExplainOptions, + QueryProperties, + QueryTrackingConfiguration, + Result, +) class AQL: @@ -75,6 +93,9 @@ async def execute( allow_dirty_read (bool | None): Allow reads from followers in a cluster. options (QueryProperties | dict | None): Extra options for the query. + Returns: + Cursor: Result cursor. + References: - `create-a-cursor `__ """ # noqa: E501 @@ -113,3 +134,311 @@ def response_handler(resp: Response) -> Cursor: return Cursor(self._executor, self.deserializer.loads(resp.raw_body)) return await self._executor.execute(request, response_handler) + + async def tracking(self) -> Result[QueryTrackingConfiguration]: + """Returns the current query tracking configuration. + + Returns: + QueryTrackingConfiguration: Returns the current query tracking configuration. + + Raises: + AQLQueryTrackingGetError: If retrieval fails. + + References: + - `get-the-aql-query-tracking-configuration `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/query/properties") + + def response_handler(resp: Response) -> QueryTrackingConfiguration: + if not resp.is_success: + raise AQLQueryTrackingGetError(resp, request) + return QueryTrackingConfiguration(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + async def set_tracking( + self, + enabled: Optional[bool] = None, + max_slow_queries: Optional[int] = None, + slow_query_threshold: Optional[int] = None, + max_query_string_length: Optional[int] = None, + track_bind_vars: Optional[bool] = None, + track_slow_queries: Optional[int] = None, + ) -> Result[QueryTrackingConfiguration]: + """Configure AQL query tracking properties. + + Args: + enabled (bool | None): If set to `True`, then queries will be tracked. + If set to `False`, neither queries nor slow queries will be tracked. + max_slow_queries (int | None): Maximum number of slow queries to track. Oldest + entries are discarded first. + slow_query_threshold (int | None): Runtime threshold (in seconds) for treating a + query as slow. + max_query_string_length (int | None): The maximum query string length (in bytes) + to keep in the list of queries. + track_bind_vars (bool | None): If set to `True`, track bind variables used in + queries. + track_slow_queries (int | None): If set to `True`, then slow queries will be + tracked in the list of slow queries if their runtime exceeds the + value set in `slowQueryThreshold`. + + Returns: + QueryTrackingConfiguration: Returns the updated query tracking configuration. + + Raises: + AQLQueryTrackingSetError: If setting the configuration fails. + + References: + - `update-the-aql-query-tracking-configuration `__ + """ # noqa: E501 + data: Json = dict() + + if enabled is not None: + data["enabled"] = enabled + if max_slow_queries is not None: + data["maxSlowQueries"] = max_slow_queries + if max_query_string_length is not None: + data["maxQueryStringLength"] = max_query_string_length + if slow_query_threshold is not None: + data["slowQueryThreshold"] = slow_query_threshold + if track_bind_vars is not None: + data["trackBindVars"] = track_bind_vars + if track_slow_queries is not None: + data["trackSlowQueries"] = track_slow_queries + + request = Request( + method=Method.PUT, + endpoint="/_api/query/properties", + data=self.serializer.dumps(data), + ) + + def response_handler(resp: Response) -> QueryTrackingConfiguration: + if not resp.is_success: + raise AQLQueryTrackingSetError(resp, request) + return QueryTrackingConfiguration(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + async def queries(self, all_queries: bool = False) -> Result[Jsons]: + """Return a list of currently running queries. + + Args: + all_queries (bool): If set to `True`, will return the currently + running queries in all databases, not just the selected one. + Using the parameter is only allowed in the `_system` database + and with superuser privileges. + + Returns: + list: List of currently running queries and their properties. + + Raises: + AQLQueryListError: If retrieval fails. + + References: + - `list-the-running-queries `__ + """ # noqa: E501 + request = Request( + method=Method.GET, + endpoint="/_api/query/current", + params={"all": all_queries}, + ) + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AQLQueryListError(resp, request) + return self.deserializer.loads_many(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def slow_queries(self, all_queries: bool = False) -> Result[Jsons]: + """Returns a list containing the last AQL queries that are finished and + have exceeded the slow query threshold in the selected database. + + Args: + all_queries (bool): If set to `True`, will return the slow queries + in all databases, not just the selected one. Using the parameter + is only allowed in the `_system` database and with superuser privileges. + + Returns: + list: List of slow queries. + + Raises: + AQLQueryListError: If retrieval fails. + + References: + - `list-the-slow-aql-queries `__ + """ # noqa: E501 + request = Request( + method=Method.GET, + endpoint="/_api/query/slow", + params={"all": all_queries}, + ) + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AQLQueryListError(resp, request) + return self.deserializer.loads_many(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def clear_slow_queries(self, all_queries: bool = False) -> Result[None]: + """Clears the list of slow queries. + + Args: + all_queries (bool): If set to `True`, will clear the slow queries + in all databases, not just the selected one. Using the parameter + is only allowed in the `_system` database and with superuser privileges. + + Returns: + dict: Empty dictionary. + + Raises: + AQLQueryClearError: If retrieval fails. + + References: + - `clear-the-list-of-slow-aql-queries `__ + """ # noqa: E501 + request = Request( + method=Method.DELETE, + endpoint="/_api/query/slow", + params={"all": all_queries}, + ) + + def response_handler(resp: Response) -> None: + if not resp.is_success: + raise AQLQueryClearError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def kill( + self, + query_id: str, + ignore_missing: bool = False, + all_queries: bool = False, + ) -> Result[bool]: + """Kill a running query. + + Args: + query_id (str): Thea ID of the query to kill. + ignore_missing (bool): If set to `True`, will not raise an exception + if the query is not found. + all_queries (bool): If set to `True`, will kill the query in all databases, + not just the selected one. Using the parameter is only allowed in the + `_system` database and with superuser privileges. + + Returns: + bool: `True` if the query was killed successfully. + + Raises: + AQLQueryKillError: If killing the query fails. + + References: + - `kill-a-running-aql-query `__ + """ # noqa: E501 + request = Request( + method=Method.DELETE, + endpoint=f"/_api/query/{query_id}", + params={"all": all_queries}, + ) + + def response_handler(resp: Response) -> bool: + if resp.is_success: + return True + if resp.status_code == HTTP_NOT_FOUND and ignore_missing: + return False + raise AQLQueryKillError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def explain( + self, + query: str, + bind_vars: Optional[Json] = None, + options: Optional[QueryExplainOptions | Json] = None, + ) -> Result[Json]: + """Inspect the query and return its metadata without executing it. + + Args: + query (str): Query string to be explained. + bind_vars (dict | None): An object with key/value pairs representing + the bind parameters. + options (QueryExplainOptions | dict | None): Extra options for the query. + + Returns: + dict: Query execution plan. + + Raises: + AQLQueryExplainError: If retrieval fails. + + References: + - `explain-an-aql-query `__ + """ # noqa: E501 + data: Json = dict(query=query) + if bind_vars is not None: + data["bindVars"] = bind_vars + if options is not None: + if isinstance(options, QueryExplainOptions): + options = options.to_dict() + data["options"] = options + + request = Request( + method=Method.POST, + endpoint="/_api/explain", + data=self.serializer.dumps(data), + ) + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + raise AQLQueryExplainError(resp, request) + return self.deserializer.loads(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def validate(self, query: str) -> Result[Json]: + """Parse and validate the query without executing it. + + Args: + query (str): Query string to be validated. + + Returns: + dict: Query information. + + Raises: + AQLQueryValidateError: If validation fails. + + References: + - `parse-an-aql-query `__ + """ # noqa: E501 + request = Request( + method=Method.POST, + endpoint="/_api/query", + data=self.serializer.dumps(dict(query=query)), + ) + + def response_handler(resp: Response) -> Json: + if not resp.is_success: + raise AQLQueryValidateError(resp, request) + return self.deserializer.loads(resp.raw_body) + + return await self._executor.execute(request, response_handler) + + async def query_rules(self) -> Result[Jsons]: + """A list of all optimizer rules and their properties. + + Returns: + list: Available optimizer rules. + + Raises: + AQLQueryRulesGetError: If retrieval fails. + + References: + - `list-all-aql-optimizer-rules `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/query/rules") + + def response_handler(resp: Response) -> Jsons: + if not resp.is_success: + raise AQLQueryRulesGetError(resp, request) + return self.deserializer.loads_many(resp.raw_body) + + return await self._executor.execute(request, response_handler) diff --git a/arangoasync/auth.py b/arangoasync/auth.py index 5a1ab04..96e9b1b 100644 --- a/arangoasync/auth.py +++ b/arangoasync/auth.py @@ -103,7 +103,7 @@ def needs_refresh(self, leeway: int = 0) -> bool: def _validate(self) -> None: """Validate the token.""" if type(self._token) is not str: - raise TypeError("Token must be str or bytes") + raise TypeError("Token must be str") jwt_payload = jwt.decode( self._token, diff --git a/arangoasync/connection.py b/arangoasync/connection.py index 63a18c1..27ab9df 100644 --- a/arangoasync/connection.py +++ b/arangoasync/connection.py @@ -424,7 +424,6 @@ async def send_request(self, request: Request) -> Response: # If the token has expired, refresh it and retry the request await self.refresh_token() resp = await self.process_request(request) - self.raise_for_status(request, resp) return resp @@ -509,7 +508,6 @@ async def send_request(self, request: Request) -> Response: self.compress_request(request) resp = await self.process_request(request) - self.raise_for_status(request, resp) return resp diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index 7095b26..6cf31c5 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -71,10 +71,42 @@ def __init__( self.http_headers = resp.headers +class AQLQueryClearError(ArangoServerError): + """Failed to clear slow AQL queries.""" + + class AQLQueryExecuteError(ArangoServerError): """Failed to execute query.""" +class AQLQueryExplainError(ArangoServerError): + """Failed to parse and explain query.""" + + +class AQLQueryKillError(ArangoServerError): + """Failed to kill the query.""" + + +class AQLQueryListError(ArangoServerError): + """Failed to retrieve running AQL queries.""" + + +class AQLQueryRulesGetError(ArangoServerError): + """Failed to retrieve AQL query rules.""" + + +class AQLQueryTrackingGetError(ArangoServerError): + """Failed to retrieve AQL tracking properties.""" + + +class AQLQueryTrackingSetError(ArangoServerError): + """Failed to configure AQL tracking properties.""" + + +class AQLQueryValidateError(ArangoServerError): + """Failed to parse and validate query.""" + + class AuthHeaderError(ArangoClientError): """The authentication header could not be determined.""" diff --git a/arangoasync/typings.py b/arangoasync/typings.py index 2a0fec0..496e5ca 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -1491,3 +1491,110 @@ def stats(self) -> QueryExecutionStats: @property def warnings(self) -> Jsons: return self._warnings + + +class QueryTrackingConfiguration(JsonWrapper): + """AQL query tracking configuration. + + Example: + .. code-block:: json + + { + "enabled": true, + "trackSlowQueries": true, + "trackBindVars": true, + "maxSlowQueries": 64, + "slowQueryThreshold": 10, + "slowStreamingQueryThreshold": 10, + "maxQueryStringLength": 4096 + } + + References: + - `get-the-aql-query-tracking-configuration `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def enabled(self) -> bool: + return cast(bool, self._data["enabled"]) + + @property + def track_slow_queries(self) -> bool: + return cast(bool, self._data["trackSlowQueries"]) + + @property + def track_bind_vars(self) -> bool: + return cast(bool, self._data["trackBindVars"]) + + @property + def max_slow_queries(self) -> int: + return cast(int, self._data["maxSlowQueries"]) + + @property + def slow_query_threshold(self) -> int: + return cast(int, self._data["slowQueryThreshold"]) + + @property + def slow_streaming_query_threshold(self) -> Optional[int]: + return self._data.get("slowStreamingQueryThreshold") + + @property + def max_query_string_length(self) -> int: + return cast(int, self._data["maxQueryStringLength"]) + + +class QueryExplainOptions(JsonWrapper): + """Options for explaining an AQL query. + + Args: + all_plans (bool | None): If set to `True`, all possible execution plans are + returned. + max_plans (int | None): The maximum number of plans to return. + optimizer (dict | None): Options related to the query optimizer. + + Example: + .. code-block:: json + + { + "allPlans" : false, + "maxNumberOfPlans" : 1, + "optimizer" : { + "rules" : [ + "-all", + "+use-indexe-for-sort" + ] + } + } + + References: + - `explain-an-aql-query `__ + """ # noqa: E501 + + def __init__( + self, + all_plans: Optional[bool] = None, + max_plans: Optional[int] = None, + optimizer: Optional[Json] = None, + ) -> None: + data: Json = dict() + if all_plans is not None: + data["allPlans"] = all_plans + if max_plans is not None: + data["maxNumberOfPlans"] = max_plans + if optimizer is not None: + data["optimizer"] = optimizer + super().__init__(data) + + @property + def all_plans(self) -> Optional[bool]: + return self._data.get("allPlans") + + @property + def max_plans(self) -> Optional[int]: + return self._data.get("maxNumberOfPlans") + + @property + def optimizer(self) -> Optional[Json]: + return self._data.get("optimizer") diff --git a/pyproject.toml b/pyproject.toml index 9d6b7b5..71cedb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,8 @@ packages = ["arangoasync"] [tool.pytest.ini_options] addopts = "-s -vv -p no:warnings" minversion = "6.0" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] [tool.coverage.run] diff --git a/tests/conftest.py b/tests/conftest.py index e997824..65846e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ class GlobalData: root: str = None password: str = None secret: str = None - token: str = None + token: JwtToken = None sys_db_name: str = "_system" username: str = generate_username() cluster: bool = False @@ -157,6 +157,13 @@ async def sys_db(arango_client, sys_db_name, basic_auth_root): ) +@pytest_asyncio.fixture +async def superuser(arango_client, sys_db_name, basic_auth_root, token): + return await arango_client.db( + sys_db_name, auth_method="superuser", token=token, verify=False + ) + + @pytest_asyncio.fixture async def db(arango_client, sys_db, username, password, cluster): tst_db_name = generate_db_name() diff --git a/tests/test_aql.py b/tests/test_aql.py index 23a2fd3..74db7b0 100644 --- a/tests/test_aql.py +++ b/tests/test_aql.py @@ -1,11 +1,192 @@ +import asyncio +import time + import pytest +from arangoasync.errno import QUERY_PARSE +from arangoasync.exceptions import ( + AQLQueryClearError, + AQLQueryExecuteError, + AQLQueryExplainError, + AQLQueryKillError, + AQLQueryListError, + AQLQueryRulesGetError, + AQLQueryTrackingGetError, + AQLQueryTrackingSetError, + AQLQueryValidateError, +) +from arangoasync.typings import QueryExplainOptions + @pytest.mark.asyncio -async def test_simple_query(db, doc_col, docs): +async def test_simple_query(db, bad_db, doc_col, docs): await doc_col.insert(docs[0]) aql = db.aql _ = await aql.execute( query="FOR doc IN @@collection RETURN doc", bind_vars={"@collection": doc_col.name}, ) + + assert repr(db.aql) == f"" + + with pytest.raises(AQLQueryExecuteError): + _ = await bad_db.aql.execute( + query="FOR doc IN @@collection RETURN doc", + bind_vars={"@collection": doc_col.name}, + ) + + +@pytest.mark.asyncio +async def test_query_tracking(db, bad_db): + aql = db.aql + + # Get the current tracking properties. + tracking = await aql.tracking() + assert tracking.enabled is True + assert tracking.track_slow_queries is True + + # Disable tracking. + tracking = await aql.set_tracking(enabled=False) + assert tracking.enabled is False + + # Re-enable. + tracking = await aql.set_tracking(enabled=True, max_slow_queries=5) + assert tracking.enabled is True + assert tracking.max_slow_queries == 5 + + # Exceptions with bad database + with pytest.raises(AQLQueryTrackingGetError): + _ = await bad_db.aql.tracking() + with pytest.raises(AQLQueryTrackingSetError): + _ = await bad_db.aql.set_tracking(enabled=False) + + +@pytest.mark.asyncio +async def test_list_queries(superuser, db, bad_db): + aql = db.aql + + # Do not await, let it run in the background. + long_running_task = asyncio.create_task(aql.execute("RETURN SLEEP(10)")) + time.sleep(1) + + for _ in range(10): + queries = await aql.queries() + if len(queries) > 0: + break + + # Only superuser can list all queries from all databases. + all_queries = await superuser.aql.queries(all_queries=True) + assert len(all_queries) > 0 + + # Only test no-throws. + _ = await aql.slow_queries() + _ = await superuser.aql.slow_queries(all_queries=True) + await aql.clear_slow_queries() + await superuser.aql.clear_slow_queries(all_queries=True) + + with pytest.raises(AQLQueryListError): + _ = await bad_db.aql.queries() + with pytest.raises(AQLQueryListError): + _ = await bad_db.aql.slow_queries() + with pytest.raises(AQLQueryClearError): + await bad_db.aql.clear_slow_queries() + with pytest.raises(AQLQueryListError): + _ = await aql.queries(all_queries=True) + with pytest.raises(AQLQueryListError): + _ = await aql.slow_queries(all_queries=True) + with pytest.raises(AQLQueryClearError): + await aql.clear_slow_queries(all_queries=True) + + long_running_task.cancel() + + +@pytest.mark.asyncio +async def test_kill_query(db, bad_db, superuser): + aql = db.aql + + # Do not await, let it run in the background. + long_running_task = asyncio.create_task(aql.execute("RETURN SLEEP(10)")) + time.sleep(1) + + queries = list() + for _ in range(10): + queries = await aql.queries() + if len(queries) > 0: + break + + # Kill the query + query_id = queries[0]["id"] + assert await aql.kill(query_id) is True + + # Ignore missing + assert await aql.kill("fakeid", ignore_missing=True) is False + assert ( + await superuser.aql.kill("fakeid", ignore_missing=True, all_queries=True) + is False + ) + + # Check exceptions + with pytest.raises(AQLQueryKillError): + await aql.kill("fakeid") + with pytest.raises(AQLQueryKillError): + await bad_db.aql.kill(query_id) + + long_running_task.cancel() + + +@pytest.mark.asyncio +async def test_explain_query(db, doc_col, bad_db): + aql = db.aql + + # Explain a simple query + result = await aql.explain("RETURN 1") + assert "plan" in result + + # Something more complex + options = QueryExplainOptions( + all_plans=True, + max_plans=10, + optimizer={"rules": ["+all", "-use-index-range"]}, + ) + explanations = await aql.explain( + f"FOR d IN {doc_col.name} RETURN d", + options=options, + ) + assert "plans" in explanations + explanation = await aql.explain( + f"FOR d IN {doc_col.name} RETURN d", + options=options.to_dict(), + ) + assert "plans" in explanation + + # Check exceptions + with pytest.raises(AQLQueryExplainError): + _ = await bad_db.aql.explain("RETURN 1") + + +@pytest.mark.asyncio +async def test_validate_query(db, doc_col, bad_db): + aql = db.aql + + # Validate invalid query + with pytest.raises(AQLQueryValidateError) as err: + await aql.validate("INVALID QUERY") + assert err.value.error_code == QUERY_PARSE + + # Test validate valid query + result = await aql.validate(f"FOR d IN {doc_col.name} RETURN d") + assert result["parsed"] is True + + with pytest.raises(AQLQueryValidateError): + _ = await bad_db.aql.validate("RETURN 1") + + +@pytest.mark.asyncio +async def test_query_rules(db, bad_db): + aql = db.aql + + rules = await aql.query_rules() + assert len(rules) > 0 + + with pytest.raises(AQLQueryRulesGetError): + _ = await bad_db.aql.query_rules() diff --git a/tests/test_typings.py b/tests/test_typings.py index 166a0ce..218f421 100644 --- a/tests/test_typings.py +++ b/tests/test_typings.py @@ -10,7 +10,9 @@ QueryExecutionPlan, QueryExecutionProfile, QueryExecutionStats, + QueryExplainOptions, QueryProperties, + QueryTrackingConfiguration, UserInfo, ) @@ -282,3 +284,32 @@ def test_QueryExecutionExtra(): assert isinstance(extra.profile, QueryExecutionProfile) assert isinstance(extra.stats, QueryExecutionStats) assert extra.warnings == [{"code": 123, "message": "test warning"}] + + +def test_QueryTrackingConfiguration(): + data = { + "enabled": True, + "trackSlowQueries": True, + "trackBindVars": True, + "maxSlowQueries": 64, + "slowQueryThreshold": 10, + "slowStreamingQueryThreshold": 10, + "maxQueryStringLength": 4096, + } + config = QueryTrackingConfiguration(data) + assert config.enabled is True + assert config.track_slow_queries is True + assert config.track_bind_vars is True + assert config.max_slow_queries == 64 + assert config.slow_query_threshold == 10 + assert config.slow_streaming_query_threshold == 10 + assert config.max_query_string_length == 4096 + + +def test_QueryExplainOptions(): + options = QueryExplainOptions( + all_plans=True, max_plans=5, optimizer={"rules": ["-all", "+use-index-range"]} + ) + assert options.all_plans is True + assert options.max_plans == 5 + assert options.optimizer == {"rules": ["-all", "+use-index-range"]}