diff --git a/docker-compose.yml b/docker-compose.yml index 74785a90..7dd9fdec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ services: environment: - APP_HOST=0.0.0.0 - APP_PORT=8080 - - RELOAD=false + - RELOAD=true - ENVIRONMENT=local - WEB_CONCURRENCY=10 - ES_HOST=172.17.0.1 diff --git a/stac_fastapi/elasticsearch/setup.py b/stac_fastapi/elasticsearch/setup.py index 923ad975..9ac4f66a 100644 --- a/stac_fastapi/elasticsearch/setup.py +++ b/stac_fastapi/elasticsearch/setup.py @@ -19,6 +19,7 @@ "pystac[validation]", "uvicorn", "overrides", + "starlette", ] extra_reqs = { diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index f62e5038..a55096c3 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -13,10 +13,12 @@ from pydantic import ValidationError from stac_pydantic.links import Relations from stac_pydantic.shared import MimeTypes +from starlette.requests import Request from stac_fastapi.elasticsearch import serializers from stac_fastapi.elasticsearch.config import ElasticsearchSettings from stac_fastapi.elasticsearch.database_logic import DatabaseLogic +from stac_fastapi.elasticsearch.models.links import PagingLinks from stac_fastapi.elasticsearch.serializers import CollectionSerializer, ItemSerializer from stac_fastapi.elasticsearch.session import Session from stac_fastapi.extensions.third_party.bulk_transactions import ( @@ -84,11 +86,17 @@ async def item_collection( self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> ItemCollection: """Read an item collection from the database.""" - links = [] - base_url = str(kwargs["request"].base_url) - - items, maybe_count = await self.database.get_collection_items( - collection_id=collection_id, limit=limit, base_url=base_url + request: Request = kwargs["request"] + base_url = str(request.base_url) + + items, maybe_count, next_token = await self.database.execute_search( + search=self.database.apply_collections_filter( + self.database.make_search(), [collection_id] + ), + limit=limit, + token=token, + sort=None, + base_url=base_url, ) context_obj = None @@ -100,6 +108,10 @@ async def item_collection( if maybe_count is not None: context_obj["matched"] = maybe_count + links = [] + if next_token: + links = await PagingLinks(request=request, next=next_token).get_links() + return ItemCollection( type="FeatureCollection", features=items, @@ -203,16 +215,10 @@ async def post_search( self, search_request: stac_pydantic.api.Search, **kwargs ) -> ItemCollection: """POST search catalog.""" - base_url = str(kwargs["request"].base_url) - search = self.database.make_search() + request: Request = kwargs["request"] + base_url = str(request.base_url) - if search_request.query: - for (field_name, expr) in search_request.query.items(): - field = "properties__" + field_name - for (op, value) in expr.items(): - search = self.database.apply_stacql_filter( - search=search, op=op, field=field, value=value - ) + search = self.database.make_search() if search_request.ids: search = self.database.apply_ids_filter( @@ -242,16 +248,28 @@ async def post_search( search=search, intersects=search_request.intersects ) + if search_request.query: + for (field_name, expr) in search_request.query.items(): + field = "properties__" + field_name + for (op, value) in expr.items(): + search = self.database.apply_stacql_filter( + search=search, op=op, field=field, value=value + ) + + sort = None if search_request.sortby: - for sort in search_request.sortby: - if sort.field == "datetime": - sort.field = "properties__datetime" - search = self.database.apply_sort( - search=search, field=sort.field, direction=sort.direction - ) + sort = self.database.populate_sort(search_request.sortby) + + limit = 10 + if search_request.limit: + limit = search_request.limit - response_features, maybe_count = await self.database.execute_search( - search=search, limit=search_request.limit, base_url=base_url + items, maybe_count, next_token = await self.database.execute_search( + search=search, + limit=limit, + token=search_request.token, # type: ignore + sort=sort, + base_url=base_url, ) # if self.extension_is_enabled("FieldsExtension"): @@ -274,26 +292,23 @@ async def post_search( # for feat in response_features # ] - if search_request.limit: - limit = search_request.limit - response_features = response_features[0:limit] - else: - limit = 10 - response_features = response_features[0:limit] - context_obj = None if self.extension_is_enabled("ContextExtension"): context_obj = { - "returned": len(response_features), + "returned": len(items), "limit": limit, } if maybe_count is not None: context_obj["matched"] = maybe_count + links = [] + if next_token: + links = await PagingLinks(request=request, next=next_token).get_links() + return ItemCollection( type="FeatureCollection", - features=response_features, - links=[], + features=items, + links=links, context=context_obj, ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 16d1ce21..7319b962 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,5 +1,6 @@ """Database logic.""" import logging +from base64 import urlsafe_b64decode, urlsafe_b64encode from typing import Dict, List, Optional, Tuple, Type, Union import attr @@ -75,17 +76,6 @@ async def get_all_collections(self, base_url: str) -> List[Collection]: for c in collections["hits"]["hits"] ] - async def get_collection_items( - self, collection_id: str, limit: int, base_url: str - ) -> Tuple[List[Item], Optional[int]]: - """Database logic to retrieve an ItemCollection and a count of items contained.""" - search = self.apply_collections_filter(Search(), [collection_id]) - items, maybe_count = await self.execute_search( - search=search, limit=limit, base_url=base_url - ) - - return items, maybe_count - async def get_one_item(self, collection_id: str, item_id: str) -> Dict: """Database logic to retrieve a single item.""" try: @@ -190,31 +180,53 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float): return search @staticmethod - def apply_sort(search: Search, field, direction): + def populate_sort(sortby: List) -> Optional[Dict[str, Dict[str, str]]]: """Database logic to sort search instance.""" - return search.sort({field: {"order": direction}}) + if sortby: + return {s.field: {"order": s.direction} for s in sortby} + else: + return None async def execute_search( - self, search, limit: int, base_url: str - ) -> Tuple[List[Item], Optional[int]]: + self, + search: Search, + limit: int, + token: Optional[str], + sort: Optional[Dict[str, Dict[str, str]]], + base_url: str, + ) -> Tuple[List[Item], Optional[int], Optional[str]]: """Database logic to execute search with limit.""" - search = search[0:limit] body = search.to_dict() maybe_count = ( await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) ).get("count") + search_after = None + if token: + search_after = urlsafe_b64decode(token.encode()).decode().split(",") + es_response = await self.client.search( - index=ITEMS_INDEX, query=body.get("query"), sort=body.get("sort") + index=ITEMS_INDEX, + query=body.get("query"), + sort=sort or DEFAULT_SORT, + search_after=search_after, + size=limit, ) + hits = es_response["hits"]["hits"] items = [ self.item_serializer.db_to_stac(hit["_source"], base_url=base_url) - for hit in es_response["hits"]["hits"] + for hit in hits ] - return items, maybe_count + next_token = None + if hits and (sort_array := hits[-1].get("sort")): + next_token = urlsafe_b64encode( + ",".join([str(x) for x in sort_array]).encode() + ).decode() + + return items, maybe_count, next_token """ TRANSACTION LOGIC """ diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py new file mode 100644 index 00000000..3941a149 --- /dev/null +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py @@ -0,0 +1,138 @@ +"""link helpers.""" + +from typing import Any, Dict, List, Optional +from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse + +import attr +from stac_pydantic.links import Relations +from stac_pydantic.shared import MimeTypes +from starlette.requests import Request + +# Copied from pgstac links + +# These can be inferred from the item/collection, so they aren't included in the database +# Instead they are dynamically generated when querying the database using the classes defined below +INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"] + + +def merge_params(url: str, newparams: Dict) -> str: + """Merge url parameters.""" + u = urlparse(url) + params = parse_qs(u.query) + params.update(newparams) + param_string = unquote(urlencode(params, True)) + + href = ParseResult( + scheme=u.scheme, + netloc=u.netloc, + path=u.path, + params=u.params, + query=param_string, + fragment=u.fragment, + ).geturl() + return href + + +@attr.s +class BaseLinks: + """Create inferred links common to collections and items.""" + + request: Request = attr.ib() + + @property + def base_url(self): + """Get the base url.""" + return str(self.request.base_url) + + @property + def url(self): + """Get the current request url.""" + return str(self.request.url) + + def resolve(self, url): + """Resolve url to the current request url.""" + return urljoin(str(self.base_url), str(url)) + + def link_self(self) -> Dict: + """Return the self link.""" + return dict(rel=Relations.self.value, type=MimeTypes.json.value, href=self.url) + + def link_root(self) -> Dict: + """Return the catalog root.""" + return dict( + rel=Relations.root.value, type=MimeTypes.json.value, href=self.base_url + ) + + def create_links(self) -> List[Dict[str, Any]]: + """Return all inferred links.""" + links = [] + for name in dir(self): + if name.startswith("link_") and callable(getattr(self, name)): + link = getattr(self, name)() + if link is not None: + links.append(link) + return links + + async def get_links( + self, extra_links: Optional[List[Dict[str, Any]]] = None + ) -> List[Dict[str, Any]]: + """ + Generate all the links. + + Get the links object for a stac resource by iterating through + available methods on this class that start with link_. + """ + # TODO: Pass request.json() into function so this doesn't need to be coroutine + if self.request.method == "POST": + self.request.postbody = await self.request.json() + # join passed in links with generated links + # and update relative paths + links = self.create_links() + + if extra_links: + # For extra links passed in, + # add links modified with a resolved href. + # Drop any links that are dynamically + # determined by the server (e.g. self, parent, etc.) + # Resolving the href allows for relative paths + # to be stored in pgstac and for the hrefs in the + # links of response STAC objects to be resolved + # to the request url. + links += [ + {**link, "href": self.resolve(link["href"])} + for link in extra_links + if link["rel"] not in INFERRED_LINK_RELS + ] + + return links + + +@attr.s +class PagingLinks(BaseLinks): + """Create links for paging.""" + + next: Optional[str] = attr.ib(kw_only=True, default=None) + + def link_next(self) -> Optional[Dict[str, Any]]: + """Create link for next page.""" + if self.next is not None: + method = self.request.method + if method == "GET": + href = merge_params(self.url, {"token": self.next}) + link = dict( + rel=Relations.next.value, + type=MimeTypes.json.value, + method=method, + href=href, + ) + return link + if method == "POST": + return { + "rel": Relations.next, + "type": MimeTypes.json, + "method": method, + "href": f"{self.request.url}", + "body": {**self.request.postbody, "token": self.next}, + } + + return None diff --git a/stac_fastapi/elasticsearch/tests/api/test_api.py b/stac_fastapi/elasticsearch/tests/api/test_api.py index 1ee27aa7..c43577bd 100644 --- a/stac_fastapi/elasticsearch/tests/api/test_api.py +++ b/stac_fastapi/elasticsearch/tests/api/test_api.py @@ -6,31 +6,29 @@ from ..conftest import MockRequest, create_collection, create_item -ROUTES = set( - [ - "GET /_mgmt/ping", - "GET /docs/oauth2-redirect", - "HEAD /docs/oauth2-redirect", - "GET /", - "GET /conformance", - "GET /api", - "GET /api.html", - "HEAD /api", - "HEAD /api.html", - "GET /collections", - "GET /collections/{collection_id}", - "GET /collections/{collection_id}/items", - "GET /collections/{collection_id}/items/{item_id}", - "GET /search", - "POST /search", - "DELETE /collections/{collection_id}", - "DELETE /collections/{collection_id}/items/{item_id}", - "POST /collections", - "POST /collections/{collection_id}/items", - "PUT /collections", - "PUT /collections/{collection_id}/items", - ] -) +ROUTES = { + "GET /_mgmt/ping", + "GET /docs/oauth2-redirect", + "HEAD /docs/oauth2-redirect", + "GET /", + "GET /conformance", + "GET /api", + "GET /api.html", + "HEAD /api", + "HEAD /api.html", + "GET /collections", + "GET /collections/{collection_id}", + "GET /collections/{collection_id}/items", + "GET /collections/{collection_id}/items/{item_id}", + "GET /search", + "POST /search", + "DELETE /collections/{collection_id}", + "DELETE /collections/{collection_id}/items/{item_id}", + "POST /collections", + "POST /collections/{collection_id}/items", + "PUT /collections", + "PUT /collections/{collection_id}/items", +} async def test_post_search_content_type(app_client, ctx): @@ -137,15 +135,15 @@ async def test_app_query_extension_gte(app_client, ctx): assert len(resp.json()["features"]) == 1 -async def test_app_query_extension_limit_lt0(app_client, ctx): +async def test_app_query_extension_limit_lt0(app_client): assert (await app_client.post("/search", json={"limit": -1})).status_code == 400 -async def test_app_query_extension_limit_gt10000(app_client, ctx): +async def test_app_query_extension_limit_gt10000(app_client): assert (await app_client.post("/search", json={"limit": 10001})).status_code == 400 -async def test_app_query_extension_limit_10000(app_client, ctx): +async def test_app_query_extension_limit_10000(app_client): params = {"limit": 10000} resp = await app_client.post("/search", json=params) assert resp.status_code == 200 diff --git a/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py b/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py index db7972a7..97fed121 100644 --- a/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py +++ b/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py @@ -94,7 +94,7 @@ async def test_get_collection_items(app_client, ctx, core_client, txn_client): item["id"] = str(uuid.uuid4()) await txn_client.create_item(item, request=MockRequest, refresh=True) - fc = await core_client.item_collection(coll["id"], request=MockRequest) + fc = await core_client.item_collection(coll["id"], request=MockRequest()) assert len(fc["features"]) == num_of_items_to_create + 1 # ctx.item for item in fc["features"]: @@ -166,7 +166,7 @@ async def test_bulk_item_insert(ctx, core_client, txn_client, bulk_txn_client): bulk_txn_client.bulk_item_insert(Items(items=items), refresh=True) - fc = await core_client.item_collection(ctx.collection["id"], request=MockRequest) + fc = await core_client.item_collection(ctx.collection["id"], request=MockRequest()) assert len(fc["features"]) >= 10 # for item in items: @@ -190,20 +190,16 @@ async def test_feature_collection_insert( await create_item(txn_client, feature_collection) - fc = await core_client.item_collection(ctx.collection["id"], request=MockRequest) + fc = await core_client.item_collection(ctx.collection["id"], request=MockRequest()) assert len(fc["features"]) >= 10 -@pytest.mark.skip(reason="app fixture isn't injected or something?") async def test_landing_page_no_collection_title(ctx, core_client, txn_client, app): - class MockRequestWithApp(MockRequest): - app = app - ctx.collection["id"] = "new_id" del ctx.collection["title"] await txn_client.create_collection(ctx.collection, request=MockRequest) - landing_page = await core_client.landing_page(request=MockRequestWithApp) + landing_page = await core_client.landing_page(request=MockRequest(app=app)) for link in landing_page["links"]: if link["href"].split("/")[-1] == ctx.collection["id"]: assert link["title"] diff --git a/stac_fastapi/elasticsearch/tests/conftest.py b/stac_fastapi/elasticsearch/tests/conftest.py index a207836d..6906096a 100644 --- a/stac_fastapi/elasticsearch/tests/conftest.py +++ b/stac_fastapi/elasticsearch/tests/conftest.py @@ -2,7 +2,7 @@ import copy import json import os -from typing import Callable, Dict +from typing import Any, Callable, Dict, Optional import pytest import pytest_asyncio @@ -41,6 +41,13 @@ def __init__(self, item, collection): class MockRequest: base_url = "http://test-server" + def __init__( + self, method: str = "GET", url: str = "XXXX", app: Optional[Any] = None + ): + self.method = method + self.url = url + self.app = app + class TestSettings(AsyncElasticsearchSettings): class Config: @@ -181,5 +188,5 @@ async def app(): async def app_client(app): await IndexesClient().create_indexes() - async with AsyncClient(app=app, base_url="http://test") as c: + async with AsyncClient(app=app, base_url="http://test-server") as c: yield c diff --git a/stac_fastapi/elasticsearch/tests/resources/test_item.py b/stac_fastapi/elasticsearch/tests/resources/test_item.py index 6869e60d..d6d48e12 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_item.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_item.py @@ -450,7 +450,7 @@ async def test_item_search_properties_es(app_client, ctx): assert len(resp_json["features"]) == 0 -async def test_item_search_properties_field(app_client, ctx): +async def test_item_search_properties_field(app_client): """Test POST search indexed field with query (query extension)""" # Orientation is an indexed field @@ -493,102 +493,86 @@ async def test_get_missing_item_collection(app_client): assert resp.status_code == 200 -@pytest.mark.skip(reason="Pagination extension not implemented") -async def test_pagination_item_collection(app_client, load_test_data): +async def test_pagination_item_collection(app_client, ctx, txn_client): """Test item collection pagination links (paging extension)""" - test_item = load_test_data("test_item.json") - ids = [] + ids = [ctx.item["id"]] # Ingest 5 items - for idx in range(5): - uid = str(uuid.uuid4()) - test_item["id"] = uid - resp = await app_client.post( - f"/collections/{test_item['collection']}/items", json=test_item - ) - assert resp.status_code == 200 - ids.append(uid) + for _ in range(5): + ctx.item["id"] = str(uuid.uuid4()) + await create_item(txn_client, ctx.item) + ids.append(ctx.item["id"]) - # Paginate through all 5 items with a limit of 1 (expecting 5 requests) + # Paginate through all 6 items with a limit of 1 (expecting 7 requests) page = await app_client.get( - f"/collections/{test_item['collection']}/items", params={"limit": 1} + f"/collections/{ctx.item['collection']}/items", params={"limit": 1} ) - idx = 0 + item_ids = [] - while True: - idx += 1 + idx = 0 + for idx in range(100): page_data = page.json() - item_ids.append(page_data["features"][0]["id"]) next_link = list(filter(lambda l: l["rel"] == "next", page_data["links"])) if not next_link: + assert not page_data["features"] break - query_params = parse_qs(urlparse(next_link[0]["href"]).query) - page = await app_client.get( - f"/collections/{test_item['collection']}/items", - params=query_params, - ) - # Our limit is 1 so we expect len(ids) number of requests before we run out of pages + assert len(page_data["features"]) == 1 + item_ids.append(page_data["features"][0]["id"]) + + href = next_link[0]["href"][len("http://test-server") :] + page = await app_client.get(href) + assert idx == len(ids) # Confirm we have paginated through all items assert not set(item_ids) - set(ids) -@pytest.mark.skip(reason="Pagination extension not implemented") -async def test_pagination_post(app_client, load_test_data): +async def test_pagination_post(app_client, ctx, txn_client): """Test POST pagination (paging extension)""" - test_item = load_test_data("test_item.json") - ids = [] + ids = [ctx.item["id"]] # Ingest 5 items - for idx in range(5): - uid = str(uuid.uuid4()) - test_item["id"] = uid - resp = await app_client.post( - f"/collections/{test_item['collection']}/items", json=test_item - ) - assert resp.status_code == 200 - ids.append(uid) + for _ in range(5): + ctx.item["id"] = str(uuid.uuid4()) + await create_item(txn_client, ctx.item) + ids.append(ctx.item["id"]) # Paginate through all 5 items with a limit of 1 (expecting 5 requests) request_body = {"ids": ids, "limit": 1} page = await app_client.post("/search", json=request_body) idx = 0 item_ids = [] - while True: + for _ in range(100): idx += 1 page_data = page.json() - item_ids.append(page_data["features"][0]["id"]) next_link = list(filter(lambda l: l["rel"] == "next", page_data["links"])) if not next_link: break + + item_ids.append(page_data["features"][0]["id"]) + # Merge request bodies request_body.update(next_link[0]["body"]) page = await app_client.post("/search", json=request_body) - # Our limit is 1 so we expect len(ids) number of requests before we run out of pages - assert idx == len(ids) + # Our limit is 1, so we expect len(ids) number of requests before we run out of pages + assert idx == len(ids) + 1 # Confirm we have paginated through all items assert not set(item_ids) - set(ids) -@pytest.mark.skip(reason="Pagination extension not implemented") -async def test_pagination_token_idempotent(app_client, load_test_data): +async def test_pagination_token_idempotent(app_client, ctx, txn_client): """Test that pagination tokens are idempotent (paging extension)""" - test_item = load_test_data("test_item.json") - ids = [] + ids = [ctx.item["id"]] # Ingest 5 items - for idx in range(5): - uid = str(uuid.uuid4()) - test_item["id"] = uid - resp = await app_client.post( - f"/collections/{test_item['collection']}/items", json=test_item - ) - assert resp.status_code == 200 - ids.append(uid) + for _ in range(5): + ctx.item["id"] = str(uuid.uuid4()) + await create_item(txn_client, ctx.item) + ids.append(ctx.item["id"]) page = await app_client.get("/search", params={"ids": ",".join(ids), "limit": 3}) page_data = page.json()