diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ffa151b..9f3d791b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Collection-level Assets to the CollectionSerializer [#148](https://github.com/stac-utils/stac-fastapi-elasticsearch/issues/148) - Examples folder with example docker setup for running sfes from pip [#147](https://github.com/stac-utils/stac-fastapi-elasticsearch/pull/147) +- GET /search filter extension queries [#163](https://github.com/stac-utils/stac-fastapi-elasticsearch/pull/163) - Added support for GET /search intersection queries [#158](https://github.com/stac-utils/stac-fastapi-elasticsearch/issues/158) ### Changed diff --git a/stac_fastapi/elasticsearch/setup.py b/stac_fastapi/elasticsearch/setup.py index 1a7fc7cf..6ff8cd86 100644 --- a/stac_fastapi/elasticsearch/setup.py +++ b/stac_fastapi/elasticsearch/setup.py @@ -17,6 +17,7 @@ "elasticsearch-dsl==7.4.1", "pystac[validation]", "uvicorn", + "orjson", "overrides", "starlette", "geojson-pydantic", diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index 031a0235..e4aa5846 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -1,19 +1,21 @@ """Item crud client.""" -import json import logging +import re from datetime import datetime as datetime_type from datetime import timezone from typing import Any, Dict, List, Optional, Set, Type, Union from urllib.parse import unquote_plus, urljoin import attr +import orjson import stac_pydantic -from fastapi import HTTPException +from fastapi import HTTPException, Request from overrides import overrides from pydantic import ValidationError +from pygeofilter.backends.cql2_json import to_cql2 +from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from stac_pydantic.links import Relations from stac_pydantic.shared import MimeTypes -from starlette.requests import Request from stac_fastapi.elasticsearch import serializers from stac_fastapi.elasticsearch.config import ElasticsearchSettings @@ -274,9 +276,9 @@ def _return_date(interval_str): return {"lte": end_date, "gte": start_date} - @overrides async def get_search( self, + request: Request, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, bbox: Optional[List[NumType]] = None, @@ -287,8 +289,8 @@ async def get_search( fields: Optional[List[str]] = None, sortby: Optional[str] = None, intersects: Optional[str] = None, - # filter: Optional[str] = None, # todo: requires fastapi > 2.3 unreleased - # filter_lang: Optional[str] = None, # todo: requires fastapi > 2.3 unreleased + filter: Optional[str] = None, + filter_lang: Optional[str] = None, **kwargs, ) -> ItemCollection: """Get search results from the database. @@ -318,17 +320,24 @@ async def get_search( "bbox": bbox, "limit": limit, "token": token, - "query": json.loads(query) if query else query, + "query": orjson.loads(query) if query else query, } + # this is borrowed from stac-fastapi-pgstac + # Kludgy fix because using factory does not allow alias for filter-lan + query_params = str(request.query_params) + if filter_lang is None: + match = re.search(r"filter-lang=([a-z0-9-]+)", query_params, re.IGNORECASE) + if match: + filter_lang = match.group(1) + if datetime: base_args["datetime"] = datetime if intersects: - base_args["intersects"] = json.loads(unquote_plus(intersects)) + base_args["intersects"] = orjson.loads(unquote_plus(intersects)) if sortby: - # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form sort_param = [] for sort in sortby: sort_param.append( @@ -339,12 +348,13 @@ async def get_search( ) base_args["sortby"] = sort_param - # todo: requires fastapi > 2.3 unreleased - # if filter: - # if filter_lang == "cql2-text": - # base_args["filter-lang"] = "cql2-json" - # base_args["filter"] = orjson.loads(to_cql2(parse_cql2_text(filter))) - # print(f'>>> {base_args["filter"]}') + if filter: + if filter_lang == "cql2-text": + base_args["filter-lang"] = "cql2-json" + base_args["filter"] = orjson.loads(to_cql2(parse_cql2_text(filter))) + else: + base_args["filter-lang"] = "cql2-json" + base_args["filter"] = orjson.loads(unquote_plus(filter)) if fields: includes = set() @@ -363,13 +373,12 @@ async def get_search( search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") - resp = await self.post_search(search_request, request=kwargs["request"]) + resp = await self.post_search(search_request=search_request, request=request) return resp - @overrides async def post_search( - self, search_request: BaseSearchPostRequest, **kwargs + self, search_request: BaseSearchPostRequest, request: Request ) -> ItemCollection: """ Perform a POST search on the catalog. @@ -384,7 +393,6 @@ async def post_search( Raises: HTTPException: If there is an error with the cql2_json filter. """ - request: Request = kwargs["request"] base_url = str(request.base_url) search = self.database.make_search() @@ -471,7 +479,7 @@ async def post_search( filter_kwargs = search_request.fields.filter_fields items = [ - json.loads(stac_pydantic.Item(**feat).json(**filter_kwargs)) + orjson.loads(stac_pydantic.Item(**feat).json(**filter_kwargs)) for feat in items ] diff --git a/stac_fastapi/elasticsearch/tests/extensions/test_filter.py b/stac_fastapi/elasticsearch/tests/extensions/test_filter.py index 43aadf18..d9db48cd 100644 --- a/stac_fastapi/elasticsearch/tests/extensions/test_filter.py +++ b/stac_fastapi/elasticsearch/tests/extensions/test_filter.py @@ -3,10 +3,13 @@ from os import listdir from os.path import isfile, join +import pytest + THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -async def test_search_filters(app_client, ctx): +@pytest.mark.asyncio +async def test_search_filters_post(app_client, ctx): filters = [] pwd = f"{THIS_DIR}/cql2" @@ -19,7 +22,18 @@ async def test_search_filters(app_client, ctx): assert resp.status_code == 200 -async def test_search_filter_extension_eq(app_client, ctx): +@pytest.mark.asyncio +async def test_search_filter_extension_eq_get(app_client, ctx): + resp = await app_client.get( + '/search?filter-lang=cql2-json&filter={"op":"=","args":[{"property":"id"},"test-item"]}' + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json["features"]) == 1 + + +@pytest.mark.asyncio +async def test_search_filter_extension_eq_post(app_client, ctx): params = {"filter": {"op": "=", "args": [{"property": "id"}, ctx.item["id"]]}} resp = await app_client.post("/search", json=params) assert resp.status_code == 200 @@ -27,7 +41,26 @@ async def test_search_filter_extension_eq(app_client, ctx): assert len(resp_json["features"]) == 1 -async def test_search_filter_extension_gte(app_client, ctx): +@pytest.mark.asyncio +async def test_search_filter_extension_gte_get(app_client, ctx): + # there's one item that can match, so one of these queries should match it and the other shouldn't + resp = await app_client.get( + '/search?filter-lang=cql2-json&filter={"op":"<=","args":[{"property": "properties.proj:epsg"},32756]}' + ) + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 1 + + resp = await app_client.get( + '/search?filter-lang=cql2-json&filter={"op":">","args":[{"property": "properties.proj:epsg"},32756]}' + ) + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 0 + + +@pytest.mark.asyncio +async def test_search_filter_extension_gte_post(app_client, ctx): # there's one item that can match, so one of these queries should match it and the other shouldn't params = { "filter": { @@ -58,7 +91,53 @@ async def test_search_filter_extension_gte(app_client, ctx): assert len(resp.json()["features"]) == 0 -async def test_search_filter_ext_and(app_client, ctx): +@pytest.mark.asyncio +async def test_search_filter_ext_and_get(app_client, ctx): + resp = await app_client.get( + '/search?filter-lang=cql2-json&filter={"op":"and","args":[{"op":"<=","args":[{"property":"properties.proj:epsg"},32756]},{"op":"=","args":[{"property":"id"},"test-item"]}]}' + ) + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 1 + + +@pytest.mark.asyncio +async def test_search_filter_ext_and_get_cql2text_id(app_client, ctx): + collection = ctx.item["collection"] + id = ctx.item["id"] + filter = f"id='{id}' AND collection='{collection}'" + resp = await app_client.get(f"/search?filter-lang=cql2-text&filter={filter}") + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 1 + + +@pytest.mark.asyncio +async def test_search_filter_ext_and_get_cql2text_cloud_cover(app_client, ctx): + collection = ctx.item["collection"] + cloud_cover = ctx.item["properties"]["eo:cloud_cover"] + filter = f"cloud_cover={cloud_cover} AND collection='{collection}'" + resp = await app_client.get(f"/search?filter-lang=cql2-text&filter={filter}") + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 1 + + +@pytest.mark.asyncio +async def test_search_filter_ext_and_get_cql2text_cloud_cover_no_results( + app_client, ctx +): + collection = ctx.item["collection"] + cloud_cover = ctx.item["properties"]["eo:cloud_cover"] + 1 + filter = f"cloud_cover={cloud_cover} AND collection='{collection}'" + resp = await app_client.get(f"/search?filter-lang=cql2-text&filter={filter}") + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 0 + + +@pytest.mark.asyncio +async def test_search_filter_ext_and_post(app_client, ctx): params = { "filter": { "op": "and", @@ -80,7 +159,32 @@ async def test_search_filter_ext_and(app_client, ctx): assert len(resp.json()["features"]) == 1 -async def test_search_filter_extension_floats(app_client, ctx): +@pytest.mark.asyncio +async def test_search_filter_extension_floats_get(app_client, ctx): + resp = await app_client.get( + """/search?filter={"op":"and","args":[{"op":"=","args":[{"property":"id"},"test-item"]},{"op":">","args":[{"property":"properties.view:sun_elevation"},"-37.30891534"]},{"op":"<","args":[{"property":"properties.view:sun_elevation"},"-37.30691534"]}]}""" + ) + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 1 + + resp = await app_client.get( + """/search?filter={"op":"and","args":[{"op":"=","args":[{"property":"id"},"test-item-7"]},{"op":">","args":[{"property":"properties.view:sun_elevation"},"-37.30891534"]},{"op":"<","args":[{"property":"properties.view:sun_elevation"},"-37.30691534"]}]}""" + ) + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 0 + + resp = await app_client.get( + """/search?filter={"op":"and","args":[{"op":"=","args":[{"property":"id"},"test-item"]},{"op":">","args":[{"property":"properties.view:sun_elevation"},"-37.30591534"]},{"op":"<","args":[{"property":"properties.view:sun_elevation"},"-37.30491534"]}]}""" + ) + + assert resp.status_code == 200 + assert len(resp.json()["features"]) == 0 + + +@pytest.mark.asyncio +async def test_search_filter_extension_floats_post(app_client, ctx): sun_elevation = ctx.item["properties"]["view:sun_elevation"] params = {