Skip to content

convert to async and refactor tests #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions stac_fastapi/elasticsearch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@
"requests",
"ciso8601",
"overrides",
"starlette",
"httpx",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

httpx is an async client

],
"docs": ["mkdocs", "mkdocs-material", "pdocs"],
"server": ["uvicorn[standard]>=0.12.0,<0.14.0"],
}


setup(
name="stac-fastapi.elasticsearch",
description="An implementation of STAC API based on the FastAPI framework.",
Expand Down
4 changes: 2 additions & 2 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
from stac_fastapi.elasticsearch.core import (
BulkTransactionsClient,
CoreCrudClient,
CoreClient,
TransactionsClient,
)
from stac_fastapi.elasticsearch.extensions import QueryExtension
Expand Down Expand Up @@ -37,7 +37,7 @@
api = StacApi(
settings=settings,
extensions=extensions,
client=CoreCrudClient(session=session, post_request_model=post_request_model),
client=CoreClient(session=session, post_request_model=post_request_model),
search_get_request_model=create_get_request_model(extensions),
search_post_request_model=post_request_model,
)
Expand Down
20 changes: 16 additions & 4 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import Set

from elasticsearch import Elasticsearch
from elasticsearch import AsyncElasticsearch, Elasticsearch
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need the sync Elasticsearch client b/c the bulk api does not support async yet.


from stac_fastapi.types.config import ApiSettings

Expand All @@ -16,13 +16,25 @@ class ElasticsearchSettings(ApiSettings):
# Fields which are defined by STAC but not included in the database model
forbidden_fields: Set[str] = {"type"}

# Fields which are item properties but indexed as distinct fields in the database model
indexed_fields: Set[str] = {"datetime"}

@property
def create_client(self):
"""Create es client."""
return Elasticsearch(
[{"host": str(DOMAIN), "port": str(PORT)}],
headers={"accept": "application/vnd.elasticsearch+json; compatible-with=7"},
)


class AsyncElasticsearchSettings(ApiSettings):
"""API settings."""

# Fields which are defined by STAC but not included in the database model
forbidden_fields: Set[str] = {"type"}

@property
def create_client(self):
"""Create async elasticsearch client."""
return AsyncElasticsearch(
[{"host": str(DOMAIN), "port": str(PORT)}],
headers={"accept": "application/vnd.elasticsearch+json; compatible-with=7"},
)
99 changes: 53 additions & 46 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import urljoin

import attr
import stac_pydantic.api
from fastapi import HTTPException
from overrides import overrides
from pydantic import ValidationError
Expand All @@ -23,7 +24,7 @@
Items,
)
from stac_fastapi.types import stac as stac_types
from stac_fastapi.types.core import BaseCoreClient, BaseTransactionsClient
from stac_fastapi.types.core import AsyncBaseCoreClient, AsyncBaseTransactionsClient
from stac_fastapi.types.links import CollectionLinks
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection

Expand All @@ -33,23 +34,25 @@


@attr.s
class CoreCrudClient(BaseCoreClient):
class CoreClient(AsyncBaseCoreClient):
"""Client for core endpoints defined by stac."""

