Skip to content

Commit f10ffb0

Browse files
author
Phil Varner
authored
Merge pull request #74 from stac-utils/pv/async-clients
convert to async and refactor tests
2 parents a466f75 + d427b39 commit f10ffb0

File tree

13 files changed

+572
-661
lines changed

13 files changed

+572
-661
lines changed

stac_fastapi/elasticsearch/setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@
2929
"requests",
3030
"ciso8601",
3131
"overrides",
32-
"starlette",
32+
"httpx",
3333
],
3434
"docs": ["mkdocs", "mkdocs-material", "pdocs"],
3535
"server": ["uvicorn[standard]>=0.12.0,<0.14.0"],
3636
}
3737

38-
3938
setup(
4039
name="stac-fastapi.elasticsearch",
4140
description="An implementation of STAC API based on the FastAPI framework.",

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
55
from stac_fastapi.elasticsearch.core import (
66
BulkTransactionsClient,
7-
CoreCrudClient,
7+
CoreClient,
88
TransactionsClient,
99
)
1010
from stac_fastapi.elasticsearch.extensions import QueryExtension
@@ -37,7 +37,7 @@
3737
api = StacApi(
3838
settings=settings,
3939
extensions=extensions,
40-
client=CoreCrudClient(session=session, post_request_model=post_request_model),
40+
client=CoreClient(session=session, post_request_model=post_request_model),
4141
search_get_request_model=create_get_request_model(extensions),
4242
search_post_request_model=post_request_model,
4343
)

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from typing import Set
44

5-
from elasticsearch import Elasticsearch
5+
from elasticsearch import AsyncElasticsearch, Elasticsearch
66

77
from stac_fastapi.types.config import ApiSettings
88

@@ -16,13 +16,25 @@ class ElasticsearchSettings(ApiSettings):
1616
# Fields which are defined by STAC but not included in the database model
1717
forbidden_fields: Set[str] = {"type"}
1818

19-
# Fields which are item properties but indexed as distinct fields in the database model
20-
indexed_fields: Set[str] = {"datetime"}
21-
2219
@property
2320
def create_client(self):
2421
"""Create es client."""
2522
return Elasticsearch(
2623
[{"host": str(DOMAIN), "port": str(PORT)}],
2724
headers={"accept": "application/vnd.elasticsearch+json; compatible-with=7"},
2825
)
26+
27+
28+
class AsyncElasticsearchSettings(ApiSettings):
29+
"""API settings."""
30+
31+
# Fields which are defined by STAC but not included in the database model
32+
forbidden_fields: Set[str] = {"type"}
33+
34+
@property
35+
def create_client(self):
36+
"""Create async elasticsearch client."""
37+
return AsyncElasticsearch(
38+
[{"host": str(DOMAIN), "port": str(PORT)}],
39+
headers={"accept": "application/vnd.elasticsearch+json; compatible-with=7"},
40+
)

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from urllib.parse import urljoin
88

99
import attr
10+
import stac_pydantic.api
1011
from fastapi import HTTPException
1112
from overrides import overrides
1213
from pydantic import ValidationError
@@ -23,7 +24,7 @@
2324
Items,
2425
)
2526
from stac_fastapi.types import stac as stac_types
26-
from stac_fastapi.types.core import BaseCoreClient, BaseTransactionsClient
27+
from stac_fastapi.types.core import AsyncBaseCoreClient, AsyncBaseTransactionsClient
2728
from stac_fastapi.types.links import CollectionLinks
2829
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
2930

@@ -33,23 +34,25 @@
3334

3435

3536
@attr.s
36-
class CoreCrudClient(BaseCoreClient):
37+
class CoreClient(AsyncBaseCoreClient):
3738
"""Client for core endpoints defined by stac."""
3839

