diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 7319b962..1ca72015 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,4 +1,5 @@ """Database logic.""" +import asyncio import logging from base64 import urlsafe_b64decode, urlsafe_b64encode from typing import Dict, List, Optional, Tuple, Type, Union @@ -196,24 +197,28 @@ async def execute_search( base_url: str, ) -> Tuple[List[Item], Optional[int], Optional[str]]: """Database logic to execute search with limit.""" - body = search.to_dict() - - maybe_count = ( - await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) - ).get("count") - search_after = None if token: search_after = urlsafe_b64decode(token.encode()).decode().split(",") - es_response = await self.client.search( - index=ITEMS_INDEX, - query=body.get("query"), - sort=sort or DEFAULT_SORT, - search_after=search_after, - size=limit, + query = search.query.to_dict() if search.query else None + + search_task = asyncio.create_task( + self.client.search( + index=ITEMS_INDEX, + query=query, + sort=sort or DEFAULT_SORT, + search_after=search_after, + size=limit, + ) + ) + + count_task = asyncio.create_task( + self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) ) + es_response = await search_task + hits = es_response["hits"]["hits"] items = [ self.item_serializer.db_to_stac(hit["_source"], base_url=base_url) @@ -226,6 +231,15 @@ async def execute_search( ",".join([str(x) for x in sort_array]).encode() ).decode() + # (1) count should not block returning results, so don't wait for it to be done + # (2) don't cancel the task so that it will populate the ES cache for subsequent counts + maybe_count = None + if count_task.done(): + try: + maybe_count = count_task.result().get("count") + except Exception as e: # type: ignore + logger.error(f"Count task failed: {e}") + return items, maybe_count, next_token """ TRANSACTION LOGIC """ diff --git a/stac_fastapi/elasticsearch/tests/api/test_api.py b/stac_fastapi/elasticsearch/tests/api/test_api.py index c43577bd..5b0ab8f4 100644 --- a/stac_fastapi/elasticsearch/tests/api/test_api.py +++ b/stac_fastapi/elasticsearch/tests/api/test_api.py @@ -103,7 +103,9 @@ async def test_app_context_extension(app_client, ctx, txn_client): resp_json = resp.json() assert len(resp_json["features"]) == 1 assert "context" in resp_json - assert resp_json["context"]["returned"] == resp_json["context"]["matched"] == 1 + assert resp_json["context"]["returned"] == 1 + if matched := resp_json["context"].get("matched"): + assert matched == 1 @pytest.mark.skip(reason="fields not implemented yet") diff --git a/stac_fastapi/elasticsearch/tests/resources/test_item.py b/stac_fastapi/elasticsearch/tests/resources/test_item.py index d6d48e12..8f7ea0c4 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_item.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_item.py @@ -184,7 +184,8 @@ async def test_get_item_collection(app_client, ctx, txn_client): assert resp.status_code == 200 item_collection = resp.json() - assert item_collection["context"]["matched"] == item_count + 1 + if matched := item_collection["context"].get("matched"): + assert matched == item_count + 1 @pytest.mark.skip(reason="Pagination extension not implemented")