session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
item_serializer: Type[serializers.Serializer] = attr.ib(
item_serializer: Type[serializers.ItemSerializer] = attr.ib(
default=serializers.ItemSerializer
)
collection_serializer: Type[serializers.Serializer] = attr.ib(
collection_serializer: Type[serializers.CollectionSerializer] = attr.ib(
default=serializers.CollectionSerializer
)
database = DatabaseLogic()

@overrides
def all_collections(self, **kwargs) -> Collections:
async def all_collections(self, **kwargs) -> Collections:
"""Read all collections from the database."""
base_url = str(kwargs["request"].base_url)
serialized_collections = self.database.get_all_collections(base_url=base_url)
serialized_collections = await self.database.get_all_collections(
base_url=base_url
)

links = [
{
Expand All @@ -74,21 +77,21 @@ def all_collections(self, **kwargs) -> Collections:
return collection_list

@overrides
def get_collection(self, collection_id: str, **kwargs) -> Collection:
async def get_collection(self, collection_id: str, **kwargs) -> Collection:
"""Get collection by id."""
base_url = str(kwargs["request"].base_url)
collection = self.database.find_collection(collection_id=collection_id)
collection = await self.database.find_collection(collection_id=collection_id)
return self.collection_serializer.db_to_stac(collection, base_url)

@overrides
def item_collection(
async def item_collection(
self, collection_id: str, limit: int = 10, token: str = None, **kwargs
) -> ItemCollection:
"""Read an item collection from the database."""
links = []
base_url = str(kwargs["request"].base_url)

serialized_children, count = self.database.get_item_collection(
serialized_children, count = await self.database.get_item_collection(
collection_id=collection_id, limit=limit, base_url=base_url
)

Expand All @@ -108,10 +111,12 @@ def item_collection(
)

@overrides
def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item:
"""Get item by item id, collection id."""
base_url = str(kwargs["request"].base_url)
item = self.database.get_one_item(item_id=item_id, collection_id=collection_id)
item = await self.database.get_one_item(
item_id=item_id, collection_id=collection_id
)
return self.item_serializer.db_to_stac(item, base_url)

@staticmethod
Expand Down Expand Up @@ -139,7 +144,7 @@ def _return_date(interval_str):
return {"lte": end_date, "gte": start_date}

@overrides
def get_search(
async def get_search(
self,
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
Expand Down Expand Up @@ -192,18 +197,19 @@ def get_search(
search_request = self.post_request_model(**base_args)
except ValidationError:
raise HTTPException(status_code=400, detail="Invalid parameters provided")
resp = self.post_search(search_request, request=kwargs["request"])
resp = await self.post_search(search_request, request=kwargs["request"])

return resp

def post_search(self, search_request, **kwargs) -> ItemCollection:
@overrides
async def post_search(
self, search_request: stac_pydantic.api.Search, **kwargs
) -> ItemCollection:
"""POST search catalog."""
base_url = str(kwargs["request"].base_url)
search = self.database.create_search_object()
search = self.database.create_search()

if search_request.query:
if type(search_request.query) == str:
search_request.query = json.loads(search_request.query)
for (field_name, expr) in search_request.query.items():
field = "properties__" + field_name
for (op, value) in expr.items():
Expand All @@ -217,7 +223,7 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
)

if search_request.collections:
search = self.database.search_collections(
search = self.database.filter_collections(
search=search, collection_ids=search_request.collections
)

Expand Down Expand Up @@ -247,9 +253,9 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
search=search, field=sort.field, direction=sort.direction
)

count = self.database.search_count(search=search)
count = await self.database.search_count(search=search)

response_features = self.database.execute_search(
response_features = await self.database.execute_search(
search=search, limit=search_request.limit, base_url=base_url
)

Expand Down Expand Up @@ -298,57 +304,57 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:


@attr.s
class TransactionsClient(BaseTransactionsClient):
class TransactionsClient(AsyncBaseTransactionsClient):
"""Transactions extension specific CRUD operations."""

session: Session = attr.ib(default=attr.Factory(Session.create_from_env))
database = DatabaseLogic()

@overrides
def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
async def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
"""Create item."""
base_url = str(kwargs["request"].base_url)

# If a feature collection is posted
if item["type"] == "FeatureCollection":
bulk_client = BulkTransactionsClient()
processed_items = [
bulk_client.preprocess_item(item, base_url) for item in item["features"]
bulk_client.preprocess_item(item, base_url) for item in item["features"] # type: ignore
]
self.database.bulk_sync(
await self.database.bulk_async(
processed_items, refresh=kwargs.get("refresh", False)
)

return None
return None # type: ignore
else:
item = self.database.prep_create_item(item=item, base_url=base_url)
self.database.create_item(item, refresh=kwargs.get("refresh", False))
item = await self.database.prep_create_item(item=item, base_url=base_url)
await self.database.create_item(item, refresh=kwargs.get("refresh", False))
return item

@overrides
def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
async def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
"""Update item."""
base_url = str(kwargs["request"].base_url)
now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z")
item["properties"]["updated"] = str(now)

self.database.check_collection_exists(collection_id=item["collection"])
await self.database.check_collection_exists(collection_id=item["collection"])
# todo: index instead of delete and create
self.delete_item(item_id=item["id"], collection_id=item["collection"])
self.create_item(item=item, **kwargs)
await self.delete_item(item_id=item["id"], collection_id=item["collection"])
await self.create_item(item=item, **kwargs)

return ItemSerializer.db_to_stac(item, base_url)

@overrides
def delete_item(
async def delete_item(
self, item_id: str, collection_id: str, **kwargs
) -> stac_types.Item:
"""Delete item."""
self.database.delete_item(item_id=item_id, collection_id=collection_id)
return None
await self.database.delete_item(item_id=item_id, collection_id=collection_id)
return None # type: ignore

@overrides
def create_collection(
async def create_collection(
self, collection: stac_types.Collection, **kwargs
) -> stac_types.Collection:
"""Create collection."""
Expand All @@ -357,28 +363,30 @@ def create_collection(
collection_id=collection["id"], base_url=base_url
).create_links()
collection["links"] = collection_links
self.database.create_collection(collection=collection)
await self.database.create_collection(collection=collection)

return CollectionSerializer.db_to_stac(collection, base_url)

@overrides
def update_collection(
async def update_collection(
self, collection: stac_types.Collection, **kwargs
) -> stac_types.Collection:
"""Update collection."""
base_url = str(kwargs["request"].base_url)

self.database.find_collection(collection_id=collection["id"])
self.delete_collection(collection["id"])
self.create_collection(collection, **kwargs)
await self.database.find_collection(collection_id=collection["id"])
await self.delete_collection(collection["id"])
await self.create_collection(collection, **kwargs)

return CollectionSerializer.db_to_stac(collection, base_url)

@overrides
def delete_collection(self, collection_id: str, **kwargs) -> stac_types.Collection:
async def delete_collection(
self, collection_id: str, **kwargs
) -> stac_types.Collection:
"""Delete collection."""
self.database.delete_collection(collection_id=collection_id)
return None
await self.database.delete_collection(collection_id=collection_id)
return None # type: ignore


@attr.s
Expand All @@ -395,8 +403,7 @@ def __attrs_post_init__(self):

def preprocess_item(self, item: stac_types.Item, base_url) -> stac_types.Item:
"""Preprocess items to match data model."""
item = self.database.prep_create_item(item=item, base_url=base_url)
return item
return self.database.sync_prep_create_item(item=item, base_url=base_url)

@overrides
def bulk_item_insert(
Expand Down
Loading