diff --git a/arangoasync/collection.py b/arangoasync/collection.py index ff5de47..69c271e 100644 --- a/arangoasync/collection.py +++ b/arangoasync/collection.py @@ -1,7 +1,7 @@ __all__ = ["Collection", "StandardCollection"] -from typing import Generic, Optional, Tuple, TypeVar, cast +from typing import Generic, List, Optional, Tuple, TypeVar, cast from arangoasync.errno import ( HTTP_BAD_PARAMETER, @@ -14,12 +14,24 @@ DocumentInsertError, DocumentParseError, DocumentRevisionError, + IndexCreateError, + IndexDeleteError, + IndexGetError, + IndexListError, + IndexLoadError, ) from arangoasync.executor import ApiExecutor from arangoasync.request import Method, Request from arangoasync.response import Response from arangoasync.serialization import Deserializer, Serializer -from arangoasync.typings import CollectionProperties, Json, Params, Result +from arangoasync.typings import ( + CollectionProperties, + IndexProperties, + Json, + Jsons, + Params, + Result, +) T = TypeVar("T") U = TypeVar("U") @@ -155,6 +167,179 @@ def db_name(self) -> str: """ return self._executor.db_name + @property + def serializer(self) -> Serializer[Json]: + """Return the serializer.""" + return self._executor.serializer + + @property + def deserializer(self) -> Deserializer[Json, Jsons]: + """Return the deserializer.""" + return self._executor.deserializer + + async def indexes(self) -> Result[List[IndexProperties]]: + """Fetch all index descriptions for the given collection. + + Returns: + list: List of index properties. + + Raises: + IndexListError: If retrieval fails. + + References: + - `list-all-indexes-of-a-collection `__ + """ # noqa: E501 + request = Request( + method=Method.GET, + endpoint="/_api/index", + params=dict(collection=self._name), + ) + + def response_handler(resp: Response) -> List[IndexProperties]: + if not resp.is_success: + raise IndexListError(resp, request) + data = self.deserializer.loads(resp.raw_body) + return [IndexProperties(item) for item in data["indexes"]] + + return await self._executor.execute(request, response_handler) + + async def get_index(self, id: str | int) -> Result[IndexProperties]: + """Return the properties of an index. + + Args: + id (str): Index ID. Could be either the full ID or just the index number. + + Returns: + IndexProperties: Index properties. + + Raises: + IndexGetError: If retrieval fails. + + References: + `get-an-index `__ + """ # noqa: E501 + if isinstance(id, int): + full_id = f"{self._name}/{id}" + else: + full_id = id if "/" in id else f"{self._name}/{id}" + + request = Request( + method=Method.GET, + endpoint=f"/_api/index/{full_id}", + ) + + def response_handler(resp: Response) -> IndexProperties: + if not resp.is_success: + raise IndexGetError(resp, request) + return IndexProperties(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + async def add_index( + self, + type: str, + fields: Json | List[str], + options: Optional[Json] = None, + ) -> Result[IndexProperties]: + """Create an index. + + Args: + type (str): Type attribute (ex. "persistent", "inverted", "ttl", "mdi", + "geo"). + fields (dict | list): Fields to index. + options (dict | None): Additional index options. + + Returns: + IndexProperties: New index properties. + + Raises: + IndexCreateError: If index creation fails. + + References: + - `create-an-index `__ + - `create-a-persistent-index `__ + - `create-an-inverted-index `__ + - `create-a-ttl-index `__ + - `create-a-multi-dimensional-index `__ + - `create-a-geo-spatial-index `__ + """ # noqa: E501 + options = options or {} + request = Request( + method=Method.POST, + endpoint="/_api/index", + data=self.serializer.dumps(dict(type=type, fields=fields) | options), + params=dict(collection=self._name), + ) + + def response_handler(resp: Response) -> IndexProperties: + if not resp.is_success: + raise IndexCreateError(resp, request) + return IndexProperties(self.deserializer.loads(resp.raw_body)) + + return await self._executor.execute(request, response_handler) + + async def delete_index( + self, id: str | int, ignore_missing: bool = False + ) -> Result[bool]: + """Delete an index. + + Args: + id (str): Index ID. Could be either the full ID or just the index number. + ignore_missing (bool): Do not raise an exception on missing index. + + Returns: + bool: `True` if the operation was successful. `False` if the index was not + found and **ignore_missing** was set to `True`. + + Raises: + IndexDeleteError: If deletion fails. + + References: + - `delete-an-index `__ + """ # noqa: E501 + if isinstance(id, int): + full_id = f"{self._name}/{id}" + else: + full_id = id if "/" in id else f"{self._name}/{id}" + + request = Request( + method=Method.DELETE, + endpoint=f"/_api/index/{full_id}", + ) + + def response_handler(resp: Response) -> bool: + if resp.is_success: + return True + elif ignore_missing and resp.status_code == HTTP_NOT_FOUND: + return False + raise IndexDeleteError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def load_indexes(self) -> Result[bool]: + """Cache this collection’s index entries in the main memory. + + Returns: + bool: `True` if the operation was successful. + + Raises: + IndexLoadError: If loading fails. + + References: + - `load-collection-indexes-into-memory `__ + """ # noqa: E501 + request = Request( + method=Method.PUT, + endpoint=f"/_api/collection/{self._name}/loadIndexesIntoMemory", + ) + + def response_handler(resp: Response) -> bool: + if resp.is_success: + return True + raise IndexLoadError(resp, request) + + return await self._executor.execute(request, response_handler) + class StandardCollection(Collection[T, U, V]): """Standard collection API wrapper. diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index 9a5eb1b..62b871a 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -135,6 +135,26 @@ class DocumentRevisionError(ArangoServerError): """The expected and actual document revisions mismatched.""" +class IndexCreateError(ArangoServerError): + """Failed to create collection index.""" + + +class IndexDeleteError(ArangoServerError): + """Failed to delete collection index.""" + + +class IndexGetError(ArangoServerError): + """Failed to retrieve collection index.""" + + +class IndexListError(ArangoServerError): + """Failed to retrieve collection indexes.""" + + +class IndexLoadError(ArangoServerError): + """Failed to load indexes into memory.""" + + class JWTRefreshError(ArangoClientError): """Failed to refresh the JWT token.""" diff --git a/arangoasync/typings.py b/arangoasync/typings.py index 95f2d65..0744ada 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -815,3 +815,234 @@ def format(self, formatter: Optional[Formatter] = None) -> Json: if formatter is not None: return super().format(formatter) return self.compatibility_formatter(self._data) + + +class IndexProperties(JsonWrapper): + """Properties of an index. + + Example: + .. code-block:: json + + { + "fields" : [ + "_key" + ], + "id" : "products/0", + "name" : "primary", + "selectivityEstimate" : 1, + "sparse" : false, + "type" : "primary", + "unique" : true, + } + + References: + - `get-an-index `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def id(self) -> str: + return self._data["id"] # type: ignore[no-any-return] + + @property + def numeric_id(self) -> int: + return int(self._data["id"].split("/", 1)[-1]) + + @property + def type(self) -> str: + return self._data["type"] # type: ignore[no-any-return] + + @property + def fields(self) -> Json | List[str]: + return self._data["fields"] # type: ignore[no-any-return] + + @property + def name(self) -> Optional[str]: + return self._data.get("name") + + @property + def deduplicate(self) -> Optional[bool]: + return self._data.get("deduplicate") + + @property + def sparse(self) -> Optional[bool]: + return self._data.get("sparse") + + @property + def unique(self) -> Optional[bool]: + return self._data.get("unique") + + @property + def geo_json(self) -> Optional[bool]: + return self._data.get("geoJson") + + @property + def selectivity_estimate(self) -> Optional[float]: + return self._data.get("selectivityEstimate") + + @property + def is_newly_created(self) -> Optional[bool]: + return self._data.get("isNewlyCreated") + + @property + def expire_after(self) -> Optional[int]: + return self._data.get("expireAfter") + + @property + def in_background(self) -> Optional[bool]: + return self._data.get("inBackground") + + @property + def max_num_cover_cells(self) -> Optional[int]: + return self._data.get("maxNumCoverCells") + + @property + def cache_enabled(self) -> Optional[bool]: + return self._data.get("cacheEnabled") + + @property + def legacy_polygons(self) -> Optional[bool]: + return self._data.get("legacyPolygons") + + @property + def estimates(self) -> Optional[bool]: + return self._data.get("estimates") + + @property + def analyzer(self) -> Optional[str]: + return self._data.get("analyzer") + + @property + def cleanup_interval_step(self) -> Optional[int]: + return self._data.get("cleanupIntervalStep") + + @property + def commit_interval_msec(self) -> Optional[int]: + return self._data.get("commitIntervalMsec") + + @property + def consolidation_interval_msec(self) -> Optional[int]: + return self._data.get("consolidationIntervalMsec") + + @property + def consolidation_policy(self) -> Optional[Json]: + return self._data.get("consolidationPolicy") + + @property + def primary_sort(self) -> Optional[Json]: + return self._data.get("primarySort") + + @property + def stored_values(self) -> Optional[List[Any]]: + return self._data.get("storedValues") + + @property + def write_buffer_active(self) -> Optional[int]: + return self._data.get("writeBufferActive") + + @property + def write_buffer_idle(self) -> Optional[int]: + return self._data.get("writeBufferIdle") + + @property + def write_buffer_size_max(self) -> Optional[int]: + return self._data.get("writeBufferSizeMax") + + @property + def primary_key_cache(self) -> Optional[bool]: + return self._data.get("primaryKeyCache") + + @property + def parallelism(self) -> Optional[int]: + return self._data.get("parallelism") + + @property + def optimize_top_k(self) -> Optional[List[str]]: + return self._data.get("optimizeTopK") + + @property + def track_list_positions(self) -> Optional[bool]: + return self._data.get("trackListPositions") + + @property + def version(self) -> Optional[int]: + return self._data.get("version") + + @property + def include_all_fields(self) -> Optional[bool]: + return self._data.get("includeAllFields") + + @property + def features(self) -> Optional[List[str]]: + return self._data.get("features") + + @staticmethod + def compatibility_formatter(data: Json) -> Json: + """python-arango compatibility formatter.""" + result = {"id": data["id"].split("/", 1)[-1], "fields": data["fields"]} + if "type" in data: + result["type"] = data["type"] + if "name" in data: + result["name"] = data["name"] + if "deduplicate" in data: + result["deduplicate"] = data["deduplicate"] + if "sparse" in data: + result["sparse"] = data["sparse"] + if "unique" in data: + result["unique"] = data["unique"] + if "geoJson" in data: + result["geo_json"] = data["geoJson"] + if "selectivityEstimate" in data: + result["selectivity"] = data["selectivityEstimate"] + if "isNewlyCreated" in data: + result["new"] = data["isNewlyCreated"] + if "expireAfter" in data: + result["expiry_time"] = data["expireAfter"] + if "inBackground" in data: + result["in_background"] = data["inBackground"] + if "maxNumCoverCells" in data: + result["max_num_cover_cells"] = data["maxNumCoverCells"] + if "storedValues" in data: + result["storedValues"] = data["storedValues"] + if "legacyPolygons" in data: + result["legacyPolygons"] = data["legacyPolygons"] + if "estimates" in data: + result["estimates"] = data["estimates"] + if "analyzer" in data: + result["analyzer"] = data["analyzer"] + if "cleanupIntervalStep" in data: + result["cleanup_interval_step"] = data["cleanupIntervalStep"] + if "commitIntervalMsec" in data: + result["commit_interval_msec"] = data["commitIntervalMsec"] + if "consolidationIntervalMsec" in data: + result["consolidation_interval_msec"] = data["consolidationIntervalMsec"] + if "consolidationPolicy" in data: + result["consolidation_policy"] = data["consolidationPolicy"] + if "features" in data: + result["features"] = data["features"] + if "primarySort" in data: + result["primary_sort"] = data["primarySort"] + if "trackListPositions" in data: + result["track_list_positions"] = data["trackListPositions"] + if "version" in data: + result["version"] = data["version"] + if "writebufferIdle" in data: + result["writebuffer_idle"] = data["writebufferIdle"] + if "writebufferActive" in data: + result["writebuffer_active"] = data["writebufferActive"] + if "writebufferSizeMax" in data: + result["writebuffer_max_size"] = data["writebufferSizeMax"] + if "optimizeTopK" in data: + result["optimizeTopK"] = data["optimizeTopK"] + return result + + def format(self, formatter: Optional[Formatter] = None) -> Json: + """Apply a formatter to the data. + + By default, the python-arango compatibility formatter is applied. + """ + if formatter is not None: + return super().format(formatter) + return self.compatibility_formatter(self._data) diff --git a/tests/test_collection.py b/tests/test_collection.py index 8a3ac4b..72f6583 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -1,6 +1,16 @@ +import asyncio + import pytest -from arangoasync.exceptions import CollectionPropertiesError +from arangoasync.errno import DATA_SOURCE_NOT_FOUND, INDEX_NOT_FOUND +from arangoasync.exceptions import ( + CollectionPropertiesError, + IndexCreateError, + IndexDeleteError, + IndexGetError, + IndexListError, + IndexLoadError, +) def test_collection_attributes(db, doc_col): @@ -18,3 +28,132 @@ async def test_collection_misc_methods(doc_col, bad_col): assert len(properties.format()) > 1 with pytest.raises(CollectionPropertiesError): await bad_col.properties() + + +@pytest.mark.asyncio +async def test_collection_index(doc_col, bad_col, cluster): + # Create indexes + idx1 = await doc_col.add_index( + type="persistent", + fields=["_key"], + options={ + "unique": True, + "name": "idx1", + }, + ) + assert idx1.id is not None + assert idx1.id == f"{doc_col.name}/{idx1.numeric_id}" + assert idx1.type == "persistent" + assert idx1["type"] == "persistent" + assert idx1.fields == ["_key"] + assert idx1.name == "idx1" + assert idx1["unique"] is True + assert idx1.unique is True + assert idx1.format()["id"] == str(idx1.numeric_id) + + idx2 = await doc_col.add_index( + type="inverted", + fields=[{"name": "attr1", "cache": True}], + options={ + "unique": False, + "sparse": True, + "name": "idx2", + "storedValues": [{"fields": ["a"], "compression": "lz4", "cache": True}], + "includeAllFields": True, + "analyzer": "identity", + "primarySort": { + "cache": True, + "fields": [{"field": "a", "direction": "asc"}], + }, + }, + ) + assert idx2.id is not None + assert idx2.id == f"{doc_col.name}/{idx2.numeric_id}" + assert idx2.type == "inverted" + assert idx2["fields"][0]["name"] == "attr1" + assert idx2.name == "idx2" + assert idx2.include_all_fields is True + assert idx2.analyzer == "identity" + assert idx2.sparse is True + assert idx2.unique is False + + idx3 = await doc_col.add_index( + type="geo", + fields=["location"], + options={ + "geoJson": True, + "name": "idx3", + "inBackground": True, + }, + ) + assert idx3.id is not None + assert idx3.type == "geo" + assert idx3.fields == ["location"] + assert idx3.name == "idx3" + assert idx3.geo_json is True + if cluster: + assert idx3.in_background is True + + with pytest.raises(IndexCreateError): + await bad_col.add_index(type="persistent", fields=["_key"]) + + # List all indexes + indexes = await doc_col.indexes() + assert len(indexes) > 3, indexes + found_idx1 = found_idx2 = found_idx3 = False + for idx in indexes: + if idx.id == idx1.id: + found_idx1 = True + elif idx.id == idx2.id: + found_idx2 = True + elif idx.id == idx3.id: + found_idx3 = True + assert found_idx1 is True, indexes + assert found_idx2 is True, indexes + assert found_idx3 is True, indexes + + with pytest.raises(IndexListError) as err: + await bad_col.indexes() + assert err.value.error_code == DATA_SOURCE_NOT_FOUND + + # Get an index + get1, get2, get3 = await asyncio.gather( + doc_col.get_index(idx1.id), + doc_col.get_index(idx2.numeric_id), + doc_col.get_index(str(idx3.numeric_id)), + ) + assert get1.id == idx1.id + assert get1.type == idx1.type + assert get1.name == idx1.name + assert get2.id == idx2.id + assert get2.type == idx2.type + assert get2.name == idx2.name + assert get3.id == idx3.id + assert get3.type == idx3.type + assert get3.name == idx3.name + + with pytest.raises(IndexGetError) as err: + await doc_col.get_index("non-existent") + assert err.value.error_code == INDEX_NOT_FOUND + + # Load indexes into main memory + assert await doc_col.load_indexes() is True + with pytest.raises(IndexLoadError) as err: + await bad_col.load_indexes() + assert err.value.error_code == DATA_SOURCE_NOT_FOUND + + # Delete indexes + del1, del2, del3 = await asyncio.gather( + doc_col.delete_index(idx1.id), + doc_col.delete_index(idx2.numeric_id), + doc_col.delete_index(str(idx3.numeric_id)), + ) + assert del1 is True + assert del2 is True + assert del3 is True + + # Now, the indexes should be gone + with pytest.raises(IndexDeleteError) as err: + await doc_col.delete_index(idx1.id) + assert err.value.error_code == INDEX_NOT_FOUND + assert await doc_col.delete_index(idx2.id, ignore_missing=True) is False