From d74e386536564e7504edae176fc09ddd796845ba Mon Sep 17 00:00:00 2001 From: Phil Varner Date: Thu, 24 Mar 2022 10:40:58 -0400 Subject: [PATCH 1/3] convert all uses of query to filter --- README.md | 3 +- .../stac_fastapi/elasticsearch/core.py | 33 ++-- .../elasticsearch/database_logic.py | 163 ++++++++---------- 3 files changed, 94 insertions(+), 105 deletions(-) diff --git a/README.md b/README.md index f253448d..a58faa53 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,7 @@ For changes, see the [Changelog](CHANGELOG.md). To install the classes in your local Python env, run: ```shell -cd stac_fastapi/elasticsearch -pip install -e '.[dev]' +pip install -e 'stac_fastapi/elasticsearch[dev]' ``` ### Pre-commit diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index 4479dc6a..f786f8d2 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -91,17 +91,18 @@ async def item_collection( links = [] base_url = str(kwargs["request"].base_url) - serialized_children, count = await self.database.get_item_collection( + serialized_children, maybe_count = await self.database.get_collection_items( collection_id=collection_id, limit=limit, base_url=base_url ) context_obj = None if self.extension_is_enabled("ContextExtension"): context_obj = { - "returned": count if count is not None and count < limit else limit, + "returned": len(serialized_children), "limit": limit, - "matched": count, } + if maybe_count is not None: + context_obj["matched"] = maybe_count return ItemCollection( type="FeatureCollection", @@ -207,29 +208,29 @@ async def post_search( ) -> ItemCollection: """POST search catalog.""" base_url = str(kwargs["request"].base_url) - search = self.database.create_search() + search = self.database.make_search() if search_request.query: for (field_name, expr) in search_request.query.items(): field = "properties__" + field_name for (op, value) in expr.items(): - search = self.database.create_query_filter( + search = self.database.apply_stacql_filter( search=search, op=op, field=field, value=value ) if search_request.ids: - search = self.database.search_ids( + search = self.database.apply_ids_filter( search=search, item_ids=search_request.ids ) if search_request.collections: - search = self.database.filter_collections( + search = self.database.apply_collections_filter( search=search, collection_ids=search_request.collections ) if search_request.datetime: datetime_search = self._return_date(search_request.datetime) - search = self.database.search_datetime( + search = self.database.apply_datetime_filter( search=search, datetime_search=datetime_search ) @@ -238,10 +239,10 @@ async def post_search( if len(bbox) == 6: bbox = [bbox[0], bbox[1], bbox[3], bbox[4]] - search = self.database.search_bbox(search=search, bbox=bbox) + search = self.database.apply_bbox_filter(search=search, bbox=bbox) if search_request.intersects: - self.database.search_intersects( + self.database.apply_intersects_filter( search=search, intersects=search_request.intersects ) @@ -249,11 +250,11 @@ async def post_search( for sort in search_request.sortby: if sort.field == "datetime": sort.field = "properties__datetime" - search = self.database.sort_field( + search = self.database.apply_sort( search=search, field=sort.field, direction=sort.direction ) - count = await self.database.search_count(search=search) + maybe_count = await self.database.search_count(search=search) response_features = await self.database.execute_search( search=search, limit=search_request.limit, base_url=base_url @@ -289,16 +290,16 @@ async def post_search( context_obj = None if self.extension_is_enabled("ContextExtension"): context_obj = { - "returned": count if count < limit else limit, + "returned": len(response_features), "limit": limit, - "matched": count, } + if maybe_count is not None: + context_obj["matched"] = maybe_count - links = [] return ItemCollection( type="FeatureCollection", features=response_features, - links=links, + links=[], context=context_obj, ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 27aa3dce..370551dd 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -31,6 +31,17 @@ ITEMS_INDEX = "stac_items" COLLECTIONS_INDEX = "stac_collections" +DEFAULT_SORT = { + "properties.datetime": {"order": "desc"}, + "id": {"order": "desc"}, + "collection": {"order": "desc"}, +} + + +def bbox2polygon(b0, b1, b2, b3): + """Transform bbox to polygon.""" + return [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] + def mk_item_id(item_id: str, collection_id: str): """Make the Elasticsearch document _id value from the Item id and collection.""" @@ -43,6 +54,7 @@ class DatabaseLogic: client = AsyncElasticsearchSettings().create_client sync_client = SyncElasticsearchSettings().create_client + item_serializer: Type[serializers.ItemSerializer] = attr.ib( default=serializers.ItemSerializer ) @@ -54,43 +66,30 @@ class DatabaseLogic: async def get_all_collections(self, base_url: str) -> List[Collection]: """Database logic to retrieve a list of all collections.""" - collections = await self.client.search( - index=COLLECTIONS_INDEX, query={"match_all": {}} - ) + # https://github.com/stac-utils/stac-fastapi-elasticsearch/issues/65 + # collections should be paginated, but at least return more than the default 10 for now + collections = await self.client.search(index=COLLECTIONS_INDEX, size=1000) return [ self.collection_serializer.db_to_stac(c["_source"], base_url=base_url) for c in collections["hits"]["hits"] ] - async def search_count(self, search: Search) -> int: + async def search_count(self, search: Search) -> Optional[int]: """Database logic to count search results.""" return ( await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) ).get("count") - async def get_item_collection( + async def get_collection_items( self, collection_id: str, limit: int, base_url: str ) -> Tuple[List[Item], Optional[int]]: """Database logic to retrieve an ItemCollection and a count of items contained.""" - search = self.create_search() - search = search.filter("term", collection=collection_id) - - count = await self.search_count(search) - - search = search.query()[0:limit] - - body = search.to_dict() - es_response = await self.client.search( - index=ITEMS_INDEX, query=body["query"], sort=body.get("sort") - ) - - serialized_children = [ - self.item_serializer.db_to_stac(item["_source"], base_url=base_url) - for item in es_response["hits"]["hits"] - ] + search = self.apply_collections_filter(Search(), [collection_id]) + items = await self.execute_search(search=search, limit=limit, base_url=base_url) + maybe_count = await self.search_count(search) - return serialized_children, count + return items, maybe_count async def get_one_item(self, collection_id: str, item_id: str) -> Dict: """Database logic to retrieve a single item.""" @@ -105,48 +104,26 @@ async def get_one_item(self, collection_id: str, item_id: str) -> Dict: return item["_source"] @staticmethod - def create_search(): - """Database logic to create a nosql Search instance.""" - return Search().sort( - {"properties.datetime": {"order": "desc"}}, - {"id": {"order": "desc"}}, - {"collection": {"order": "desc"}}, - ) - - @staticmethod - def create_query_filter(search: Search, op: str, field: str, value: float): - """Database logic to perform query for search endpoint.""" - if op != "eq": - key_filter = {field: {f"{op}": value}} - search = search.query(Q("range", **key_filter)) - else: - search = search.query("match_phrase", **{field: value}) - - return search + def make_search(): + """Database logic to create a Search instance.""" + return Search().sort(*DEFAULT_SORT) @staticmethod - def search_ids(search: Search, item_ids: List[str]): + def apply_ids_filter(search: Search, item_ids: List[str]): """Database logic to search a list of STAC item ids.""" - return search.query( - Q("bool", should=[Q("term", **{"id": i_id}) for i_id in item_ids]) - ) + return search.filter("terms", id=item_ids) @staticmethod - def filter_collections(search: Search, collection_ids: List): + def apply_collections_filter(search: Search, collection_ids: List[str]): """Database logic to search a list of STAC collection ids.""" - return search.query( - Q( - "bool", - should=[Q("term", **{"collection": c_id}) for c_id in collection_ids], - ) - ) + return search.filter("terms", collection=collection_ids) @staticmethod - def search_datetime(search: Search, datetime_search): + def apply_datetime_filter(search: Search, datetime_search): """Database logic to search datetime field.""" if "eq" in datetime_search: - search = search.query( - "match_phrase", **{"properties__datetime": datetime_search["eq"]} + search = search.filter( + "term", **{"properties__datetime": datetime_search["eq"]} ) else: search = search.filter( @@ -158,24 +135,28 @@ def search_datetime(search: Search, datetime_search): return search @staticmethod - def search_bbox(search: Search, bbox: List): + def apply_bbox_filter(search: Search, bbox: List): """Database logic to search on bounding box.""" - polygon = DatabaseLogic.bbox2polygon(bbox[0], bbox[1], bbox[2], bbox[3]) - bbox_filter = Q( - { - "geo_shape": { - "geometry": { - "shape": {"type": "polygon", "coordinates": polygon}, - "relation": "intersects", + return search.filter( + Q( + { + "geo_shape": { + "geometry": { + "shape": { + "type": "polygon", + "coordinates": bbox2polygon( + bbox[0], bbox[1], bbox[2], bbox[3] + ), + }, + "relation": "intersects", + } } } - } + ) ) - search = search.query(bbox_filter) - return search @staticmethod - def search_intersects( + def apply_intersects_filter( search: Search, intersects: Union[ Point, @@ -188,33 +169,45 @@ def search_intersects( ], ): """Database logic to search a geojson object.""" - intersect_filter = Q( - { - "geo_shape": { - "geometry": { - "shape": { - "type": intersects.type.lower(), - "coordinates": intersects.coordinates, - }, - "relation": "intersects", + return search.filter( + Q( + { + "geo_shape": { + "geometry": { + "shape": { + "type": intersects.type.lower(), + "coordinates": intersects.coordinates, + }, + "relation": "intersects", + } } } - } + ) ) - search = search.query(intersect_filter) + + @staticmethod + def apply_stacql_filter(search: Search, op: str, field: str, value: float): + """Database logic to perform query for search endpoint.""" + if op != "eq": + key_filter = {field: {f"{op}": value}} + search = search.filter(Q("range", **key_filter)) + else: + search = search.filter("term", **{field: value}) + return search @staticmethod - def sort_field(search: Search, field, direction): + def apply_sort(search: Search, field, direction): """Database logic to sort search instance.""" return search.sort({field: {"order": direction}}) - async def execute_search(self, search, limit: int, base_url: str) -> List: + async def execute_search(self, search, limit: int, base_url: str) -> List[Item]: """Database logic to execute search with limit.""" - search = search.query()[0:limit] + search = search[0:limit] body = search.to_dict() + es_response = await self.client.search( - index=ITEMS_INDEX, query=body["query"], sort=body.get("sort") + index=ITEMS_INDEX, query=body.get("query"), sort=body.get("sort") ) return [ @@ -330,7 +323,8 @@ def bulk_sync(self, processed_items, refresh: bool = False): self.sync_client, self._mk_actions(processed_items), refresh=refresh ) - def _mk_actions(self, processed_items): + @staticmethod + def _mk_actions(processed_items): return [ { "_index": ITEMS_INDEX, @@ -357,8 +351,3 @@ async def delete_collections(self) -> None: body={"query": {"match_all": {}}}, wait_for_completion=True, ) - - @staticmethod - def bbox2polygon(b0, b1, b2, b3): - """Transform bbox to polygon.""" - return [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] From 561f3fe1dd5c173638ba6c21e2767c60976d2b30 Mon Sep 17 00:00:00 2001 From: Phil Varner Date: Thu, 24 Mar 2022 11:36:28 -0400 Subject: [PATCH 2/3] refactor bbox handling --- .../stac_fastapi/elasticsearch/core.py | 16 ++++++---------- .../stac_fastapi/elasticsearch/database_logic.py | 4 +--- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index f786f8d2..267bfd74 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -50,9 +50,7 @@ class CoreClient(AsyncBaseCoreClient): async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" base_url = str(kwargs["request"].base_url) - serialized_collections = await self.database.get_all_collections( - base_url=base_url - ) + collection_list = await self.database.get_all_collections(base_url=base_url) links = [ { @@ -71,10 +69,8 @@ async def all_collections(self, **kwargs) -> Collections: "href": urljoin(base_url, "collections"), }, ] - collection_list = Collections( - collections=serialized_collections or [], links=links - ) - return collection_list + + return Collections(collections=collection_list, links=links) @overrides async def get_collection(self, collection_id: str, **kwargs) -> Collection: @@ -91,14 +87,14 @@ async def item_collection( links = [] base_url = str(kwargs["request"].base_url) - serialized_children, maybe_count = await self.database.get_collection_items( + items, maybe_count = await self.database.get_collection_items( collection_id=collection_id, limit=limit, base_url=base_url ) context_obj = None if self.extension_is_enabled("ContextExtension"): context_obj = { - "returned": len(serialized_children), + "returned": len(items), "limit": limit, } if maybe_count is not None: @@ -106,7 +102,7 @@ async def item_collection( return ItemCollection( type="FeatureCollection", - features=serialized_children, + features=items, links=links, context=context_obj, ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 370551dd..e0373b31 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -144,9 +144,7 @@ def apply_bbox_filter(search: Search, bbox: List): "geometry": { "shape": { "type": "polygon", - "coordinates": bbox2polygon( - bbox[0], bbox[1], bbox[2], bbox[3] - ), + "coordinates": bbox2polygon(*bbox), }, "relation": "intersects", } From da30e88a61b6592a0118373f10e9d31a1b6c89de Mon Sep 17 00:00:00 2001 From: Phil Varner Date: Thu, 24 Mar 2022 11:44:26 -0400 Subject: [PATCH 3/3] refactor execute_search to return count also --- .../stac_fastapi/elasticsearch/core.py | 4 +--- .../elasticsearch/database_logic.py | 23 +++++++++++-------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index 267bfd74..f62e5038 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -250,9 +250,7 @@ async def post_search( search=search, field=sort.field, direction=sort.direction ) - maybe_count = await self.database.search_count(search=search) - - response_features = await self.database.execute_search( + response_features, maybe_count = await self.database.execute_search( search=search, limit=search_request.limit, base_url=base_url ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index e0373b31..16d1ce21 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -75,19 +75,14 @@ async def get_all_collections(self, base_url: str) -> List[Collection]: for c in collections["hits"]["hits"] ] - async def search_count(self, search: Search) -> Optional[int]: - """Database logic to count search results.""" - return ( - await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) - ).get("count") - async def get_collection_items( self, collection_id: str, limit: int, base_url: str ) -> Tuple[List[Item], Optional[int]]: """Database logic to retrieve an ItemCollection and a count of items contained.""" search = self.apply_collections_filter(Search(), [collection_id]) - items = await self.execute_search(search=search, limit=limit, base_url=base_url) - maybe_count = await self.search_count(search) + items, maybe_count = await self.execute_search( + search=search, limit=limit, base_url=base_url + ) return items, maybe_count @@ -199,20 +194,28 @@ def apply_sort(search: Search, field, direction): """Database logic to sort search instance.""" return search.sort({field: {"order": direction}}) - async def execute_search(self, search, limit: int, base_url: str) -> List[Item]: + async def execute_search( + self, search, limit: int, base_url: str + ) -> Tuple[List[Item], Optional[int]]: """Database logic to execute search with limit.""" search = search[0:limit] body = search.to_dict() + maybe_count = ( + await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) + ).get("count") + es_response = await self.client.search( index=ITEMS_INDEX, query=body.get("query"), sort=body.get("sort") ) - return [ + items = [ self.item_serializer.db_to_stac(hit["_source"], base_url=base_url) for hit in es_response["hits"]["hits"] ] + return items, maybe_count + """ TRANSACTION LOGIC """ async def check_collection_exists(self, collection_id: str):