Skip to content

Commit 5bb8356

Browse files
committed
add wait_for, reorg es db logic
1 parent e7198f5 commit 5bb8356

File tree

6 files changed

+278
-99
lines changed

6 files changed

+278
-99
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 10 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -673,27 +673,6 @@ class TransactionsClient(AsyncBaseTransactionsClient):
673673
settings: ApiBaseSettings = attr.ib()
674674
session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
675675

676-
def _resolve_refresh(self, **kwargs) -> bool:
677-
"""
678-
Resolve the `refresh` parameter from kwargs or the environment variable.
679-
680-
Args:
681-
**kwargs: Additional keyword arguments, including `refresh`.
682-
683-
Returns:
684-
bool: The resolved value of the `refresh` parameter.
685-
"""
686-
refresh = kwargs.get(
687-
"refresh", self.database.async_settings.database_refresh == "true"
688-
)
689-
if "refresh" in kwargs:
690-
logger.info(f"`refresh` parameter explicitly passed in kwargs: {refresh}")
691-
else:
692-
logger.info(
693-
f"`refresh` parameter derived from environment variable: {refresh}"
694-
)
695-
return refresh
696-
697676
@overrides
698677
async def create_item(
699678
self, collection_id: str, item: Union[Item, ItemCollection], **kwargs
@@ -717,9 +696,6 @@ async def create_item(
717696
request = kwargs.get("request")
718697
base_url = str(request.base_url)
719698

720-
# Resolve the `refresh` parameter
721-
refresh = self._resolve_refresh(**kwargs)
722-
723699
# Convert Pydantic model to dict for uniform processing
724700
item_dict = item.model_dump(mode="json")
725701

@@ -738,9 +714,9 @@ async def create_item(
738714
attempted = len(processed_items)
739715

740716
success, errors = await self.database.bulk_async(
741-
collection_id,
742-
processed_items,
743-
refresh=refresh,
717+
collection_id=collection_id,
718+
processed_items=processed_items,
719+
**kwargs,
744720
)
745721
if errors:
746722
logger.error(
@@ -754,10 +730,7 @@ async def create_item(
754730

755731
# Handle single item
756732
await self.database.create_item(
757-
item_dict,
758-
refresh=refresh,
759-
base_url=base_url,
760-
exist_ok=False,
733+
item_dict, base_url=base_url, exist_ok=False, **kwargs
761734
)
762735
return ItemSerializer.db_to_stac(item_dict, base_url)
763736

@@ -783,14 +756,11 @@ async def update_item(
783756
item = item.model_dump(mode="json")
784757
base_url = str(kwargs["request"].base_url)
785758

786-
# Resolve the `refresh` parameter
787-
refresh = self._resolve_refresh(**kwargs)
788-
789759
now = datetime_type.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
790760
item["properties"]["updated"] = now
791761

792762
await self.database.create_item(
793-
item, refresh=refresh, base_url=base_url, exist_ok=True
763+
item, base_url=base_url, exist_ok=True, **kwargs
794764
)
795765

796766
return ItemSerializer.db_to_stac(item, base_url)
@@ -806,11 +776,8 @@ async def delete_item(self, item_id: str, collection_id: str, **kwargs) -> None:
806776
Returns:
807777
None: Returns 204 No Content on successful deletion
808778
"""
809-
# Resolve the `refresh` parameter
810-
refresh = self._resolve_refresh(**kwargs)
811-
812779
await self.database.delete_item(
813-
item_id=item_id, collection_id=collection_id, refresh=refresh
780+
item_id=item_id, collection_id=collection_id, **kwargs
814781
)
815782
return None
816783

@@ -833,11 +800,8 @@ async def create_collection(
833800
collection = collection.model_dump(mode="json")
834801
request = kwargs["request"]
835802

836-
# Resolve the `refresh` parameter
837-
refresh = self._resolve_refresh(**kwargs)
838-
839803
collection = self.database.collection_serializer.stac_to_db(collection, request)
840-
await self.database.create_collection(collection=collection, refresh=refresh)
804+
await self.database.create_collection(collection=collection, **kwargs)
841805
return CollectionSerializer.db_to_stac(
842806
collection,
843807
request,
@@ -871,12 +835,9 @@ async def update_collection(
871835

872836
request = kwargs["request"]
873837

874-
# Resolve the `refresh` parameter
875-
refresh = self._resolve_refresh(**kwargs)
876-
877838
collection = self.database.collection_serializer.stac_to_db(collection, request)
878839
await self.database.update_collection(
879-
collection_id=collection_id, collection=collection, refresh=refresh
840+
collection_id=collection_id, collection=collection, **kwargs
880841
)
881842

882843
return CollectionSerializer.db_to_stac(
@@ -901,12 +862,7 @@ async def delete_collection(self, collection_id: str, **kwargs) -> None:
901862
Raises:
902863
NotFoundError: If the collection doesn't exist
903864
"""
904-
# Resolve the `refresh` parameter
905-
refresh = self._resolve_refresh(**kwargs)
906-
907-
await self.database.delete_collection(
908-
collection_id=collection_id, refresh=refresh
909-
)
865+
await self.database.delete_collection(collection_id=collection_id, **kwargs)
910866
return None
911867

912868

@@ -965,19 +921,6 @@ def bulk_item_insert(
965921
else:
966922
base_url = ""
967923

968-
# Use `refresh` from kwargs if provided, otherwise fall back to the environment variable
969-
refresh = kwargs.get(
970-
"refresh", self.database.sync_settings.database_refresh == "true"
971-
)
972-
973-
# Log the value of `refresh` and its source
974-
if "refresh" in kwargs:
975-
logger.info(f"`refresh` parameter explicitly passed in kwargs: {refresh}")
976-
else:
977-
logger.info(
978-
f"`refresh` parameter derived from environment variable: {refresh}"
979-
)
980-
981924
processed_items = []
982925
for item in items.items.values():
983926
try:
@@ -996,7 +939,7 @@ def bulk_item_insert(
996939
success, errors = self.database.bulk_sync(
997940
collection_id,
998941
processed_items,
999-
refresh=refresh,
942+
**kwargs,
1000943
)
1001944
if errors:
1002945
logger.error(f"Bulk sync operation encountered errors: {errors}")

stac_fastapi/core/stac_fastapi/core/utilities.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,30 @@ def get_bool_env(name: str, default: bool = False) -> bool:
3939
return default
4040

4141

42+
def resolve_refresh(refresh: str) -> str:
43+
"""
44+
Resolve the `refresh` parameter from kwargs or the environment variable.
45+
46+
Args:
47+
refresh (str): The `refresh` parameter value.
48+
49+
Returns:
50+
str: The resolved value of the `refresh` parameter, which can be "true", "false", or "wait_for".
51+
"""
52+
logger = logging.getLogger(__name__)
53+
54+
# Normalize and validate the `refresh` value
55+
refresh = refresh.lower()
56+
if refresh not in {"true", "false", "wait_for"}:
57+
raise ValueError(
58+
"Invalid value for `refresh`. Must be 'true', 'false', or 'wait_for'."
59+
)
60+
61+
# Log the resolved value
62+
logger.info(f"`refresh` parameter resolved to: {refresh}")
63+
return refresh
64+
65+
4266
def bbox2polygon(b0: float, b1: float, b2: float, b3: float) -> List[List[List[float]]]:
4367
"""Transform a bounding box represented by its four coordinates `b0`, `b1`, `b2`, and `b3` into a polygon.
4468

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import ssl
6-
from typing import Any, Dict, Set
6+
from typing import Any, Dict, Set, Union
77

88
import certifi
99
from elasticsearch._async.client import AsyncElasticsearch
@@ -87,7 +87,24 @@ class ElasticsearchSettings(ApiSettings, ApiBaseSettings):
8787
enable_response_models: bool = False
8888
enable_direct_response: bool = get_bool_env("ENABLE_DIRECT_RESPONSE", default=False)
8989
raise_on_bulk_error: bool = get_bool_env("RAISE_ON_BULK_ERROR", default=False)
90-
database_refresh: bool = get_bool_env("DATABASE_REFRESH", default=False)
90+
91+
@property
92+
def database_refresh(self) -> Union[bool, str]:
93+
"""
94+
Get the value of the DATABASE_REFRESH environment variable.
95+
96+
Returns:
97+
Union[bool, str]: The value of DATABASE_REFRESH, which can be True, False, or "wait_for".
98+
"""
99+
value = os.getenv("DATABASE_REFRESH", "false").lower()
100+
if value in {"true", "false"}:
101+
return value == "true"
102+
elif value == "wait_for":
103+
return "wait_for"
104+
else:
105+
raise ValueError(
106+
"Invalid value for DATABASE_REFRESH. Must be 'true', 'false', or 'wait_for'."
107+
)
91108

92109
@property
93110
def create_client(self):
@@ -109,7 +126,24 @@ class AsyncElasticsearchSettings(ApiSettings, ApiBaseSettings):
109126
enable_response_models: bool = False
110127
enable_direct_response: bool = get_bool_env("ENABLE_DIRECT_RESPONSE", default=False)
111128
raise_on_bulk_error: bool = get_bool_env("RAISE_ON_BULK_ERROR", default=False)
112-
database_refresh: bool = get_bool_env("DATABASE_REFRESH", default=False)
129+
130+
@property
131+
def database_refresh(self) -> Union[bool, str]:
132+
"""
133+
Get the value of the DATABASE_REFRESH environment variable.
134+
135+
Returns:
136+
Union[bool, str]: The value of DATABASE_REFRESH, which can be True, False, or "wait_for".
137+
"""
138+
value = os.getenv("DATABASE_REFRESH", "false").lower()
139+
if value in {"true", "false"}:
140+
return value == "true"
141+
elif value == "wait_for":
142+
return "wait_for"
143+
else:
144+
raise ValueError(
145+
"Invalid value for DATABASE_REFRESH. Must be 'true', 'false', or 'wait_for'."
146+
)
113147

114148
@property
115149
def create_client(self):

0 commit comments

Comments
 (0)