3940
session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
40-
item_serializer: Type[serializers.Serializer] = attr.ib(
41+
item_serializer: Type[serializers.ItemSerializer] = attr.ib(
4142
default=serializers.ItemSerializer
4243
)
43-
collection_serializer: Type[serializers.Serializer] = attr.ib(
44+
collection_serializer: Type[serializers.CollectionSerializer] = attr.ib(
4445
default=serializers.CollectionSerializer
4546
)
4647
database = DatabaseLogic()
4748

4849
@overrides
49-
def all_collections(self, **kwargs) -> Collections:
50+
async def all_collections(self, **kwargs) -> Collections:
5051
"""Read all collections from the database."""
5152
base_url = str(kwargs["request"].base_url)
52-
serialized_collections = self.database.get_all_collections(base_url=base_url)
53+
serialized_collections = await self.database.get_all_collections(
54+
base_url=base_url
55+
)
5356

5457
links = [
5558
{
@@ -74,21 +77,21 @@ def all_collections(self, **kwargs) -> Collections:
7477
return collection_list
7578

7679
@overrides
77-
def get_collection(self, collection_id: str, **kwargs) -> Collection:
80+
async def get_collection(self, collection_id: str, **kwargs) -> Collection:
7881
"""Get collection by id."""
7982
base_url = str(kwargs["request"].base_url)
80-
collection = self.database.find_collection(collection_id=collection_id)
83+
collection = await self.database.find_collection(collection_id=collection_id)
8184
return self.collection_serializer.db_to_stac(collection, base_url)
8285

8386
@overrides
84-
def item_collection(
87+
async def item_collection(
8588
self, collection_id: str, limit: int = 10, token: str = None, **kwargs
8689
) -> ItemCollection:
8790
"""Read an item collection from the database."""
8891
links = []
8992
base_url = str(kwargs["request"].base_url)
9093

91-
serialized_children, count = self.database.get_item_collection(
94+
serialized_children, count = await self.database.get_item_collection(
9295
collection_id=collection_id, limit=limit, base_url=base_url
9396
)
9497

@@ -108,10 +111,12 @@ def item_collection(
108111
)
109112

110113
@overrides
111-
def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
114+
async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
112115
"""Get item by item id, collection id."""
113116
base_url = str(kwargs["request"].base_url)
114-
item = self.database.get_one_item(item_id=item_id, collection_id=collection_id)
117+
item = await self.database.get_one_item(
118+
item_id=item_id, collection_id=collection_id
119+
)
115120
return self.item_serializer.db_to_stac(item, base_url)
116121

117122
@staticmethod
@@ -139,7 +144,7 @@ def _return_date(interval_str):
139144
return {"lte": end_date, "gte": start_date}
140145

141146
@overrides
142-
def get_search(
147+
async def get_search(
143148
self,
144149
collections: Optional[List[str]] = None,
145150
ids: Optional[List[str]] = None,
@@ -192,18 +197,19 @@ def get_search(
192197
search_request = self.post_request_model(**base_args)
193198
except ValidationError:
194199
raise HTTPException(status_code=400, detail="Invalid parameters provided")
195-
resp = self.post_search(search_request, request=kwargs["request"])
200+
resp = await self.post_search(search_request, request=kwargs["request"])
196201

197202
return resp
198203

199-
def post_search(self, search_request, **kwargs) -> ItemCollection:
204+
@overrides
205+
async def post_search(
206+
self, search_request: stac_pydantic.api.Search, **kwargs
207+
) -> ItemCollection:
200208
"""POST search catalog."""
201209
base_url = str(kwargs["request"].base_url)
202-
search = self.database.create_search_object()
210+
search = self.database.create_search()
203211

204212
if search_request.query:
205-
if type(search_request.query) == str:
206-
search_request.query = json.loads(search_request.query)
207213
for (field_name, expr) in search_request.query.items():
208214
field = "properties__" + field_name
209215
for (op, value) in expr.items():
@@ -217,7 +223,7 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
217223
)
218224

219225
if search_request.collections:
220-
search = self.database.search_collections(
226+
search = self.database.filter_collections(
221227
search=search, collection_ids=search_request.collections
222228
)
223229

@@ -247,9 +253,9 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
247253
search=search, field=sort.field, direction=sort.direction
248254
)
249255

250-
count = self.database.search_count(search=search)
256+
count = await self.database.search_count(search=search)
251257

252-
response_features = self.database.execute_search(
258+
response_features = await self.database.execute_search(
253259
search=search, limit=search_request.limit, base_url=base_url
254260
)
255261

@@ -298,57 +304,57 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
298304

299305

300306
@attr.s
301-
class TransactionsClient(BaseTransactionsClient):
307+
class TransactionsClient(AsyncBaseTransactionsClient):
302308
"""Transactions extension specific CRUD operations."""
303309

304310
session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
305311
database = DatabaseLogic()
306312

307313
@overrides
308-
def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
314+
async def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
309315
"""Create item."""
310316
base_url = str(kwargs["request"].base_url)
311317

312318
# If a feature collection is posted
313319
if item["type"] == "FeatureCollection":
314320
bulk_client = BulkTransactionsClient()
315321
processed_items = [
316-
bulk_client.preprocess_item(item, base_url) for item in item["features"]
322+
bulk_client.preprocess_item(item, base_url) for item in item["features"] # type: ignore
317323
]
318-
self.database.bulk_sync(
324+
await self.database.bulk_async(
319325
processed_items, refresh=kwargs.get("refresh", False)
320326
)
321327

322-
return None
328+
return None # type: ignore
323329
else:
324-
item = self.database.prep_create_item(item=item, base_url=base_url)
325-
self.database.create_item(item, refresh=kwargs.get("refresh", False))
330+
item = await self.database.prep_create_item(item=item, base_url=base_url)
331+
await self.database.create_item(item, refresh=kwargs.get("refresh", False))
326332
return item
327333

328334
@overrides
329-
def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
335+
async def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
330336
"""Update item."""
331337
base_url = str(kwargs["request"].base_url)
332338
now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z")
333339
item["properties"]["updated"] = str(now)
334340

335-
self.database.check_collection_exists(collection_id=item["collection"])
341+
await self.database.check_collection_exists(collection_id=item["collection"])
336342
# todo: index instead of delete and create
337-
self.delete_item(item_id=item["id"], collection_id=item["collection"])
338-
self.create_item(item=item, **kwargs)
343+
await self.delete_item(item_id=item["id"], collection_id=item["collection"])
344+
await self.create_item(item=item, **kwargs)
339345

340346
return ItemSerializer.db_to_stac(item, base_url)
341347

342348
@overrides
343-
def delete_item(
349+
async def delete_item(
344350
self, item_id: str, collection_id: str, **kwargs
345351
) -> stac_types.Item:
346352
"""Delete item."""
347-
self.database.delete_item(item_id=item_id, collection_id=collection_id)
348-
return None
353+
await self.database.delete_item(item_id=item_id, collection_id=collection_id)
354+
return None # type: ignore
349355

350356
@overrides
351-
def create_collection(
357+
async def create_collection(
352358
self, collection: stac_types.Collection, **kwargs
353359
) -> stac_types.Collection:
354360
"""Create collection."""
@@ -357,28 +363,30 @@ def create_collection(
357363
collection_id=collection["id"], base_url=base_url
358364
).create_links()
359365
collection["links"] = collection_links
360-
self.database.create_collection(collection=collection)
366+
await self.database.create_collection(collection=collection)
361367

362368
return CollectionSerializer.db_to_stac(collection, base_url)
363369

364370
@overrides
365-
def update_collection(
371+
async def update_collection(
366372
self, collection: stac_types.Collection, **kwargs
367373
) -> stac_types.Collection:
368374
"""Update collection."""
369375
base_url = str(kwargs["request"].base_url)
370376

371-
self.database.find_collection(collection_id=collection["id"])
372-
self.delete_collection(collection["id"])
373-
self.create_collection(collection, **kwargs)
377+
await self.database.find_collection(collection_id=collection["id"])
378+
await self.delete_collection(collection["id"])
379+
await self.create_collection(collection, **kwargs)
374380

375381
return CollectionSerializer.db_to_stac(collection, base_url)
376382

377383
@overrides
378-
def delete_collection(self, collection_id: str, **kwargs) -> stac_types.Collection:
384+
async def delete_collection(
385+
self, collection_id: str, **kwargs
386+
) -> stac_types.Collection:
379387
"""Delete collection."""
380-
self.database.delete_collection(collection_id=collection_id)
381-
return None
388+
await self.database.delete_collection(collection_id=collection_id)
389+
return None # type: ignore
382390

383391

384392
@attr.s
@@ -395,8 +403,7 @@ def __attrs_post_init__(self):
395403

396404
def preprocess_item(self, item: stac_types.Item, base_url) -> stac_types.Item:
397405
"""Preprocess items to match data model."""
398-
item = self.database.prep_create_item(item=item, base_url=base_url)
399-
return item
406+
return self.database.sync_prep_create_item(item=item, base_url=base_url)
400407

401408
@overrides
402409
def bulk_item_insert(

0 commit comments

Comments
 (0)