diff --git a/CHANGELOG.md b/CHANGELOG.md index c60614e5..67882e2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed +- Removed database logic from core.py all_collections [#196](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/196) - Changed OpenSearch config ssl_version to SSLv23 [#200](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/200) ### Fixed diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index 63c43944..99b58e16 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -1,7 +1,6 @@ """Item crud client.""" import logging import re -from base64 import urlsafe_b64encode from datetime import datetime as datetime_type from datetime import timezone from typing import Any, Dict, List, Optional, Set, Type, Union @@ -193,49 +192,24 @@ async def landing_page(self, **kwargs) -> stac_types.LandingPage: async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database. - Returns: - Collections: A `Collections` object containing all the collections in the database and - links to various resources. + Args: + **kwargs: Keyword arguments from the request. - Raises: - Exception: If any error occurs while reading the collections from the database. + Returns: + A Collections object containing all the collections in the database and links to various resources. """ - request: Request = kwargs["request"] - base_url = str(kwargs["request"].base_url) + request = kwargs["request"] + base_url = str(request.base_url) + limit = int(request.query_params.get("limit", 10)) + token = request.query_params.get("token") - limit = ( - int(request.query_params["limit"]) - if "limit" in request.query_params - else 10 - ) - token = ( - request.query_params["token"] if "token" in request.query_params else None + collections, next_token = await self.database.get_all_collections( + token=token, limit=limit, base_url=base_url ) - hits = await self.database.get_all_collections(limit=limit, token=token) - - next_search_after = None - next_link = None - if len(hits) == limit: - last_hit = hits[-1] - next_search_after = last_hit["sort"] - next_token = urlsafe_b64encode( - ",".join(map(str, next_search_after)).encode() - ).decode() - paging_links = PagingLinks(next=next_token, request=request) - next_link = paging_links.link_next() - links = [ - { - "rel": Relations.root.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.parent.value, - "type": MimeTypes.json, - "href": base_url, - }, + {"rel": Relations.root.value, "type": MimeTypes.json, "href": base_url}, + {"rel": Relations.parent.value, "type": MimeTypes.json, "href": base_url}, { "rel": Relations.self.value, "type": MimeTypes.json, @@ -243,16 +217,11 @@ async def all_collections(self, **kwargs) -> Collections: }, ] - if next_link: + if next_token: + next_link = PagingLinks(next=next_token, request=request).link_next() links.append(next_link) - return Collections( - collections=[ - self.collection_serializer.db_to_stac(c["_source"], base_url=base_url) - for c in hits - ], - links=links, - ) + return Collections(collections=collections, links=links) async def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get a collection from the database by its id. @@ -269,7 +238,9 @@ async def get_collection(self, collection_id: str, **kwargs) -> Collection: """ base_url = str(kwargs["request"].base_url) collection = await self.database.find_collection(collection_id=collection_id) - return self.collection_serializer.db_to_stac(collection, base_url) + return self.collection_serializer.db_to_stac( + collection=collection, base_url=base_url + ) async def item_collection( self, diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index ed0434f5..87ca8916 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -291,33 +291,43 @@ class DatabaseLogic: """CORE LOGIC""" async def get_all_collections( - self, token: Optional[str], limit: int - ) -> Iterable[Dict[str, Any]]: - """Retrieve a list of all collections from the database. + self, token: Optional[str], limit: int, base_url: str + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + """Retrieve a list of all collections from Elasticsearch, supporting pagination. Args: - token (Optional[str]): The token used to return the next set of results. - limit (int): Number of results to return + token (Optional[str]): The pagination token. + limit (int): The number of results to return. Returns: - collections (Iterable[Dict[str, Any]]): A list of dictionaries containing the source data for each collection. - - Notes: - The collections are retrieved from the Elasticsearch database using the `client.search` method, - with the `COLLECTIONS_INDEX` as the target index and `size=limit` to retrieve records. - The result is a generator of dictionaries containing the source data for each collection. + A tuple of (collections, next pagination token if any). """ search_after = None if token: - search_after = urlsafe_b64decode(token.encode()).decode().split(",") - collections = await self.client.search( + search_after = [token] + + response = await self.client.search( index=COLLECTIONS_INDEX, - search_after=search_after, - size=limit, - sort={"id": {"order": "asc"}}, + body={ + "sort": [{"id": {"order": "asc"}}], + "size": limit, + "search_after": search_after, + }, ) - hits = collections["hits"]["hits"] - return hits + + hits = response["hits"]["hits"] + collections = [ + self.collection_serializer.db_to_stac( + collection=hit["_source"], base_url=base_url + ) + for hit in hits + ] + + next_token = None + if len(hits) == limit: + next_token = hits[-1]["sort"][0] + + return collections, next_token async def get_one_item(self, collection_id: str, item_id: str) -> Dict: """Retrieve a single item from the database. diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index efedbd5f..0f4bf9cf 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -312,36 +312,49 @@ class DatabaseLogic: """CORE LOGIC""" async def get_all_collections( - self, - token: Optional[str], - limit: int, - ) -> Iterable[Dict[str, Any]]: - """Retrieve a list of all collections from the database. + self, token: Optional[str], limit: int, base_url: str + ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + """ + Retrieve a list of all collections from Opensearch, supporting pagination. Args: - token (Optional[str]): The token used to return the next set of results. - limit (int): Number of results to return + token (Optional[str]): The pagination token. + limit (int): The number of results to return. Returns: - collections (Iterable[Dict[str, Any]]): A list of dictionaries containing the source data for each collection. - - Notes: - The collections are retrieved from the Elasticsearch database using the `client.search` method, - with the `COLLECTIONS_INDEX` as the target index and `size=limit` to retrieve records. - The result is a generator of dictionaries containing the source data for each collection. + A tuple of (collections, next pagination token if any). """ - search_body: Dict[str, Any] = {} + search_body = { + "sort": [{"id": {"order": "asc"}}], + "size": limit, + } + + # Only add search_after to the query if token is not None and not empty if token: - search_after = urlsafe_b64decode(token.encode()).decode().split(",") + search_after = [token] search_body["search_after"] = search_after - search_body["sort"] = {"id": {"order": "asc"}} - - collections = await self.client.search( - index=COLLECTIONS_INDEX, body=search_body, size=limit + response = await self.client.search( + index="collections", + body=search_body, ) - hits = collections["hits"]["hits"] - return hits + + hits = response["hits"]["hits"] + collections = [ + self.collection_serializer.db_to_stac( + collection=hit["_source"], base_url=base_url + ) + for hit in hits + ] + + next_token = None + if len(hits) == limit: + # Ensure we have a valid sort value for next_token + next_token_values = hits[-1].get("sort") + if next_token_values: + next_token = next_token_values[0] + + return collections, next_token async def get_one_item(self, collection_id: str, item_id: str) -> Dict: """Retrieve a single item from the database.