Skip to content

Add support for optional enum queryables #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 8, 2025
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

### Added

- Added support for enum queryables [#390](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/390)

### Changed

- Optimize data_loader.py script [#395](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/395)
Expand Down
6 changes: 2 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ APP_HOST ?= 0.0.0.0
EXTERNAL_APP_PORT ?= 8080

ES_APP_PORT ?= 8080
OS_APP_PORT ?= 8082

ES_HOST ?= docker.for.mac.localhost
ES_PORT ?= 9200

OS_APP_PORT ?= 8082
OS_HOST ?= docker.for.mac.localhost
OS_PORT ?= 9202

run_es = docker compose \
run \
-p ${EXTERNAL_APP_PORT}:${ES_APP_PORT} \
Expand Down
14 changes: 13 additions & 1 deletion stac_fastapi/core/stac_fastapi/core/base_database_logic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base database logic."""

import abc
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional


class BaseDatabaseLogic(abc.ABC):
Expand Down Expand Up @@ -36,6 +36,18 @@ async def delete_item(
"""Delete an item from the database."""
pass

@abc.abstractmethod
async def get_items_mapping(self, collection_id: str) -> Dict[str, Dict[str, Any]]:
"""Get the mapping for the items in the collection."""
pass

@abc.abstractmethod
async def get_items_unique_values(
self, collection_id: str, field_names: Iterable[str], *, limit: int = ...
) -> Dict[str, List[str]]:
"""Get the unique values for the given fields in the collection."""
pass

@abc.abstractmethod
async def create_collection(self, collection: Dict, refresh: bool = False) -> None:
"""Create a collection in the database."""
Expand Down
11 changes: 11 additions & 0 deletions stac_fastapi/core/stac_fastapi/core/extensions/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@
"maximum": 100,
},
}
"""Queryables that are present in all collections."""

OPTIONAL_QUERYABLES: Dict[str, Dict[str, Any]] = {
"platform": {
"$enum": True,
"description": "Satellite platform identifier",
},
}
"""Queryables that are present in some collections."""

ALL_QUERYABLES: Dict[str, Dict[str, Any]] = DEFAULT_QUERYABLES | OPTIONAL_QUERYABLES


class LogicalOp(str, Enum):
Expand Down
3 changes: 2 additions & 1 deletion stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.filter import FilterConformanceClasses
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient
from stac_fastapi.sfeos_helpers.filter import EsAsyncBaseFiltersClient
Expand All @@ -56,7 +57,7 @@
client=EsAsyncBaseFiltersClient(database=database_logic)
)
filter_extension.conformance_classes.append(
"http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators"
FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS
)

aggregation_extension = AggregationExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,37 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]:
except ESNotFoundError:
raise NotFoundError(f"Mapping for index {index_name} not found")

async def get_items_unique_values(
self, collection_id: str, field_names: Iterable[str], *, limit: int = 100
) -> Dict[str, List[str]]:
"""Get the unique values for the given fields in the collection."""
limit_plus_one = limit + 1
index_name = index_alias_by_collection_id(collection_id)

query = await self.client.search(
index=index_name,
body={
"size": 0,
"aggs": {
field: {"terms": {"field": field, "size": limit_plus_one}}
for field in field_names
},
},
)

result: Dict[str, List[str]] = {}
for field, agg in query["aggregations"].items():
if len(agg["buckets"]) > limit:
logger.warning(
"Skipping enum field %s: exceeds limit of %d unique values. "
"Consider excluding this field from enumeration or increase the limit.",
field,
limit,
)
continue
result[field] = [bucket["key"] for bucket in agg["buckets"]]
return result

async def create_collection(self, collection: Collection, **kwargs: Any):
"""Create a single collection in the database.

Expand Down
3 changes: 2 additions & 1 deletion stac_fastapi/opensearch/stac_fastapi/opensearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TokenPaginationExtension,
TransactionExtension,
)
from stac_fastapi.extensions.core.filter import FilterConformanceClasses
from stac_fastapi.extensions.third_party import BulkTransactionExtension
from stac_fastapi.opensearch.config import OpensearchSettings
from stac_fastapi.opensearch.database_logic import (
Expand All @@ -56,7 +57,7 @@
client=EsAsyncBaseFiltersClient(database=database_logic)
)
filter_extension.conformance_classes.append(
"http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators"
FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS
)

aggregation_extension = AggregationExtension(
Expand Down
31 changes: 31 additions & 0 deletions stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,37 @@ async def get_items_mapping(self, collection_id: str) -> Dict[str, Any]:
except exceptions.NotFoundError:
raise NotFoundError(f"Mapping for index {index_name} not found")

async def get_items_unique_values(
self, collection_id: str, field_names: Iterable[str], *, limit: int = 100
) -> Dict[str, List[str]]:
"""Get the unique values for the given fields in the collection."""
limit_plus_one = limit + 1
index_name = index_alias_by_collection_id(collection_id)

query = await self.client.search(
index=index_name,
body={
"size": 0,
"aggs": {
field: {"terms": {"field": field, "size": limit_plus_one}}
for field in field_names
},
},
)

result: Dict[str, List[str]] = {}
for field, agg in query["aggregations"].items():
if len(agg["buckets"]) > limit:
logger.warning(
"Skipping enum field %s: exceeds limit of %d unique values. "
"Consider excluding this field from enumeration or increase the limit.",
field,
limit,
)
continue
result[field] = [bucket["key"] for bucket in agg["buckets"]]
return result

async def create_collection(self, collection: Collection, **kwargs: Any):
"""Create a single collection in the database.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Filter client implementation for Elasticsearch/OpenSearch."""

from collections import deque
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import attr

from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
from stac_fastapi.core.extensions.filter import DEFAULT_QUERYABLES
from stac_fastapi.core.extensions.filter import ALL_QUERYABLES, DEFAULT_QUERYABLES
from stac_fastapi.extensions.core.filter.client import AsyncBaseFiltersClient
from stac_fastapi.sfeos_helpers.mappings import ES_MAPPING_TYPE_TO_JSON

Expand Down Expand Up @@ -59,31 +59,31 @@ async def get_queryables(

mapping_data = await self.database.get_items_mapping(collection_id)
mapping_properties = next(iter(mapping_data.values()))["mappings"]["properties"]
stack = deque(mapping_properties.items())
stack: deque[Tuple[str, Dict[str, Any]]] = deque(mapping_properties.items())
enum_fields: Dict[str, Dict[str, Any]] = {}

while stack:
field_name, field_def = stack.popleft()
field_fqn, field_def = stack.popleft()

# Iterate over nested fields
field_properties = field_def.get("properties")
if field_properties:
# Fields in Item Properties should be exposed with their un-prefixed names,
# and not require expressions to prefix them with properties,
# e.g., eo:cloud_cover instead of properties.eo:cloud_cover.
if field_name == "properties":
stack.extend(field_properties.items())
else:
stack.extend(
(f"{field_name}.{k}", v) for k, v in field_properties.items()
)
stack.extend(
(f"{field_fqn}.{k}", v) for k, v in field_properties.items()
)

# Skip non-indexed or disabled fields
field_type = field_def.get("type")
if not field_type or not field_def.get("enabled", True):
continue

# Fields in Item Properties should be exposed with their un-prefixed names,
# and not require expressions to prefix them with properties,
# e.g., eo:cloud_cover instead of properties.eo:cloud_cover.
field_name = field_fqn.removeprefix("properties.")

# Generate field properties
field_result = DEFAULT_QUERYABLES.get(field_name, {})
field_result = ALL_QUERYABLES.get(field_name, {})
properties[field_name] = field_result

field_name_human = field_name.replace("_", " ").title()
Expand All @@ -95,4 +95,13 @@ async def get_queryables(
if field_type in {"date", "date_nanos"}:
field_result.setdefault("format", "date-time")

if field_result.pop("$enum", False):
enum_fields[field_fqn] = field_result

if enum_fields:
for field_fqn, unique_values in (
await self.database.get_items_unique_values(collection_id, enum_fields)
).items():
enum_fields[field_fqn]["enum"] = unique_values

return queryables
16 changes: 13 additions & 3 deletions stac_fastapi/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from stac_fastapi.core.rate_limit import setup_rate_limit
from stac_fastapi.core.route_dependencies import get_route_dependencies
from stac_fastapi.core.utilities import get_bool_env
from stac_fastapi.extensions.core.filter import FilterConformanceClasses
from stac_fastapi.sfeos_helpers.aggregation import EsAsyncBaseAggregationClient
from stac_fastapi.sfeos_helpers.filter import EsAsyncBaseFiltersClient

if os.getenv("BACKEND", "elasticsearch").lower() == "opensearch":
from stac_fastapi.opensearch.config import AsyncOpensearchSettings as AsyncSettings
Expand All @@ -39,9 +41,11 @@
)
else:
from stac_fastapi.elasticsearch.config import (
ElasticsearchSettings as SearchSettings,
AsyncElasticsearchSettings as AsyncSettings,
)
from stac_fastapi.elasticsearch.config import (
ElasticsearchSettings as SearchSettings,
)
from stac_fastapi.elasticsearch.database_logic import (
DatabaseLogic,
create_collection_index,
Expand Down Expand Up @@ -198,6 +202,13 @@ def bulk_txn_client():
async def app():
settings = AsyncSettings()

filter_extension = FilterExtension(
client=EsAsyncBaseFiltersClient(database=database)
)
filter_extension.conformance_classes.append(
FilterConformanceClasses.ADVANCED_COMPARISON_OPERATORS
)

aggregation_extension = AggregationExtension(
client=EsAsyncBaseAggregationClient(
database=database, session=None, settings=settings
Expand All @@ -217,7 +228,7 @@ async def app():
FieldsExtension(),
QueryExtension(),
TokenPaginationExtension(),
FilterExtension(),
filter_extension,
FreeTextExtension(),
]

Expand Down Expand Up @@ -313,7 +324,6 @@ async def app_client_rate_limit(app_rate_limit):

@pytest_asyncio.fixture(scope="session")
async def app_basic_auth():

stac_fastapi_route_dependencies = """[
{
"routes":[{"method":"*","path":"*"}],
Expand Down
51 changes: 50 additions & 1 deletion stac_fastapi/tests/extensions/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
import logging
import os
import uuid
from os import listdir
from os.path import isfile, join
from typing import Callable, Dict

import pytest
from httpx import AsyncClient

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -40,7 +43,6 @@ async def test_filter_extension_collection_link(app_client, load_test_data):

@pytest.mark.asyncio
async def test_search_filters_post(app_client, ctx):

filters = []
pwd = f"{THIS_DIR}/cql2"
for fn in [fn for f in listdir(pwd) if isfile(fn := join(pwd, f))]:
Expand Down Expand Up @@ -625,3 +627,50 @@ async def test_search_filter_extension_cql2text_s_disjoint_property(app_client,
assert resp.status_code == 200
resp_json = resp.json()
assert len(resp_json["features"]) == 1


@pytest.mark.asyncio
async def test_queryables_enum_platform(
app_client: AsyncClient,
load_test_data: Callable[[str], Dict],
monkeypatch: pytest.MonkeyPatch,
):
# Arrange
# Enforce instant database refresh
# TODO: Is there a better way to do this?
monkeypatch.setenv("DATABASE_REFRESH", "true")

# Create collection
collection_data = load_test_data("test_collection.json")
collection_id = collection_data["id"] = f"enum-test-collection-{uuid.uuid4()}"
r = await app_client.post("/collections", json=collection_data)
r.raise_for_status()

# Create items with different platform values
NUM_ITEMS = 3
for i in range(1, NUM_ITEMS + 1):
item_data = load_test_data("test_item.json")
item_data["id"] = f"enum-test-item-{i}"
item_data["collection"] = collection_id
item_data["properties"]["platform"] = "landsat-8" if i % 2 else "sentinel-2"
r = await app_client.post(f"/collections/{collection_id}/items", json=item_data)
r.raise_for_status()

# Act
# Test queryables endpoint
queryables = (
(await app_client.get(f"/collections/{collection_data['id']}/queryables"))
.raise_for_status()
.json()
)

# Assert
# Verify distinct values (should only have 2 unique values despite 3 items)
properties = queryables["properties"]
platform_info = properties["platform"]
platform_values = platform_info["enum"]
assert set(platform_values) == {"landsat-8", "sentinel-2"}

# Clean up
r = await app_client.delete(f"/collections/{collection_id}")
r.raise_for_status()