diff --git a/CHANGELOG.md b/CHANGELOG.md index 90f35f22..bd27e75a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Fixed +- Fixed issue where paginated search queries would return a `next_token` on the last page [#243](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/243) - Fixed issue where searches return an empty `links` array [#241](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/241) ## [v2.4.0] diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 49c44a60..3ec81c99 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -8,6 +8,7 @@ import attr from elasticsearch_dsl import Q, Search +import stac_fastapi.types.search from elasticsearch import exceptions, helpers # type: ignore from stac_fastapi.core.extensions import filter from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer @@ -552,6 +553,7 @@ async def execute_search( NotFoundError: If the collections specified in `collection_ids` do not exist. """ search_after = None + if token: search_after = urlsafe_b64decode(token.encode()).decode().split(",") @@ -559,6 +561,10 @@ async def execute_search( index_param = indices(collection_ids) + max_result_window = stac_fastapi.types.search.Limit.le + + size_limit = min(limit + 1, max_result_window) + search_task = asyncio.create_task( self.client.search( index=index_param, @@ -566,7 +572,7 @@ async def execute_search( query=query, sort=sort or DEFAULT_SORT, search_after=search_after, - size=limit, + size=size_limit, ) ) @@ -584,24 +590,27 @@ async def execute_search( raise NotFoundError(f"Collections '{collection_ids}' do not exist") hits = es_response["hits"]["hits"] - items = (hit["_source"] for hit in hits) + items = (hit["_source"] for hit in hits[:limit]) next_token = None - if hits and (sort_array := hits[-1].get("sort")): - next_token = urlsafe_b64encode( - ",".join([str(x) for x in sort_array]).encode() - ).decode() - - # (1) count should not block returning results, so don't wait for it to be done - # (2) don't cancel the task so that it will populate the ES cache for subsequent counts - maybe_count = None + if len(hits) > limit and limit < max_result_window: + if hits and (sort_array := hits[limit - 1].get("sort")): + next_token = urlsafe_b64encode( + ",".join([str(x) for x in sort_array]).encode() + ).decode() + + matched = ( + es_response["hits"]["total"]["value"] + if es_response["hits"]["total"]["relation"] == "eq" + else None + ) if count_task.done(): try: - maybe_count = count_task.result().get("count") + matched = count_task.result().get("count") except Exception as e: logger.error(f"Count task failed: {e}") - return items, maybe_count, next_token + return items, matched, next_token """ TRANSACTION LOGIC """ diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 95129f27..3775ead3 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -11,6 +11,7 @@ from opensearchpy.helpers.query import Q from opensearchpy.helpers.search import Search +import stac_fastapi.types.search from stac_fastapi.core import serializers from stac_fastapi.core.extensions import filter from stac_fastapi.core.utilities import bbox2polygon @@ -582,19 +583,28 @@ async def execute_search( query = search.query.to_dict() if search.query else None if query: search_body["query"] = query + + search_after = None + if token: search_after = urlsafe_b64decode(token.encode()).decode().split(",") + if search_after: search_body["search_after"] = search_after + search_body["sort"] = sort if sort else DEFAULT_SORT index_param = indices(collection_ids) + max_result_window = stac_fastapi.types.search.Limit.le + + size_limit = min(limit + 1, max_result_window) + search_task = asyncio.create_task( self.client.search( index=index_param, ignore_unavailable=ignore_unavailable, body=search_body, - size=limit, + size=size_limit, ) ) @@ -612,24 +622,27 @@ async def execute_search( raise NotFoundError(f"Collections '{collection_ids}' do not exist") hits = es_response["hits"]["hits"] - items = (hit["_source"] for hit in hits) + items = (hit["_source"] for hit in hits[:limit]) next_token = None - if hits and (sort_array := hits[-1].get("sort")): - next_token = urlsafe_b64encode( - ",".join([str(x) for x in sort_array]).encode() - ).decode() - - # (1) count should not block returning results, so don't wait for it to be done - # (2) don't cancel the task so that it will populate the ES cache for subsequent counts - maybe_count = None + if len(hits) > limit and limit < max_result_window: + if hits and (sort_array := hits[limit - 1].get("sort")): + next_token = urlsafe_b64encode( + ",".join([str(x) for x in sort_array]).encode() + ).decode() + + matched = ( + es_response["hits"]["total"]["value"] + if es_response["hits"]["total"]["relation"] == "eq" + else None + ) if count_task.done(): try: - maybe_count = count_task.result().get("count") + matched = count_task.result().get("count") except Exception as e: logger.error(f"Count task failed: {e}") - return items, maybe_count, next_token + return items, matched, next_token """ TRANSACTION LOGIC """ diff --git a/stac_fastapi/tests/resources/test_item.py b/stac_fastapi/tests/resources/test_item.py index 336becdc..2d2c6099 100644 --- a/stac_fastapi/tests/resources/test_item.py +++ b/stac_fastapi/tests/resources/test_item.py @@ -492,12 +492,9 @@ async def test_item_search_temporal_window_timezone_get(app_client, ctx): "datetime": f"{datetime_to_str(item_date_before)}/{datetime_to_str(item_date_after)}", } resp = await app_client.get("/search", params=params) - resp_json = resp.json() - next_link = next(link for link in resp_json["links"] if link["rel"] == "next")[ - "href" - ] - resp = await app_client.get(next_link) assert resp.status_code == 200 + resp_json = resp.json() + assert resp_json["features"][0]["id"] == test_item["id"] @pytest.mark.asyncio @@ -632,18 +629,17 @@ async def test_pagination_item_collection(app_client, ctx, txn_client): await create_item(txn_client, item=ctx.item) ids.append(ctx.item["id"]) - # Paginate through all 6 items with a limit of 1 (expecting 7 requests) + # Paginate through all 6 items with a limit of 1 (expecting 6 requests) page = await app_client.get( f"/collections/{ctx.item['collection']}/items", params={"limit": 1} ) item_ids = [] - idx = 0 - for idx in range(100): + for idx in range(1, 100): page_data = page.json() next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"])) if not next_link: - assert not page_data["features"] + assert idx == 6 break assert len(page_data["features"]) == 1 @@ -672,10 +668,8 @@ async def test_pagination_post(app_client, ctx, txn_client): # Paginate through all 5 items with a limit of 1 (expecting 5 requests) request_body = {"ids": ids, "limit": 1} page = await app_client.post("/search", json=request_body) - idx = 0 item_ids = [] - for _ in range(100): - idx += 1 + for idx in range(1, 100): page_data = page.json() next_link = list(filter(lambda link: link["rel"] == "next", page_data["links"])) if not next_link: @@ -688,7 +682,7 @@ async def test_pagination_post(app_client, ctx, txn_client): page = await app_client.post("/search", json=request_body) # Our limit is 1, so we expect len(ids) number of requests before we run out of pages - assert idx == len(ids) + 1 + assert idx == len(ids) # Confirm we have paginated through all items assert not set(item_ids) - set(ids) @@ -702,8 +696,8 @@ async def test_pagination_token_idempotent(app_client, ctx, txn_client): # Ingest 5 items for _ in range(5): ctx.item["id"] = str(uuid.uuid4()) - await create_item(txn_client, ctx.item) - ids.append(ctx.item["id"]) + await create_item(txn_client, ctx.item) + ids.append(ctx.item["id"]) page = await app_client.get("/search", params={"ids": ",".join(ids), "limit": 3}) page_data = page.json()