diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d9df98a..cc3aa19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] +### Added + +- Added support for enum queryables [#390](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/390) + ### Changed - Optimize data_loader.py script [#395](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/395) diff --git a/Makefile b/Makefile index 5896e734..c23ca951 100644 --- a/Makefile +++ b/Makefile @@ -3,13 +3,11 @@ APP_HOST ?= 0.0.0.0 EXTERNAL_APP_PORT ?= 8080 ES_APP_PORT ?= 8080 +OS_APP_PORT ?= 8082 + ES_HOST ?= docker.for.mac.localhost ES_PORT ?= 9200 -OS_APP_PORT ?= 8082 -OS_HOST ?= docker.for.mac.localhost -OS_PORT ?= 9202 - run_es = docker compose \ run \ -p ${EXTERNAL_APP_PORT}:${ES_APP_PORT} \ diff --git a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py index 0043cfb8..57ca9437 100644 --- a/stac_fastapi/core/stac_fastapi/core/base_database_logic.py +++ b/stac_fastapi/core/stac_fastapi/core/base_database_logic.py @@ -1,7 +1,7 @@ """Base database logic.""" import abc -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional class BaseDatabaseLogic(abc.ABC): @@ -36,6 +36,18 @@ async def delete_item( """Delete an item from the database.""" pass + @abc.abstractmethod + async def get_items_mapping(self, collection_id: str) -> Dict[str, Dict[str, Any]]: + """Get the mapping for the items in the collection.""" + pass + + @abc.abstractmethod + async def get_items_unique_values( + self, collection_id: str, field_names: Iterable[str], *, limit: int = ... + ) -> Dict[str, List[str]]: + """Get the unique values for the given fields in the collection.""" + pass + @abc.abstractmethod async def create_collection(self, collection: Dict, refresh: bool = False) -> None: """Create a collection in the database.""" diff --git a/stac_fastapi/core/stac_fastapi/core/extensions/filter.py b/stac_fastapi/core/stac_fastapi/core/extensions/filter.py index 08200030..c6859672 100644 --- a/stac_fastapi/core/stac_fastapi/core/extensions/filter.py +++ b/stac_fastapi/core/stac_fastapi/core/extensions/filter.py @@ -60,6 +60,17 @@ "maximum": 100, }, } +"""Queryables that are present in all collections.""" + +OPTIONAL_QUERYABLES: Dict[str, Dict[str, Any]] = { + "platform": { + "$enum": True, + "description": "Satellite platform identifier", + }, +} +"""Queryables that are present in some collections.""" + +ALL_QUERYABLES: Dict[str, Dict[str, Any]] = DEFAULT_QUERYABLES | OPTIONAL_QUERYABLES class LogicalOp(str, Enum): diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index 3f22d7ab..cda5a464 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -37,6 +37,7 @@ TokenPaginationExtension, TransactionExtension, ) +from stac_fastapi.extensions.core.filter import FilterConformanceClasses from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient from stac_fastapi.sfeos_helpers.filter import EsAsyncBaseFiltersClient @@ -56,7 +57,7 @@ client=EsAsyncBaseFiltersClient(database=database_logic) ) filter_extension.conformance_classes.append( - "http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators" + FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS ) aggregation_extension = AggregationExtension( diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index d529ce01..a1ca6250 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -895,6 +895,37 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]: except ESNotFoundError: raise NotFoundError(f"Mapping for index {index_name} not found") + async def get_items_unique_values( + self, collection_id: str, field_names: Iterable[str], *, limit: int = 100 + ) -> Dict[str, List[str]]: + """Get the unique values for the given fields in the collection.""" + limit_plus_one = limit + 1 + index_name = index_alias_by_collection_id(collection_id) + + query = await self.client.search( + index=index_name, + body={ + "size": 0, + "aggs": { + field: {"terms": {"field": field, "size": limit_plus_one}} + for field in field_names + }, + }, + ) + + result: Dict[str, List[str]] = {} + for field, agg in query["aggregations"].items(): + if len(agg["buckets"]) > limit: + logger.warning( + "Skipping enum field %s: exceeds limit of %d unique values. " + "Consider excluding this field from enumeration or increase the limit.", + field, + limit, + ) + continue + result[field] = [bucket["key"] for bucket in agg["buckets"]] + return result + async def create_collection(self, collection: Collection, **kwargs: Any): """Create a single collection in the database. diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py index 0d11369e..66671ff0 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/app.py @@ -31,6 +31,7 @@ TokenPaginationExtension, TransactionExtension, ) +from stac_fastapi.extensions.core.filter import FilterConformanceClasses from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.opensearch.config import OpensearchSettings from stac_fastapi.opensearch.database_logic import ( @@ -56,7 +57,7 @@ client=EsAsyncBaseFiltersClient(database=database_logic) ) filter_extension.conformance_classes.append( - "http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators" + FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS ) aggregation_extension = AggregationExtension( diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index f93311f9..88c7fcdc 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -904,6 +904,37 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]: except exceptions.NotFoundError: raise NotFoundError(f"Mapping for index {index_name} not found") + async def get_items_unique_values( + self, collection_id: str, field_names: Iterable[str], *, limit: int = 100 + ) -> Dict[str, List[str]]: + """Get the unique values for the given fields in the collection.""" + limit_plus_one = limit + 1 + index_name = index_alias_by_collection_id(collection_id) + + query = await self.client.search( + index=index_name, + body={ + "size": 0, + "aggs": { + field: {"terms": {"field": field, "size": limit_plus_one}} + for field in field_names + }, + }, + ) + + result: Dict[str, List[str]] = {} + for field, agg in query["aggregations"].items(): + if len(agg["buckets"]) > limit: + logger.warning( + "Skipping enum field %s: exceeds limit of %d unique values. " + "Consider excluding this field from enumeration or increase the limit.", + field, + limit, + ) + continue + result[field] = [bucket["key"] for bucket in agg["buckets"]] + return result + async def create_collection(self, collection: Collection, **kwargs: Any): """Create a single collection in the database. diff --git a/stac_fastapi/sfeos_helpers/stac_fastapi/sfeos_helpers/filter/client.py b/stac_fastapi/sfeos_helpers/stac_fastapi/sfeos_helpers/filter/client.py index 4b2a1a71..9d0eb69b 100644 --- a/stac_fastapi/sfeos_helpers/stac_fastapi/sfeos_helpers/filter/client.py +++ b/stac_fastapi/sfeos_helpers/stac_fastapi/sfeos_helpers/filter/client.py @@ -1,12 +1,12 @@ """Filter client implementation for Elasticsearch/OpenSearch.""" from collections import deque -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import attr from stac_fastapi.core.base_database_logic import BaseDatabaseLogic -from stac_fastapi.core.extensions.filter import DEFAULT_QUERYABLES +from stac_fastapi.core.extensions.filter import ALL_QUERYABLES, DEFAULT_QUERYABLES from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient from stac_fastapi.sfeos_helpers.mappings import ES_MAPPING_TYPE_TO_JSON @@ -59,31 +59,31 @@ async def get_queryables( mapping_data = await self.database.get_items_mapping(collection_id) mapping_properties = next(iter(mapping_data.values()))["mappings"]["properties"] - stack = deque(mapping_properties.items()) + stack: deque[Tuple[str, Dict[str, Any]]] = deque(mapping_properties.items()) + enum_fields: Dict[str, Dict[str, Any]] = {} while stack: - field_name, field_def = stack.popleft() + field_fqn, field_def = stack.popleft() # Iterate over nested fields field_properties = field_def.get("properties") if field_properties: - # Fields in Item Properties should be exposed with their un-prefixed names, - # and not require expressions to prefix them with properties, - # e.g., eo:cloud_cover instead of properties.eo:cloud_cover. - if field_name == "properties": - stack.extend(field_properties.items()) - else: - stack.extend( - (f"{field_name}.{k}", v) for k, v in field_properties.items() - ) + stack.extend( + (f"{field_fqn}.{k}", v) for k, v in field_properties.items() + ) # Skip non-indexed or disabled fields field_type = field_def.get("type") if not field_type or not field_def.get("enabled", True): continue + # Fields in Item Properties should be exposed with their un-prefixed names, + # and not require expressions to prefix them with properties, + # e.g., eo:cloud_cover instead of properties.eo:cloud_cover. + field_name = field_fqn.removeprefix("properties.") + # Generate field properties - field_result = DEFAULT_QUERYABLES.get(field_name, {}) + field_result = ALL_QUERYABLES.get(field_name, {}) properties[field_name] = field_result field_name_human = field_name.replace("_", " ").title() @@ -95,4 +95,13 @@ async def get_queryables( if field_type in {"date", "date_nanos"}: field_result.setdefault("format", "date-time") + if field_result.pop("$enum", False): + enum_fields[field_fqn] = field_result + + if enum_fields: + for field_fqn, unique_values in ( + await self.database.get_items_unique_values(collection_id, enum_fields) + ).items(): + enum_fields[field_fqn]["enum"] = unique_values + return queryables diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index afb9ac9b..7d8c1113 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -27,7 +27,9 @@ from stac_fastapi.core.rate_limit import setup_rate_limit from stac_fastapi.core.route_dependencies import get_route_dependencies from stac_fastapi.core.utilities import get_bool_env +from stac_fastapi.extensions.core.filter import FilterConformanceClasses from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient +from stac_fastapi.sfeos_helpers.filter import EsAsyncBaseFiltersClient if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch": from stac_fastapi.opensearch.config import AsyncOpensearchSettings as AsyncSettings @@ -39,9 +41,11 @@ ) else: from stac_fastapi.elasticsearch.config import ( - ElasticsearchSettings as SearchSettings, AsyncElasticsearchSettings as AsyncSettings, ) + from stac_fastapi.elasticsearch.config import ( + ElasticsearchSettings as SearchSettings, + ) from stac_fastapi.elasticsearch.database_logic import ( DatabaseLogic, create_collection_index, @@ -198,6 +202,13 @@ def bulk_txn_client(): async def app(): settings = AsyncSettings() + filter_extension = FilterExtension( + client=EsAsyncBaseFiltersClient(database=database) + ) + filter_extension.conformance_classes.append( + FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS + ) + aggregation_extension = AggregationExtension( client=EsAsyncBaseAggregationClient( database=database, session=None, settings=settings @@ -217,7 +228,7 @@ async def app(): FieldsExtension(), QueryExtension(), TokenPaginationExtension(), - FilterExtension(), + filter_extension, FreeTextExtension(), ] @@ -313,7 +324,6 @@ async def app_client_rate_limit(app_rate_limit): @pytest_asyncio.fixture(scope="session") async def app_basic_auth(): - stac_fastapi_route_dependencies = """[ { "routes":[{"method":"*","path":"*"}], diff --git a/stac_fastapi/tests/extensions/test_filter.py b/stac_fastapi/tests/extensions/test_filter.py index fb6bc850..e54d198e 100644 --- a/stac_fastapi/tests/extensions/test_filter.py +++ b/stac_fastapi/tests/extensions/test_filter.py @@ -1,10 +1,13 @@ import json import logging import os +import uuid from os import listdir from os.path import isfile, join +from typing import Callable, Dict import pytest +from httpx import AsyncClient THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -40,7 +43,6 @@ async def test_filter_extension_collection_link(app_client, load_test_data): @pytest.mark.asyncio async def test_search_filters_post(app_client, ctx): - filters = [] pwd = f"{THIS_DIR}/cql2" for fn in [fn for f in listdir(pwd) if isfile(fn := join(pwd, f))]: @@ -625,3 +627,50 @@ async def test_search_filter_extension_cql2text_s_disjoint_property(app_client, assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 + + +@pytest.mark.asyncio +async def test_queryables_enum_platform( + app_client: AsyncClient, + load_test_data: Callable[[str], Dict], + monkeypatch: pytest.MonkeyPatch, +): + # Arrange + # Enforce instant database refresh + # TODO: Is there a better way to do this? + monkeypatch.setenv("DATABASE_REFRESH", "true") + + # Create collection + collection_data = load_test_data("test_collection.json") + collection_id = collection_data["id"] = f"enum-test-collection-{uuid.uuid4()}" + r = await app_client.post("/collections", json=collection_data) + r.raise_for_status() + + # Create items with different platform values + NUM_ITEMS = 3 + for i in range(1, NUM_ITEMS + 1): + item_data = load_test_data("test_item.json") + item_data["id"] = f"enum-test-item-{i}" + item_data["collection"] = collection_id + item_data["properties"]["platform"] = "landsat-8" if i % 2 else "sentinel-2" + r = await app_client.post(f"/collections/{collection_id}/items", json=item_data) + r.raise_for_status() + + # Act + # Test queryables endpoint + queryables = ( + (await app_client.get(f"/collections/{collection_data['id']}/queryables")) + .raise_for_status() + .json() + ) + + # Assert + # Verify distinct values (should only have 2 unique values despite 3 items) + properties = queryables["properties"] + platform_info = properties["platform"] + platform_values = platform_info["enum"] + assert set(platform_values) == {"landsat-8", "sentinel-2"} + + # Clean up + r = await app_client.delete(f"/collections/{collection_id}") + r.raise_for_status()