Skip to content

Commit c25229e

Browse files
refactor aggregate in database_logic (#294)
**Related Issue(s):** N/A **Description:** Refactor `aggregate()` in database logic to allow extending the supported set of aggregations. The mapping of aggregation name to Elasticsearch/OpenSearch functionality was in the `aggregate()` function, which made it difficult to alter the set of supported aggregations. I moved the mapping to a property of the database logic, so it can be modified when the database logic is instantiated. **PR Checklist:** - [x] Code is formatted and linted (run `pre-commit run --all-files`) - [x] Tests pass (run `make test`) - [x] Documentation has been updated to reflect changes, if applicable - [x] Changes are added to the changelog --------- Co-authored-by: Jonathan Healy <jonathan.d.healy@gmail.com>
1 parent 7b2b191 commit c25229e

File tree

4 files changed

+225
-170
lines changed

4 files changed

+225
-170
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
99

1010
### Added
1111

12+
- Added `datetime_frequency_interval` parameter for `datetime_frequency` aggregation. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294)
13+
1214
### Changed
15+
16+
- Refactored aggregation in database logic. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294)
1317
- Fixed the `self` link for the `/collections/{collection_id}/aggregations` endpoint. [#295](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/295)
1418

1519
## [v3.1.0] - 2024-09-02

stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class EsAggregationExtensionGetRequest(
5050
centroid_geotile_grid_frequency_precision: Optional[int] = attr.ib(default=None)
5151
geometry_geohash_grid_frequency_precision: Optional[int] = attr.ib(default=None)
5252
geometry_geotile_grid_frequency_precision: Optional[int] = attr.ib(default=None)
53+
datetime_frequency_interval: Optional[str] = attr.ib(default=None)
5354

5455

5556
class EsAggregationExtensionPostRequest(
@@ -62,6 +63,7 @@ class EsAggregationExtensionPostRequest(
6263
centroid_geotile_grid_frequency_precision: Optional[int] = None
6364
geometry_geohash_grid_frequency_precision: Optional[int] = None
6465
geometry_geotile_grid_frequency_precision: Optional[int] = None
66+
datetime_frequency_interval: Optional[str] = None
6567

6668

6769
@attr.s
@@ -124,6 +126,8 @@ class EsAsyncAggregationClient(AsyncBaseAggregationClient):
124126
MAX_GEOHASH_PRECISION = 12
125127
MAX_GEOHEX_PRECISION = 15
126128
MAX_GEOTILE_PRECISION = 29
129+
SUPPORTED_DATETIME_INTERVAL = {"day", "month", "year"}
130+
DEFAULT_DATETIME_INTERVAL = "month"
127131

128132
async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs):
129133
"""Get the available aggregations for a catalog or collection defined in the STAC JSON. If no aggregations, default aggregations are used."""
@@ -182,6 +186,30 @@ def extract_precision(
182186
else:
183187
return min_value
184188

189+
def extract_date_histogram_interval(self, value: Optional[str]) -> str:
190+
"""
191+
Ensure that the interval for the date histogram is valid. If no value is provided, the default will be returned.
192+
193+
Args:
194+
value: value entered by the user
195+
196+
Returns:
197+
string containing the date histogram interval to use.
198+
199+
Raises:
200+
HTTPException: if the supplied value is not in the supported intervals
201+
"""
202+
if value is not None:
203+
if value not in self.SUPPORTED_DATETIME_INTERVAL:
204+
raise HTTPException(
205+
status_code=400,
206+
detail=f"Invalid datetime interval. Must be one of {self.SUPPORTED_DATETIME_INTERVAL}",
207+
)
208+
else:
209+
return value
210+
else:
211+
return self.DEFAULT_DATETIME_INTERVAL
212+
185213
@staticmethod
186214
def _return_date(
187215
interval: Optional[Union[DateTimeType, str]]
@@ -319,6 +347,7 @@ async def aggregate(
319347
centroid_geotile_grid_frequency_precision: Optional[int] = None,
320348
geometry_geohash_grid_frequency_precision: Optional[int] = None,
321349
geometry_geotile_grid_frequency_precision: Optional[int] = None,
350+
datetime_frequency_interval: Optional[str] = None,
322351
**kwargs,
323352
) -> Union[Dict, Exception]:
324353
"""Get aggregations from the database."""
@@ -339,6 +368,7 @@ async def aggregate(
339368
"centroid_geotile_grid_frequency_precision": centroid_geotile_grid_frequency_precision,
340369
"geometry_geohash_grid_frequency_precision": geometry_geohash_grid_frequency_precision,
341370
"geometry_geotile_grid_frequency_precision": geometry_geotile_grid_frequency_precision,
371+
"datetime_frequency_interval": datetime_frequency_interval,
342372
}
343373

344374
if collection_id:
@@ -475,6 +505,10 @@ async def aggregate(
475505
self.MAX_GEOTILE_PRECISION,
476506
)
477507

508+
datetime_frequency_interval = self.extract_date_histogram_interval(
509+
aggregate_request.datetime_frequency_interval,
510+
)
511+
478512
try:
479513
db_response = await self.database.aggregate(
480514
collections,
@@ -485,6 +519,7 @@ async def aggregate(
485519
centroid_geotile_grid_precision,
486520
geometry_geohash_grid_precision,
487521
geometry_geotile_grid_precision,
522+
datetime_frequency_interval,
488523
)
489524
except Exception as error:
490525
if not isinstance(error, IndexError):

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 93 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import os
66
from base64 import urlsafe_b64decode, urlsafe_b64encode
7+
from copy import deepcopy
78
from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Type, Union
89

910
import attr
@@ -316,6 +317,77 @@ class DatabaseLogic:
316317

317318
extensions: List[str] = attr.ib(default=attr.Factory(list))
318319

320+
aggregation_mapping: Dict[str, Dict[str, Any]] = {
321+
"total_count": {"value_count": {"field": "id"}},
322+
"collection_frequency": {"terms": {"field": "collection", "size": 100}},
323+
"platform_frequency": {"terms": {"field": "properties.platform", "size": 100}},
324+
"cloud_cover_frequency": {
325+
"range": {
326+
"field": "properties.eo:cloud_cover",
327+
"ranges": [
328+
{"to": 5},
329+
{"from": 5, "to": 15},
330+
{"from": 15, "to": 40},
331+
{"from": 40},
332+
],
333+
}
334+
},
335+
"datetime_frequency": {
336+
"date_histogram": {
337+
"field": "properties.datetime",
338+
"calendar_interval": "month",
339+
}
340+
},
341+
"datetime_min": {"min": {"field": "properties.datetime"}},
342+
"datetime_max": {"max": {"field": "properties.datetime"}},
343+
"grid_code_frequency": {
344+
"terms": {
345+
"field": "properties.grid:code",
346+
"missing": "none",
347+
"size": 10000,
348+
}
349+
},
350+
"sun_elevation_frequency": {
351+
"histogram": {"field": "properties.view:sun_elevation", "interval": 5}
352+
},
353+
"sun_azimuth_frequency": {
354+
"histogram": {"field": "properties.view:sun_azimuth", "interval": 5}
355+
},
356+
"off_nadir_frequency": {
357+
"histogram": {"field": "properties.view:off_nadir", "interval": 5}
358+
},
359+
"centroid_geohash_grid_frequency": {
360+
"geohash_grid": {
361+
"field": "properties.proj:centroid",
362+
"precision": 1,
363+
}
364+
},
365+
"centroid_geohex_grid_frequency": {
366+
"geohex_grid": {
367+
"field": "properties.proj:centroid",
368+
"precision": 0,
369+
}
370+
},
371+
"centroid_geotile_grid_frequency": {
372+
"geotile_grid": {
373+
"field": "properties.proj:centroid",
374+
"precision": 0,
375+
}
376+
},
377+
"geometry_geohash_grid_frequency": {
378+
"geohash_grid": {
379+
"field": "geometry",
380+
"precision": 1,
381+
}
382+
},
383+
"geometry_geotile_grid_frequency": {
384+
"geotile_grid": {
385+
"field": "geometry",
386+
"precision": 0,
387+
}
388+
},
389+
}
390+
319391
"""CORE LOGIC"""
320392

321393
async def get_all_collections(
@@ -657,104 +729,41 @@ async def aggregate(
657729
centroid_geotile_grid_precision: int,
658730
geometry_geohash_grid_precision: int,
659731
geometry_geotile_grid_precision: int,
732+
datetime_frequency_interval: str,
660733
ignore_unavailable: Optional[bool] = True,
661734
):
662735
"""Return aggregations of STAC Items."""
663-
agg_2_es = {
664-
"total_count": {"value_count": {"field": "id"}},
665-
"collection_frequency": {"terms": {"field": "collection", "size": 100}},
666-
"platform_frequency": {
667-
"terms": {"field": "properties.platform", "size": 100}
668-
},
669-
"cloud_cover_frequency": {
670-
"range": {
671-
"field": "properties.eo:cloud_cover",
672-
"ranges": [
673-
{"to": 5},
674-
{"from": 5, "to": 15},
675-
{"from": 15, "to": 40},
676-
{"from": 40},
677-
],
678-
}
679-
},
680-
"datetime_frequency": {
681-
"date_histogram": {
682-
"field": "properties.datetime",
683-
"calendar_interval": "month",
684-
}
685-
},
686-
"datetime_min": {"min": {"field": "properties.datetime"}},
687-
"datetime_max": {"max": {"field": "properties.datetime"}},
688-
"grid_code_frequency": {
689-
"terms": {
690-
"field": "properties.grid:code",
691-
"missing": "none",
692-
"size": 10000,
693-
}
694-
},
695-
"sun_elevation_frequency": {
696-
"histogram": {"field": "properties.view:sun_elevation", "interval": 5}
697-
},
698-
"sun_azimuth_frequency": {
699-
"histogram": {"field": "properties.view:sun_azimuth", "interval": 5}
700-
},
701-
"off_nadir_frequency": {
702-
"histogram": {"field": "properties.view:off_nadir", "interval": 5}
703-
},
704-
}
705-
706736
search_body: Dict[str, Any] = {}
707737
query = search.query.to_dict() if search.query else None
708738
if query:
709739
search_body["query"] = query
710740

711741
logger.debug("Aggregations: %s", aggregations)
712742

713-
# include all aggregations specified
714-
# this will ignore aggregations with the wrong names
715-
search_body["aggregations"] = {
716-
k: v for k, v in agg_2_es.items() if k in aggregations
717-
}
718-
719-
if "centroid_geohash_grid_frequency" in aggregations:
720-
search_body["aggregations"]["centroid_geohash_grid_frequency"] = {
721-
"geohash_grid": {
722-
"field": "properties.proj:centroid",
723-
"precision": centroid_geohash_grid_precision,
724-
}
725-
}
726-
727-
if "centroid_geohex_grid_frequency" in aggregations:
728-
search_body["aggregations"]["centroid_geohex_grid_frequency"] = {
729-
"geohex_grid": {
730-
"field": "properties.proj:centroid",
731-
"precision": centroid_geohex_grid_precision,
732-
}
743+
def _fill_aggregation_parameters(name: str, agg: dict) -> dict:
744+
[key] = agg.keys()
745+
agg_precision = {
746+
"centroid_geohash_grid_frequency": centroid_geohash_grid_precision,
747+
"centroid_geohex_grid_frequency": centroid_geohex_grid_precision,
748+
"centroid_geotile_grid_frequency": centroid_geotile_grid_precision,
749+
"geometry_geohash_grid_frequency": geometry_geohash_grid_precision,
750+
"geometry_geotile_grid_frequency": geometry_geotile_grid_precision,
733751
}
752+
if name in agg_precision:
753+
agg[key]["precision"] = agg_precision[name]
734754

735-
if "centroid_geotile_grid_frequency" in aggregations:
736-
search_body["aggregations"]["centroid_geotile_grid_frequency"] = {
737-
"geotile_grid": {
738-
"field": "properties.proj:centroid",
739-
"precision": centroid_geotile_grid_precision,
740-
}
741-
}
755+
if key == "date_histogram":
756+
agg[key]["calendar_interval"] = datetime_frequency_interval
742757

743-
if "geometry_geohash_grid_frequency" in aggregations:
744-
search_body["aggregations"]["geometry_geohash_grid_frequency"] = {
745-
"geohash_grid": {
746-
"field": "geometry",
747-
"precision": geometry_geohash_grid_precision,
748-
}
749-
}
758+
return agg
750759

751-
if "geometry_geotile_grid_frequency" in aggregations:
752-
search_body["aggregations"]["geometry_geotile_grid_frequency"] = {
753-
"geotile_grid": {
754-
"field": "geometry",
755-
"precision": geometry_geotile_grid_precision,
756-
}
757-
}
760+
# include all aggregations specified
761+
# this will ignore aggregations with the wrong names
762+
search_body["aggregations"] = {
763+
k: _fill_aggregation_parameters(k, deepcopy(v))
764+
for k, v in self.aggregation_mapping.items()
765+
if k in aggregations
766+
}
758767

759768
index_param = indices(collection_ids)
760769
search_task = asyncio.create_task(

0 commit comments

Comments
 (0)