From 034cbd8111f2c120b54b17572d3b0a5db9a29877 Mon Sep 17 00:00:00 2001 From: Phil Varner Date: Tue, 22 Mar 2022 16:22:49 -0400 Subject: [PATCH 1/2] convert to async and refactor tests --- Makefile | 2 +- stac_fastapi/elasticsearch/setup.py | 3 +- .../stac_fastapi/elasticsearch/app.py | 4 +- .../stac_fastapi/elasticsearch/config.py | 20 +- .../stac_fastapi/elasticsearch/core.py | 99 +++--- .../elasticsearch/database_logic.py | 207 +++++++------ .../stac_fastapi/elasticsearch/indexes.py | 10 +- .../elasticsearch/tests/api/test_api.py | 171 +++++------ .../tests/clients/test_elasticsearch.py | 285 ++++++------------ stac_fastapi/elasticsearch/tests/conftest.py | 114 +++---- .../tests/resources/test_collection.py | 39 ++- .../tests/resources/test_conformance.py | 15 +- .../tests/resources/test_item.py | 262 ++++++++-------- .../tests/resources/test_mgmt.py | 4 +- 14 files changed, 573 insertions(+), 662 deletions(-) diff --git a/Makefile b/Makefile index e7bb1f74..18a276d4 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ docker-shell: .PHONY: test test: - -$(run_es) /bin/bash -c 'export && ./scripts/wait-for-it-es.sh elasticsearch:9200 && cd /app/stac_fastapi/elasticsearch/tests/ && pytest' + -$(run_es) /bin/bash -c 'export && ./scripts/wait-for-it-es.sh elasticsearch:9200 && cd /app/stac_fastapi/elasticsearch/tests/ && pytest api/test_api.py' docker-compose down .PHONY: run-database diff --git a/stac_fastapi/elasticsearch/setup.py b/stac_fastapi/elasticsearch/setup.py index 0c794a75..da175af7 100644 --- a/stac_fastapi/elasticsearch/setup.py +++ b/stac_fastapi/elasticsearch/setup.py @@ -29,13 +29,12 @@ "requests", "ciso8601", "overrides", - "starlette", + "httpx", ], "docs": ["mkdocs", "mkdocs-material", "pdocs"], "server": ["uvicorn[standard]>=0.12.0,<0.14.0"], } - setup( name="stac-fastapi.elasticsearch", description="An implementation of STAC API based on the FastAPI framework.", diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py index c768e3d1..256413d2 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/app.py @@ -4,7 +4,7 @@ from stac_fastapi.elasticsearch.config import ElasticsearchSettings from stac_fastapi.elasticsearch.core import ( BulkTransactionsClient, - CoreCrudClient, + CoreClient, TransactionsClient, ) from stac_fastapi.elasticsearch.extensions import QueryExtension @@ -37,7 +37,7 @@ api = StacApi( settings=settings, extensions=extensions, - client=CoreCrudClient(session=session, post_request_model=post_request_model), + client=CoreClient(session=session, post_request_model=post_request_model), search_get_request_model=create_get_request_model(extensions), search_post_request_model=post_request_model, ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py index 2345f6f9..42601587 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/config.py @@ -2,7 +2,7 @@ import os from typing import Set -from elasticsearch import Elasticsearch +from elasticsearch import AsyncElasticsearch, Elasticsearch from stac_fastapi.types.config import ApiSettings @@ -16,9 +16,6 @@ class ElasticsearchSettings(ApiSettings): # Fields which are defined by STAC but not included in the database model forbidden_fields: Set[str] = {"type"} - # Fields which are item properties but indexed as distinct fields in the database model - indexed_fields: Set[str] = {"datetime"} - @property def create_client(self): """Create es client.""" @@ -26,3 +23,18 @@ def create_client(self): [{"host": str(DOMAIN), "port": str(PORT)}], headers={"accept": "application/vnd.elasticsearch+json; compatible-with=7"}, ) + + +class AsyncElasticsearchSettings(ApiSettings): + """API settings.""" + + # Fields which are defined by STAC but not included in the database model + forbidden_fields: Set[str] = {"type"} + + @property + def create_client(self): + """Create async elasticsearch client.""" + return AsyncElasticsearch( + [{"host": str(DOMAIN), "port": str(PORT)}], + headers={"accept": "application/vnd.elasticsearch+json; compatible-with=7"}, + ) diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index 7847380c..4479dc6a 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -7,6 +7,7 @@ from urllib.parse import urljoin import attr +import stac_pydantic.api from fastapi import HTTPException from overrides import overrides from pydantic import ValidationError @@ -23,7 +24,7 @@ Items, ) from stac_fastapi.types import stac as stac_types -from stac_fastapi.types.core import BaseCoreClient, BaseTransactionsClient +from stac_fastapi.types.core import AsyncBaseCoreClient, AsyncBaseTransactionsClient from stac_fastapi.types.links import CollectionLinks from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection @@ -33,23 +34,25 @@ @attr.s -class CoreCrudClient(BaseCoreClient): +class CoreClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) - item_serializer: Type[serializers.Serializer] = attr.ib( + item_serializer: Type[serializers.ItemSerializer] = attr.ib( default=serializers.ItemSerializer ) - collection_serializer: Type[serializers.Serializer] = attr.ib( + collection_serializer: Type[serializers.CollectionSerializer] = attr.ib( default=serializers.CollectionSerializer ) database = DatabaseLogic() @overrides - def all_collections(self, **kwargs) -> Collections: + async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" base_url = str(kwargs["request"].base_url) - serialized_collections = self.database.get_all_collections(base_url=base_url) + serialized_collections = await self.database.get_all_collections( + base_url=base_url + ) links = [ { @@ -74,21 +77,21 @@ def all_collections(self, **kwargs) -> Collections: return collection_list @overrides - def get_collection(self, collection_id: str, **kwargs) -> Collection: + async def get_collection(self, collection_id: str, **kwargs) -> Collection: """Get collection by id.""" base_url = str(kwargs["request"].base_url) - collection = self.database.find_collection(collection_id=collection_id) + collection = await self.database.find_collection(collection_id=collection_id) return self.collection_serializer.db_to_stac(collection, base_url) @overrides - def item_collection( + async def item_collection( self, collection_id: str, limit: int = 10, token: str = None, **kwargs ) -> ItemCollection: """Read an item collection from the database.""" links = [] base_url = str(kwargs["request"].base_url) - serialized_children, count = self.database.get_item_collection( + serialized_children, count = await self.database.get_item_collection( collection_id=collection_id, limit=limit, base_url=base_url ) @@ -108,10 +111,12 @@ def item_collection( ) @overrides - def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: + async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: """Get item by item id, collection id.""" base_url = str(kwargs["request"].base_url) - item = self.database.get_one_item(item_id=item_id, collection_id=collection_id) + item = await self.database.get_one_item( + item_id=item_id, collection_id=collection_id + ) return self.item_serializer.db_to_stac(item, base_url) @staticmethod @@ -139,7 +144,7 @@ def _return_date(interval_str): return {"lte": end_date, "gte": start_date} @overrides - def get_search( + async def get_search( self, collections: Optional[List[str]] = None, ids: Optional[List[str]] = None, @@ -192,18 +197,19 @@ def get_search( search_request = self.post_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") - resp = self.post_search(search_request, request=kwargs["request"]) + resp = await self.post_search(search_request, request=kwargs["request"]) return resp - def post_search(self, search_request, **kwargs) -> ItemCollection: + @overrides + async def post_search( + self, search_request: stac_pydantic.api.Search, **kwargs + ) -> ItemCollection: """POST search catalog.""" base_url = str(kwargs["request"].base_url) - search = self.database.create_search_object() + search = self.database.create_search() if search_request.query: - if type(search_request.query) == str: - search_request.query = json.loads(search_request.query) for (field_name, expr) in search_request.query.items(): field = "properties__" + field_name for (op, value) in expr.items(): @@ -217,7 +223,7 @@ def post_search(self, search_request, **kwargs) -> ItemCollection: ) if search_request.collections: - search = self.database.search_collections( + search = self.database.filter_collections( search=search, collection_ids=search_request.collections ) @@ -247,9 +253,9 @@ def post_search(self, search_request, **kwargs) -> ItemCollection: search=search, field=sort.field, direction=sort.direction ) - count = self.database.search_count(search=search) + count = await self.database.search_count(search=search) - response_features = self.database.execute_search( + response_features = await self.database.execute_search( search=search, limit=search_request.limit, base_url=base_url ) @@ -298,14 +304,14 @@ def post_search(self, search_request, **kwargs) -> ItemCollection: @attr.s -class TransactionsClient(BaseTransactionsClient): +class TransactionsClient(AsyncBaseTransactionsClient): """Transactions extension specific CRUD operations.""" session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) database = DatabaseLogic() @overrides - def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: + async def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: """Create item.""" base_url = str(kwargs["request"].base_url) @@ -313,42 +319,42 @@ def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: if item["type"] == "FeatureCollection": bulk_client = BulkTransactionsClient() processed_items = [ - bulk_client.preprocess_item(item, base_url) for item in item["features"] + bulk_client.preprocess_item(item, base_url) for item in item["features"] # type: ignore ] - self.database.bulk_sync( + await self.database.bulk_async( processed_items, refresh=kwargs.get("refresh", False) ) - return None + return None # type: ignore else: - item = self.database.prep_create_item(item=item, base_url=base_url) - self.database.create_item(item, refresh=kwargs.get("refresh", False)) + item = await self.database.prep_create_item(item=item, base_url=base_url) + await self.database.create_item(item, refresh=kwargs.get("refresh", False)) return item @overrides - def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: + async def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: """Update item.""" base_url = str(kwargs["request"].base_url) now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z") item["properties"]["updated"] = str(now) - self.database.check_collection_exists(collection_id=item["collection"]) + await self.database.check_collection_exists(collection_id=item["collection"]) # todo: index instead of delete and create - self.delete_item(item_id=item["id"], collection_id=item["collection"]) - self.create_item(item=item, **kwargs) + await self.delete_item(item_id=item["id"], collection_id=item["collection"]) + await self.create_item(item=item, **kwargs) return ItemSerializer.db_to_stac(item, base_url) @overrides - def delete_item( + async def delete_item( self, item_id: str, collection_id: str, **kwargs ) -> stac_types.Item: """Delete item.""" - self.database.delete_item(item_id=item_id, collection_id=collection_id) - return None + await self.database.delete_item(item_id=item_id, collection_id=collection_id) + return None # type: ignore @overrides - def create_collection( + async def create_collection( self, collection: stac_types.Collection, **kwargs ) -> stac_types.Collection: """Create collection.""" @@ -357,28 +363,30 @@ def create_collection( collection_id=collection["id"], base_url=base_url ).create_links() collection["links"] = collection_links - self.database.create_collection(collection=collection) + await self.database.create_collection(collection=collection) return CollectionSerializer.db_to_stac(collection, base_url) @overrides - def update_collection( + async def update_collection( self, collection: stac_types.Collection, **kwargs ) -> stac_types.Collection: """Update collection.""" base_url = str(kwargs["request"].base_url) - self.database.find_collection(collection_id=collection["id"]) - self.delete_collection(collection["id"]) - self.create_collection(collection, **kwargs) + await self.database.find_collection(collection_id=collection["id"]) + await self.delete_collection(collection["id"]) + await self.create_collection(collection, **kwargs) return CollectionSerializer.db_to_stac(collection, base_url) @overrides - def delete_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: + async def delete_collection( + self, collection_id: str, **kwargs + ) -> stac_types.Collection: """Delete collection.""" - self.database.delete_collection(collection_id=collection_id) - return None + await self.database.delete_collection(collection_id=collection_id) + return None # type: ignore @attr.s @@ -395,8 +403,7 @@ def __attrs_post_init__(self): def preprocess_item(self, item: stac_types.Item, base_url) -> stac_types.Item: """Preprocess items to match data model.""" - item = self.database.prep_create_item(item=item, base_url=base_url) - return item + return self.database.sync_prep_create_item(item=item, base_url=base_url) @overrides def bulk_item_insert( diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index aa980084..27aa3dce 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,6 +1,6 @@ """Database logic.""" import logging -from typing import List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import attr import elasticsearch @@ -17,8 +17,11 @@ ) from stac_fastapi.elasticsearch import serializers -from stac_fastapi.elasticsearch.config import ElasticsearchSettings -from stac_fastapi.types.errors import ConflictError, ForeignKeyError, NotFoundError +from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings +from stac_fastapi.elasticsearch.config import ( + ElasticsearchSettings as SyncElasticsearchSettings, +) +from stac_fastapi.types.errors import ConflictError, NotFoundError from stac_fastapi.types.stac import Collection, Item logger = logging.getLogger(__name__) @@ -38,8 +41,8 @@ def mk_item_id(item_id: str, collection_id: str): class DatabaseLogic: """Database logic.""" - settings = ElasticsearchSettings() - client = settings.create_client + client = AsyncElasticsearchSettings().create_client + sync_client = SyncElasticsearchSettings().create_client item_serializer: Type[serializers.ItemSerializer] = attr.ib( default=serializers.ItemSerializer ) @@ -47,65 +50,52 @@ class DatabaseLogic: default=serializers.CollectionSerializer ) - @staticmethod - def bbox2poly(b0, b1, b2, b3): - """Transform bbox to polygon.""" - poly = [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] - return poly - """CORE LOGIC""" - def get_all_collections(self, base_url: str) -> List[Collection]: + async def get_all_collections(self, base_url: str) -> List[Collection]: """Database logic to retrieve a list of all collections.""" - try: - collections = self.client.search( - index=COLLECTIONS_INDEX, query={"match_all": {}} - ) - except elasticsearch.exceptions.NotFoundError: - return [] + collections = await self.client.search( + index=COLLECTIONS_INDEX, query={"match_all": {}} + ) - serialized_collections = [ - self.collection_serializer.db_to_stac( - collection["_source"], base_url=base_url - ) - for collection in collections["hits"]["hits"] + return [ + self.collection_serializer.db_to_stac(c["_source"], base_url=base_url) + for c in collections["hits"]["hits"] ] - return serialized_collections + async def search_count(self, search: Search) -> int: + """Database logic to count search results.""" + return ( + await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True)) + ).get("count") - def get_item_collection( + async def get_item_collection( self, collection_id: str, limit: int, base_url: str ) -> Tuple[List[Item], Optional[int]]: """Database logic to retrieve an ItemCollection and a count of items contained.""" - search = self.create_search_object() - search = self.search_collections(search, [collection_id]) - - collection_filter = Q( - "bool", should=[Q("match_phrase", **{"collection": collection_id})] - ) - search = search.query(collection_filter) + search = self.create_search() + search = search.filter("term", collection=collection_id) - count = self.search_count(search) + count = await self.search_count(search) - # search = search.sort({"id.keyword" : {"order" : "asc"}}) search = search.query()[0:limit] body = search.to_dict() - collection_children = self.client.search( + es_response = await self.client.search( index=ITEMS_INDEX, query=body["query"], sort=body.get("sort") ) serialized_children = [ self.item_serializer.db_to_stac(item["_source"], base_url=base_url) - for item in collection_children["hits"]["hits"] + for item in es_response["hits"]["hits"] ] return serialized_children, count - def get_one_item(self, collection_id: str, item_id: str) -> Item: + async def get_one_item(self, collection_id: str, item_id: str) -> Dict: """Database logic to retrieve a single item.""" try: - item = self.client.get( + item = await self.client.get( index=ITEMS_INDEX, id=mk_item_id(item_id, collection_id) ) except elasticsearch.exceptions.NotFoundError: @@ -115,7 +105,7 @@ def get_one_item(self, collection_id: str, item_id: str) -> Item: return item["_source"] @staticmethod - def create_search_object(): + def create_search(): """Database logic to create a nosql Search instance.""" return Search().sort( {"properties.datetime": {"order": "desc"}}, @@ -135,21 +125,21 @@ def create_query_filter(search: Search, op: str, field: str, value: float): return search @staticmethod - def search_ids(search: Search, item_ids: List): + def search_ids(search: Search, item_ids: List[str]): """Database logic to search a list of STAC item ids.""" - id_list = [] - for item_id in item_ids: - id_list.append(Q("match_phrase", **{"id": item_id})) - id_filter = Q("bool", should=id_list) - search = search.query(id_filter) - - return search + return search.query( + Q("bool", should=[Q("term", **{"id": i_id}) for i_id in item_ids]) + ) @staticmethod - def search_collections(search: Search, collection_ids: List): + def filter_collections(search: Search, collection_ids: List): """Database logic to search a list of STAC collection ids.""" - collections_query = [Q("term", **{"collection": cid}) for cid in collection_ids] - return search.query(Q("bool", should=collections_query)) + return search.query( + Q( + "bool", + should=[Q("term", **{"collection": c_id}) for c_id in collection_ids], + ) + ) @staticmethod def search_datetime(search: Search, datetime_search): @@ -170,12 +160,12 @@ def search_datetime(search: Search, datetime_search): @staticmethod def search_bbox(search: Search, bbox: List): """Database logic to search on bounding box.""" - poly = DatabaseLogic.bbox2poly(bbox[0], bbox[1], bbox[2], bbox[3]) + polygon = DatabaseLogic.bbox2polygon(bbox[0], bbox[1], bbox[2], bbox[3]) bbox_filter = Q( { "geo_shape": { "geometry": { - "shape": {"type": "polygon", "coordinates": poly}, + "shape": {"type": "polygon", "coordinates": polygon}, "relation": "intersects", } } @@ -219,45 +209,46 @@ def sort_field(search: Search, field, direction): """Database logic to sort search instance.""" return search.sort({field: {"order": direction}}) - def search_count(self, search: Search) -> int: - """Database logic to count search results.""" - try: - return self.client.count( - index=ITEMS_INDEX, body=search.to_dict(count=True) - ).get("count") - except elasticsearch.exceptions.NotFoundError: - raise NotFoundError("No items exist") - - def execute_search(self, search, limit: int, base_url: str) -> List: + async def execute_search(self, search, limit: int, base_url: str) -> List: """Database logic to execute search with limit.""" search = search.query()[0:limit] body = search.to_dict() - response = self.client.search( + es_response = await self.client.search( index=ITEMS_INDEX, query=body["query"], sort=body.get("sort") ) - if len(response["hits"]["hits"]) > 0: - response_features = [ - self.item_serializer.db_to_stac(item["_source"], base_url=base_url) - for item in response["hits"]["hits"] - ] - else: - response_features = [] - - return response_features + return [ + self.item_serializer.db_to_stac(hit["_source"], base_url=base_url) + for hit in es_response["hits"]["hits"] + ] """ TRANSACTION LOGIC """ - def check_collection_exists(self, collection_id: str): + async def check_collection_exists(self, collection_id: str): """Database logic to check if a collection exists.""" - if not self.client.exists(index=COLLECTIONS_INDEX, id=collection_id): - raise ForeignKeyError(f"Collection {collection_id} does not exist") + if not await self.client.exists(index=COLLECTIONS_INDEX, id=collection_id): + raise NotFoundError(f"Collection {collection_id} does not exist") + + async def prep_create_item(self, item: Item, base_url: str) -> Item: + """Database logic for prepping an item for insertion.""" + await self.check_collection_exists(collection_id=item["collection"]) + + if await self.client.exists( + index=ITEMS_INDEX, id=mk_item_id(item["id"], item["collection"]) + ): + raise ConflictError( + f"Item {item['id']} in collection {item['collection']} already exists" + ) + + return self.item_serializer.stac_to_db(item, base_url) - def prep_create_item(self, item: Item, base_url: str) -> Item: + def sync_prep_create_item(self, item: Item, base_url: str) -> Item: """Database logic for prepping an item for insertion.""" - self.check_collection_exists(collection_id=item["collection"]) + collection_id = item["collection"] + if not self.sync_client.exists(index=COLLECTIONS_INDEX, id=collection_id): + raise NotFoundError(f"Collection {collection_id} does not exist") - if self.client.exists( + if self.sync_client.exists( index=ITEMS_INDEX, id=mk_item_id(item["id"], item["collection"]) ): raise ConflictError( @@ -266,10 +257,10 @@ def prep_create_item(self, item: Item, base_url: str) -> Item: return self.item_serializer.stac_to_db(item, base_url) - def create_item(self, item: Item, refresh: bool = False): + async def create_item(self, item: Item, refresh: bool = False): """Database logic for creating one item.""" # todo: check if collection exists, but cache - es_resp = self.client.index( + es_resp = await self.client.index( index=ITEMS_INDEX, id=mk_item_id(item["id"], item["collection"]), document=item, @@ -281,10 +272,12 @@ def create_item(self, item: Item, refresh: bool = False): f"Item {item['id']} in collection {item['collection']} already exists" ) - def delete_item(self, item_id: str, collection_id: str, refresh: bool = False): + async def delete_item( + self, item_id: str, collection_id: str, refresh: bool = False + ): """Database logic for deleting one item.""" try: - self.client.delete( + await self.client.delete( index=ITEMS_INDEX, id=mk_item_id(item_id, collection_id), refresh=refresh, @@ -294,35 +287,51 @@ def delete_item(self, item_id: str, collection_id: str, refresh: bool = False): f"Item {item_id} in collection {collection_id} not found" ) - def create_collection(self, collection: Collection, refresh: bool = False): + async def create_collection(self, collection: Collection, refresh: bool = False): """Database logic for creating one collection.""" - if self.client.exists(index=COLLECTIONS_INDEX, id=collection["id"]): + if await self.client.exists(index=COLLECTIONS_INDEX, id=collection["id"]): raise ConflictError(f"Collection {collection['id']} already exists") - self.client.index( + await self.client.index( index=COLLECTIONS_INDEX, id=collection["id"], document=collection, refresh=refresh, ) - def find_collection(self, collection_id: str) -> Collection: + async def find_collection(self, collection_id: str) -> Collection: """Database logic to find and return a collection.""" try: - collection = self.client.get(index=COLLECTIONS_INDEX, id=collection_id) + collection = await self.client.get( + index=COLLECTIONS_INDEX, id=collection_id + ) except elasticsearch.exceptions.NotFoundError: raise NotFoundError(f"Collection {collection_id} not found") return collection["_source"] - def delete_collection(self, collection_id: str, refresh: bool = False): + async def delete_collection(self, collection_id: str, refresh: bool = False): """Database logic for deleting one collection.""" - _ = self.find_collection(collection_id=collection_id) - self.client.delete(index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh) + await self.find_collection(collection_id=collection_id) + await self.client.delete( + index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh + ) + + async def bulk_async(self, processed_items, refresh: bool = False): + """Database logic for async bulk item insertion.""" + # todo: wrap as async + helpers.bulk( + self.sync_client, self._mk_actions(processed_items), refresh=refresh + ) def bulk_sync(self, processed_items, refresh: bool = False): - """Database logic for bulk item insertion.""" - actions = [ + """Database logic for sync bulk item insertion.""" + helpers.bulk( + self.sync_client, self._mk_actions(processed_items), refresh=refresh + ) + + def _mk_actions(self, processed_items): + return [ { "_index": ITEMS_INDEX, "_id": mk_item_id(item["id"], item["collection"]), @@ -330,22 +339,26 @@ def bulk_sync(self, processed_items, refresh: bool = False): } for item in processed_items ] - helpers.bulk(self.client, actions, refresh=refresh) # DANGER - def delete_items(self) -> None: + async def delete_items(self) -> None: """Danger. this is only for tests.""" - self.client.delete_by_query( + await self.client.delete_by_query( index=ITEMS_INDEX, body={"query": {"match_all": {}}}, wait_for_completion=True, ) # DANGER - def delete_collections(self) -> None: + async def delete_collections(self) -> None: """Danger. this is only for tests.""" - self.client.delete_by_query( + await self.client.delete_by_query( index=COLLECTIONS_INDEX, body={"query": {"match_all": {}}}, wait_for_completion=True, ) + + @staticmethod + def bbox2polygon(b0, b1, b2, b3): + """Transform bbox to polygon.""" + return [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/indexes.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/indexes.py index d858a806..b102683e 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/indexes.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/indexes.py @@ -4,7 +4,7 @@ import attr -from stac_fastapi.elasticsearch.config import ElasticsearchSettings +from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings from stac_fastapi.elasticsearch.database_logic import COLLECTIONS_INDEX, ITEMS_INDEX from stac_fastapi.elasticsearch.session import Session @@ -16,7 +16,7 @@ class IndexesClient: """Elasticsearch client to handle index creation.""" session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) - client = ElasticsearchSettings().create_client + client = AsyncElasticsearchSettings().create_client ES_MAPPINGS_DYNAMIC_TEMPLATES = [ # Common https://github.com/radiantearth/stac-spec/blob/master/item-spec/common-metadata.md @@ -101,14 +101,14 @@ class IndexesClient: }, } - def create_indexes(self): + async def create_indexes(self): """Create the index for Items and Collections.""" - self.client.indices.create( + await self.client.indices.create( index=ITEMS_INDEX, mappings=self.ES_ITEMS_MAPPINGS, ignore=400, # ignore 400 already exists code ) - self.client.indices.create( + await self.client.indices.create( index=COLLECTIONS_INDEX, mappings=self.ES_COLLECTIONS_MAPPINGS, ignore=400, # ignore 400 already exists code diff --git a/stac_fastapi/elasticsearch/tests/api/test_api.py b/stac_fastapi/elasticsearch/tests/api/test_api.py index 71552034..e1e639cb 100644 --- a/stac_fastapi/elasticsearch/tests/api/test_api.py +++ b/stac_fastapi/elasticsearch/tests/api/test_api.py @@ -1,82 +1,73 @@ +import copy +import uuid from datetime import datetime, timedelta import pytest -from ..conftest import MockStarletteRequest, create_collection, create_item - -STAC_CORE_ROUTES = [ - "GET /", - "GET /collections", - "GET /collections/{collectionId}", - "GET /collections/{collectionId}/items", - "GET /collections/{collectionId}/items/{itemId}", - "GET /conformance", - "GET /search", - "POST /search", -] - -STAC_TRANSACTION_ROUTES = [ - "DELETE /collections/{collectionId}", - "DELETE /collections/{collectionId}/items/{itemId}", - "POST /collections", - "POST /collections/{collectionId}/items", - "PUT /collections", - "PUT /collections/{collectionId}/items", -] - - -@pytest.mark.skip(reason="fails ci only") -def test_post_search_content_type(app_client): +from ..conftest import MockRequest, create_collection, create_item + +ROUTES = set( + [ + "GET /_mgmt/ping", + "GET /docs/oauth2-redirect", + "HEAD /docs/oauth2-redirect", + "GET /", + "GET /conformance", + "GET /api", + "GET /api.html", + "HEAD /api", + "HEAD /api.html", + "GET /collections", + "GET /collections/{collection_id}", + "GET /collections/{collection_id}/items", + "GET /collections/{collection_id}/items/{item_id}", + "GET /search", + "POST /search", + "DELETE /collections/{collection_id}", + "DELETE /collections/{collection_id}/items/{item_id}", + "POST /collections", + "POST /collections/{collection_id}/items", + "PUT /collections", + "PUT /collections/{collection_id}/items", + ] +) + + +async def test_post_search_content_type(app_client, ctx): params = {"limit": 1} - resp = app_client.post("search", json=params) + resp = await app_client.post("/search", json=params) assert resp.headers["content-type"] == "application/geo+json" -@pytest.mark.skip(reason="fails ci only") -def test_get_search_content_type(app_client): - resp = app_client.get("search") +async def test_get_search_content_type(app_client, ctx): + resp = await app_client.get("/search") assert resp.headers["content-type"] == "application/geo+json" -def test_api_headers(app_client): - resp = app_client.get("/api") +async def test_api_headers(app_client): + resp = await app_client.get("/api") assert ( resp.headers["content-type"] == "application/vnd.oai.openapi+json;version=3.0" ) assert resp.status_code == 200 -@pytest.mark.skip(reason="not working") -def test_core_router(api_client): - core_routes = set(STAC_CORE_ROUTES) - api_routes = set( - [f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes] - ) - assert not core_routes - api_routes - - -@pytest.mark.skip(reason="not working") -def test_transactions_router(api_client): - transaction_routes = set(STAC_TRANSACTION_ROUTES) - api_routes = set( - [f"{list(route.methods)[0]} {route.path}" for route in api_client.app.routes] - ) - assert not transaction_routes - api_routes +async def test_router(app): + api_routes = set([f"{list(route.methods)[0]} {route.path}" for route in app.routes]) + assert len(api_routes - ROUTES) == 0 -@pytest.mark.skip(reason="unknown") -def test_app_transaction_extension(app_client, load_test_data, es_txn_client): - item = load_test_data("test_item.json") - resp = app_client.post(f"/collections/{item['collection']}/items", json=item) +async def test_app_transaction_extension(app_client, ctx): + item = copy.deepcopy(ctx.item) + item["id"] = str(uuid.uuid4()) + resp = await app_client.post(f"/collections/{item['collection']}/items", json=item) assert resp.status_code == 200 - es_txn_client.delete_item( - item["id"], item["collection"], request=MockStarletteRequest - ) + await app_client.delete(f"/collections/{item['collection']}/items/{item['id']}") -def test_app_search_response(app_client, ctx): - resp = app_client.get("/search", params={"ids": ["test-item"]}) +async def test_app_search_response(app_client, ctx): + resp = await app_client.get("/search", params={"ids": ["test-item"]}) assert resp.status_code == 200 resp_json = resp.json() @@ -86,17 +77,17 @@ def test_app_search_response(app_client, ctx): assert resp_json.get("stac_extensions") is None -def test_app_context_extension(app_client, ctx, es_txn_client): +async def test_app_context_extension(app_client, ctx, txn_client): test_item = ctx.item test_item["id"] = "test-item-2" test_item["collection"] = "test-collection-2" test_collection = ctx.collection test_collection["id"] = "test-collection-2" - create_collection(es_txn_client, test_collection) - create_item(es_txn_client, test_item) + await create_collection(txn_client, test_collection) + await create_item(txn_client, test_item) - resp = app_client.get( + resp = await app_client.get( f"/collections/{test_collection['id']}/items/{test_item['id']}" ) assert resp.status_code == 200 @@ -104,12 +95,12 @@ def test_app_context_extension(app_client, ctx, es_txn_client): assert resp_json["id"] == test_item["id"] assert resp_json["collection"] == test_item["collection"] - resp = app_client.get(f"/collections/{test_collection['id']}") + resp = await app_client.get(f"/collections/{test_collection['id']}") assert resp.status_code == 200 resp_json = resp.json() assert resp_json["id"] == test_collection["id"] - resp = app_client.post("/search", json={"collections": ["test-collection-2"]}) + resp = await app_client.post("/search", json={"collections": ["test-collection-2"]}) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 @@ -118,51 +109,49 @@ def test_app_context_extension(app_client, ctx, es_txn_client): @pytest.mark.skip(reason="fields not implemented yet") -def test_app_fields_extension(load_test_data, app_client, es_txn_client): +async def test_app_fields_extension(load_test_data, app_client, txn_client): item = load_test_data("test_item.json") - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) + txn_client.create_item(item, request=MockRequest, refresh=True) - resp = app_client.get("/search", params={"collections": ["test-collection"]}) + resp = await app_client.get("/search", params={"collections": ["test-collection"]}) assert resp.status_code == 200 resp_json = resp.json() assert list(resp_json["features"][0]["properties"]) == ["datetime"] - es_txn_client.delete_item( - item["id"], item["collection"], request=MockStarletteRequest - ) + txn_client.delete_item(item["id"], item["collection"], request=MockRequest) -def test_app_query_extension_gt(app_client, ctx): +async def test_app_query_extension_gt(app_client, ctx): params = {"query": {"proj:epsg": {"gt": ctx.item["properties"]["proj:epsg"]}}} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 0 -def test_app_query_extension_gte(app_client, ctx): +async def test_app_query_extension_gte(app_client, ctx): params = {"query": {"proj:epsg": {"gte": ctx.item["properties"]["proj:epsg"]}}} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 assert len(resp.json()["features"]) == 1 -def test_app_query_extension_limit_lt0(app_client, ctx): - assert app_client.post("/search", json={"limit": -1}).status_code == 400 +async def test_app_query_extension_limit_lt0(app_client, ctx): + assert (await app_client.post("/search", json={"limit": -1})).status_code == 400 -def test_app_query_extension_limit_gt10000(app_client, ctx): - assert app_client.post("/search", json={"limit": 10001}).status_code == 400 +async def test_app_query_extension_limit_gt10000(app_client, ctx): + assert (await app_client.post("/search", json={"limit": 10001})).status_code == 400 -def test_app_query_extension_limit_10000(app_client, ctx): +async def test_app_query_extension_limit_10000(app_client, ctx): params = {"limit": 10000} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 -def test_app_sort_extension(app_client, es_txn_client, ctx): +async def test_app_sort_extension(app_client, txn_client, ctx): first_item = ctx.item item_date = datetime.strptime( first_item["properties"]["datetime"], "%Y-%m-%dT%H:%M:%SZ" @@ -174,30 +163,30 @@ def test_app_sort_extension(app_client, es_txn_client, ctx): second_item["properties"]["datetime"] = another_item_date.strftime( "%Y-%m-%dT%H:%M:%SZ" ) - create_item(es_txn_client, second_item) + await create_item(txn_client, second_item) params = { "collections": [first_item["collection"]], "sortby": [{"field": "properties.datetime", "direction": "desc"}], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert resp_json["features"][0]["id"] == first_item["id"] assert resp_json["features"][1]["id"] == second_item["id"] -def test_search_invalid_date(app_client, ctx): +async def test_search_invalid_date(app_client, ctx): params = { "datetime": "2020-XX-01/2020-10-30", "collections": [ctx.item["collection"]], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 400 -def test_search_point_intersects(app_client, ctx): +async def test_search_point_intersects(app_client, ctx): point = [150.04, -33.14] intersects = {"type": "Point", "coordinates": point} @@ -205,14 +194,14 @@ def test_search_point_intersects(app_client, ctx): "intersects": intersects, "collections": [ctx.item["collection"]], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 -def test_datetime_non_interval(app_client, ctx): +async def test_datetime_non_interval(app_client, ctx): dt_formats = [ "2020-02-12T12:30:22+00:00", "2020-02-12T12:30:22.00Z", @@ -226,26 +215,26 @@ def test_datetime_non_interval(app_client, ctx): "collections": [ctx.item["collection"]], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() # datetime is returned in this format "2020-02-12T12:30:22Z" assert resp_json["features"][0]["properties"]["datetime"][0:19] == dt[0:19] -def test_bbox_3d(app_client, ctx): +async def test_bbox_3d(app_client, ctx): australia_bbox = [106.343365, -47.199523, 0.1, 168.218365, -19.437288, 0.1] params = { "bbox": australia_bbox, "collections": [ctx.item["collection"]], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 1 -def test_search_line_string_intersects(app_client, ctx): +async def test_search_line_string_intersects(app_client, ctx): line = [[150.04, -33.14], [150.22, -33.89]] intersects = {"type": "LineString", "coordinates": line} params = { @@ -253,7 +242,7 @@ def test_search_line_string_intersects(app_client, ctx): "collections": [ctx.item["collection"]], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 diff --git a/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py b/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py index 8c4b9c18..9cde8493 100644 --- a/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py +++ b/stac_fastapi/elasticsearch/tests/clients/test_elasticsearch.py @@ -5,208 +5,127 @@ import pytest from stac_pydantic import Item -from stac_fastapi.api.app import StacApi -from stac_fastapi.elasticsearch.core import BulkTransactionsClient, CoreCrudClient from stac_fastapi.extensions.third_party.bulk_transactions import Items from stac_fastapi.types.errors import ConflictError, NotFoundError -from ..conftest import MockStarletteRequest, create_item +from ..conftest import MockRequest, create_item -def test_create_collection( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - data = load_test_data("test_collection.json") - try: - es_txn_client.create_collection(data, request=MockStarletteRequest) - except Exception: - pass - coll = es_core.get_collection(data["id"], request=MockStarletteRequest) - assert coll["id"] == data["id"] - es_txn_client.delete_collection(data["id"], request=MockStarletteRequest) +async def test_create_collection(app_client, ctx, core_client, txn_client): + in_coll = deepcopy(ctx.collection) + in_coll["id"] = str(uuid.uuid4()) + await txn_client.create_collection(in_coll, request=MockRequest) + got_coll = await core_client.get_collection(in_coll["id"], request=MockRequest) + assert got_coll["id"] == in_coll["id"] + await txn_client.delete_collection(in_coll["id"], request=MockRequest) -def test_create_collection_already_exists( - es_txn_client, - load_test_data: Callable, -): - data = load_test_data("test_collection.json") - es_txn_client.create_collection(data, request=MockStarletteRequest) +async def test_create_collection_already_exists(app_client, ctx, txn_client): + data = deepcopy(ctx.collection) # change id to avoid elasticsearch duplicate key error data["_id"] = str(uuid.uuid4()) with pytest.raises(ConflictError): - es_txn_client.create_collection(data, request=MockStarletteRequest) + await txn_client.create_collection(data, request=MockRequest) - es_txn_client.delete_collection(data["id"], request=MockStarletteRequest) + await txn_client.delete_collection(data["id"], request=MockRequest) -def test_update_collection( - es_core: CoreCrudClient, - es_txn_client, +async def test_update_collection( + core_client, + txn_client, load_test_data: Callable, ): data = load_test_data("test_collection.json") - es_txn_client.create_collection(data, request=MockStarletteRequest) + await txn_client.create_collection(data, request=MockRequest) data["keywords"].append("new keyword") - es_txn_client.update_collection(data, request=MockStarletteRequest) + await txn_client.update_collection(data, request=MockRequest) - coll = es_core.get_collection(data["id"], request=MockStarletteRequest) + coll = await core_client.get_collection(data["id"], request=MockRequest) assert "new keyword" in coll["keywords"] - es_txn_client.delete_collection(data["id"], request=MockStarletteRequest) + await txn_client.delete_collection(data["id"], request=MockRequest) -def test_delete_collection( - es_core: CoreCrudClient, - es_txn_client, +async def test_delete_collection( + core_client, + txn_client, load_test_data: Callable, ): data = load_test_data("test_collection.json") - es_txn_client.create_collection(data, request=MockStarletteRequest) + await txn_client.create_collection(data, request=MockRequest) - es_txn_client.delete_collection(data["id"], request=MockStarletteRequest) + await txn_client.delete_collection(data["id"], request=MockRequest) with pytest.raises(NotFoundError): - es_core.get_collection(data["id"], request=MockStarletteRequest) + await core_client.get_collection(data["id"], request=MockRequest) -def test_get_collection( - es_core: CoreCrudClient, - es_txn_client, +async def test_get_collection( + core_client, + txn_client, load_test_data: Callable, ): data = load_test_data("test_collection.json") - es_txn_client.create_collection(data, request=MockStarletteRequest) - coll = es_core.get_collection(data["id"], request=MockStarletteRequest) + await txn_client.create_collection(data, request=MockRequest) + coll = await core_client.get_collection(data["id"], request=MockRequest) assert coll["id"] == data["id"] - es_txn_client.delete_collection(data["id"], request=MockStarletteRequest) - + await txn_client.delete_collection(data["id"], request=MockRequest) -def test_get_item( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - collection_data = load_test_data("test_collection.json") - item_data = load_test_data("test_item.json") - es_txn_client.create_collection(collection_data, request=MockStarletteRequest) - es_txn_client.create_item(item_data, request=MockStarletteRequest) - got_item = es_core.get_item( - item_id=item_data["id"], - collection_id=item_data["collection"], - request=MockStarletteRequest, - ) - assert got_item["id"] == item_data["id"] - assert got_item["collection"] == item_data["collection"] - es_txn_client.delete_collection(collection_data["id"], request=MockStarletteRequest) - es_txn_client.delete_item( - item_data["id"], item_data["collection"], request=MockStarletteRequest +async def test_get_item(app_client, ctx, core_client): + got_item = await core_client.get_item( + item_id=ctx.item["id"], + collection_id=ctx.item["collection"], + request=MockRequest, ) + assert got_item["id"] == ctx.item["id"] + assert got_item["collection"] == ctx.item["collection"] -def test_get_collection_items( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - - item = load_test_data("test_item.json") - - for _ in range(5): +async def test_get_collection_items(app_client, ctx, core_client, txn_client): + coll = ctx.collection + num_of_items_to_create = 5 + for _ in range(num_of_items_to_create): + item = deepcopy(ctx.item) item["id"] = str(uuid.uuid4()) - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) + await txn_client.create_item(item, request=MockRequest, refresh=True) - fc = es_core.item_collection(coll["id"], request=MockStarletteRequest) - assert len(fc["features"]) == 5 + fc = await core_client.item_collection(coll["id"], request=MockRequest) + assert len(fc["features"]) == num_of_items_to_create + 1 # ctx.item for item in fc["features"]: assert item["collection"] == coll["id"] - es_txn_client.delete_collection(coll["id"], request=MockStarletteRequest) - for item in fc["features"]: - es_txn_client.delete_item(item["id"], coll["id"], request=MockStarletteRequest) - -def test_create_item( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - item = load_test_data("test_item.json") - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) - resp = es_core.get_item( - item["id"], item["collection"], request=MockStarletteRequest +async def test_create_item(ctx, core_client, txn_client): + resp = await core_client.get_item( + ctx.item["id"], ctx.item["collection"], request=MockRequest ) - assert Item(**item).dict( + assert Item(**ctx.item).dict( exclude={"links": ..., "properties": {"created", "updated"}} ) == Item(**resp).dict(exclude={"links": ..., "properties": {"created", "updated"}}) - es_txn_client.delete_collection(coll["id"], request=MockStarletteRequest) - es_txn_client.delete_item(item["id"], coll["id"], request=MockStarletteRequest) - - -def test_create_item_already_exists( - es_txn_client, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - - item = load_test_data("test_item.json") - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) +async def test_create_item_already_exists(ctx, txn_client): with pytest.raises(ConflictError): - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) + await txn_client.create_item(ctx.item, request=MockRequest, refresh=True) - es_txn_client.delete_collection(coll["id"], request=MockStarletteRequest) - es_txn_client.delete_item(item["id"], coll["id"], request=MockStarletteRequest) +async def test_update_item(ctx, core_client, txn_client): + ctx.item["properties"]["foo"] = "bar" + await txn_client.update_item(ctx.item, request=MockRequest) -def test_update_item( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - - item = load_test_data("test_item.json") - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) - - item["properties"]["foo"] = "bar" - es_txn_client.update_item(item, request=MockStarletteRequest) - - updated_item = es_core.get_item( - item["id"], item["collection"], request=MockStarletteRequest + updated_item = await core_client.get_item( + ctx.item["id"], ctx.item["collection"], request=MockRequest ) assert updated_item["properties"]["foo"] == "bar" - es_txn_client.delete_collection(coll["id"], request=MockStarletteRequest) - es_txn_client.delete_item(item["id"], coll["id"], request=MockStarletteRequest) - - -def test_update_geometry( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - - item = load_test_data("test_item.json") - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) +async def test_update_geometry(ctx, core_client, txn_client): new_coordinates = [ [ [142.15052873427666, -33.82243006904891], @@ -217,62 +136,37 @@ def test_update_geometry( ] ] - item["geometry"]["coordinates"] = new_coordinates - es_txn_client.update_item(item, request=MockStarletteRequest) + ctx.item["geometry"]["coordinates"] = new_coordinates + await txn_client.update_item(ctx.item, request=MockRequest) - updated_item = es_core.get_item( - item["id"], item["collection"], request=MockStarletteRequest + updated_item = await core_client.get_item( + ctx.item["id"], ctx.item["collection"], request=MockRequest ) assert updated_item["geometry"]["coordinates"] == new_coordinates - es_txn_client.delete_collection(coll["id"], request=MockStarletteRequest) - es_txn_client.delete_item(item["id"], coll["id"], request=MockStarletteRequest) - -def test_delete_item( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - - item = load_test_data("test_item.json") - es_txn_client.create_item(item, request=MockStarletteRequest, refresh=True) - - es_txn_client.delete_item( - item["id"], item["collection"], request=MockStarletteRequest - ) - - es_txn_client.delete_collection(coll["id"], request=MockStarletteRequest) +async def test_delete_item(ctx, core_client, txn_client): + await txn_client.delete_item(ctx.item["id"], ctx.item["collection"]) with pytest.raises(NotFoundError): - es_core.get_item(item["id"], item["collection"], request=MockStarletteRequest) + await core_client.get_item( + ctx.item["id"], ctx.item["collection"], request=MockRequest + ) -def test_bulk_item_insert( - es_core: CoreCrudClient, - es_txn_client, - es_bulk_transactions: BulkTransactionsClient, - load_test_data: Callable, -): - coll = load_test_data("test_collection.json") - es_txn_client.create_collection(coll, request=MockStarletteRequest) - - item = load_test_data("test_item.json") - +async def test_bulk_item_insert(ctx, core_client, txn_client, bulk_txn_client): items = {} for _ in range(10): - _item = deepcopy(item) + _item = deepcopy(ctx.item) _item["id"] = str(uuid.uuid4()) items[_item["id"]] = _item # fc = es_core.item_collection(coll["id"], request=MockStarletteRequest) # assert len(fc["features"]) == 0 - es_bulk_transactions.bulk_item_insert(Items(items=items), refresh=True) + bulk_txn_client.bulk_item_insert(Items(items=items), refresh=True) - fc = es_core.item_collection(coll["id"], request=MockStarletteRequest) + fc = await core_client.item_collection(ctx.collection["id"], request=MockRequest) assert len(fc["features"]) >= 10 # for item in items: @@ -281,42 +175,35 @@ def test_bulk_item_insert( # ) -def test_feature_collection_insert( - es_core: CoreCrudClient, - es_txn_client, - es_bulk_transactions: BulkTransactionsClient, - test_item, - test_collection, +async def test_feature_collection_insert( + core_client, + txn_client, ctx, ): features = [] for _ in range(10): - _item = deepcopy(test_item) + _item = deepcopy(ctx.item) _item["id"] = str(uuid.uuid4()) features.append(_item) feature_collection = {"type": "FeatureCollection", "features": features} - create_item(es_txn_client, feature_collection) + await create_item(txn_client, feature_collection) - fc = es_core.item_collection(test_collection["id"], request=MockStarletteRequest) + fc = await core_client.item_collection(ctx.collection["id"], request=MockRequest) assert len(fc["features"]) >= 10 -def test_landing_page_no_collection_title( - es_core: CoreCrudClient, - es_txn_client, - load_test_data: Callable, - api_client: StacApi, -): - class MockStarletteRequestWithApp(MockStarletteRequest): - app = api_client.app +@pytest.mark.skip(reason="app fixture isn't injected or something?") +async def test_landing_page_no_collection_title(ctx, core_client, txn_client, app): + class MockRequestWithApp(MockRequest): + app = app - coll = load_test_data("test_collection.json") - del coll["title"] - es_txn_client.create_collection(coll, request=MockStarletteRequest) + ctx.collection["id"] = "new_id" + del ctx.collection["title"] + await txn_client.create_collection(ctx.collection, request=MockRequest) - landing_page = es_core.landing_page(request=MockStarletteRequestWithApp) + landing_page = await core_client.landing_page(request=MockRequestWithApp) for link in landing_page["links"]: - if link["href"].split("/")[-1] == coll["id"]: + if link["href"].split("/")[-1] == ctx.collection["id"]: assert link["title"] diff --git a/stac_fastapi/elasticsearch/tests/conftest.py b/stac_fastapi/elasticsearch/tests/conftest.py index 99141f67..a207836d 100644 --- a/stac_fastapi/elasticsearch/tests/conftest.py +++ b/stac_fastapi/elasticsearch/tests/conftest.py @@ -1,17 +1,19 @@ +import asyncio import copy import json import os from typing import Callable, Dict import pytest -from starlette.testclient import TestClient +import pytest_asyncio +from httpx import AsyncClient from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import create_request_model -from stac_fastapi.elasticsearch.config import ElasticsearchSettings +from stac_fastapi.elasticsearch.config import AsyncElasticsearchSettings from stac_fastapi.elasticsearch.core import ( BulkTransactionsClient, - CoreCrudClient, + CoreClient, TransactionsClient, ) from stac_fastapi.elasticsearch.database_logic import COLLECTIONS_INDEX, ITEMS_INDEX @@ -30,7 +32,17 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), "data") -class TestSettings(ElasticsearchSettings): +class Context: + def __init__(self, item, collection): + self.item = item + self.collection = collection + + +class MockRequest: + base_url = "http://test-server" + + +class TestSettings(AsyncElasticsearchSettings): class Config: env_file = ".env.test" @@ -39,20 +51,25 @@ class Config: Settings.set(settings) +@pytest.fixture(scope="session") +def event_loop(): + return asyncio.get_event_loop() + + def _load_file(filename: str) -> Dict: with open(os.path.join(DATA_DIR, filename)) as file: return json.load(file) +_test_item_prototype = _load_file("test_item.json") +_test_collection_prototype = _load_file("test_collection.json") + + @pytest.fixture def load_test_data() -> Callable[[str], Dict]: return _load_file -_test_item_prototype = _load_file("test_item.json") -_test_collection_prototype = _load_file("test_collection.json") - - @pytest.fixture def test_item() -> Dict: return copy.deepcopy(_test_item_prototype) @@ -63,82 +80,65 @@ def test_collection() -> Dict: return copy.deepcopy(_test_collection_prototype) -def create_collection(es_txn_client: TransactionsClient, collection: Dict) -> None: - es_txn_client.create_collection( - dict(collection), request=MockStarletteRequest, refresh=True +async def create_collection(txn_client: TransactionsClient, collection: Dict) -> None: + await txn_client.create_collection( + dict(collection), request=MockRequest, refresh=True ) -def create_item(es_txn_client: TransactionsClient, item: Dict) -> None: - es_txn_client.create_item(dict(item), request=MockStarletteRequest, refresh=True) - +async def create_item(txn_client: TransactionsClient, item: Dict) -> None: + await txn_client.create_item(item, request=MockRequest, refresh=True) -def delete_collections_and_items(es_txn_client: TransactionsClient) -> None: - refresh_indices(es_txn_client) - # try: - es_txn_client.database.delete_items() - # except Exception: - # pass - # try: - es_txn_client.database.delete_collections() - # except Exception: - # pass +async def delete_collections_and_items(txn_client: TransactionsClient) -> None: + await refresh_indices(txn_client) + await txn_client.database.delete_items() + await txn_client.database.delete_collections() -def refresh_indices(es_txn_client: TransactionsClient) -> None: +async def refresh_indices(txn_client: TransactionsClient) -> None: try: - es_txn_client.database.client.indices.refresh(index=ITEMS_INDEX) + await txn_client.database.client.indices.refresh(index=ITEMS_INDEX) except Exception: pass try: - es_txn_client.database.client.indices.refresh(index=COLLECTIONS_INDEX) + await txn_client.database.client.indices.refresh(index=COLLECTIONS_INDEX) except Exception: pass -class Context: - def __init__(self, item, collection): - self.item = item - self.collection = collection - - -@pytest.fixture() -def ctx(es_txn_client: TransactionsClient, test_collection, test_item): +@pytest_asyncio.fixture() +async def ctx(txn_client: TransactionsClient, test_collection, test_item): # todo remove one of these when all methods use it - delete_collections_and_items(es_txn_client) + await delete_collections_and_items(txn_client) - create_collection(es_txn_client, test_collection) - create_item(es_txn_client, test_item) + await create_collection(txn_client, test_collection) + await create_item(txn_client, test_item) yield Context(item=test_item, collection=test_collection) - delete_collections_and_items(es_txn_client) - - -class MockStarletteRequest: - base_url = "http://test-server" + await delete_collections_and_items(txn_client) @pytest.fixture -def es_core(): - return CoreCrudClient(session=None) +def core_client(): + return CoreClient(session=None) @pytest.fixture -def es_txn_client(): +def txn_client(): return TransactionsClient(session=None) @pytest.fixture -def es_bulk_transactions(): +def bulk_txn_client(): return BulkTransactionsClient(session=None) -@pytest.fixture -def api_client(): - settings = ElasticsearchSettings() +@pytest_asyncio.fixture(scope="session") +async def app(): + settings = AsyncElasticsearchSettings() extensions = [ TransactionExtension( client=TransactionsClient(session=None), settings=settings @@ -166,7 +166,7 @@ def api_client(): return StacApi( settings=settings, - client=CoreCrudClient( + client=CoreClient( session=None, extensions=extensions, post_request_model=post_request_model, @@ -174,12 +174,12 @@ def api_client(): extensions=extensions, search_get_request_model=get_request_model, search_post_request_model=post_request_model, - ) + ).app -@pytest.fixture -def app_client(api_client: StacApi): - IndexesClient().create_indexes() +@pytest_asyncio.fixture(scope="session") +async def app_client(app): + await IndexesClient().create_indexes() - with TestClient(api_client.app) as test_app: - yield test_app + async with AsyncClient(app=app, base_url="http://test") as c: + yield c diff --git a/stac_fastapi/elasticsearch/tests/resources/test_collection.py b/stac_fastapi/elasticsearch/tests/resources/test_collection.py index b0d8b3d6..5172c2b0 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_collection.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_collection.py @@ -1,67 +1,64 @@ import pystac -def test_create_and_delete_collection(app_client, load_test_data): +async def test_create_and_delete_collection(app_client, load_test_data): """Test creation and deletion of a collection""" test_collection = load_test_data("test_collection.json") test_collection["id"] = "test" - resp = app_client.post("/collections", json=test_collection) + resp = await app_client.post("/collections", json=test_collection) assert resp.status_code == 200 - resp = app_client.delete(f"/collections/{test_collection['id']}") + resp = await app_client.delete(f"/collections/{test_collection['id']}") assert resp.status_code == 200 -def test_create_collection_conflict(app_client, load_test_data): +async def test_create_collection_conflict(app_client, ctx): """Test creation of a collection which already exists""" # This collection ID is created in the fixture, so this should be a conflict - test_collection = load_test_data("test_collection.json") - resp = app_client.post("/collections", json=test_collection) + resp = await app_client.post("/collections", json=ctx.collection) assert resp.status_code == 409 -def test_delete_missing_collection(app_client): +async def test_delete_missing_collection(app_client): """Test deletion of a collection which does not exist""" - resp = app_client.delete("/collections/missing-collection") + resp = await app_client.delete("/collections/missing-collection") assert resp.status_code == 404 -def test_update_collection_already_exists(app_client, load_test_data): +async def test_update_collection_already_exists(ctx, app_client): """Test updating a collection which already exists""" - test_collection = load_test_data("test_collection.json") - test_collection["keywords"].append("test") - resp = app_client.put("/collections", json=test_collection) + ctx.collection["keywords"].append("test") + resp = await app_client.put("/collections", json=ctx.collection) assert resp.status_code == 200 - resp = app_client.get(f"/collections/{test_collection['id']}") + resp = await app_client.get(f"/collections/{ctx.collection['id']}") assert resp.status_code == 200 resp_json = resp.json() assert "test" in resp_json["keywords"] -def test_update_new_collection(app_client, load_test_data): +async def test_update_new_collection(app_client, load_test_data): """Test updating a collection which does not exist (same as creation)""" test_collection = load_test_data("test_collection.json") test_collection["id"] = "new-test-collection" - resp = app_client.put("/collections", json=test_collection) + resp = await app_client.put("/collections", json=test_collection) assert resp.status_code == 404 -def test_collection_not_found(app_client): +async def test_collection_not_found(app_client): """Test read a collection which does not exist""" - resp = app_client.get("/collections/does-not-exist") + resp = await app_client.get("/collections/does-not-exist") assert resp.status_code == 404 -def test_returns_valid_collection(app_client, load_test_data): +async def test_returns_valid_collection(ctx, app_client): """Test validates fetched collection with jsonschema""" - test_collection = load_test_data("test_collection.json") - resp = app_client.put("/collections", json=test_collection) + resp = await app_client.put("/collections", json=ctx.collection) assert resp.status_code == 200 - resp = app_client.get(f"/collections/{test_collection['id']}") + resp = await app_client.get(f"/collections/{ctx.collection['id']}") assert resp.status_code == 200 resp_json = resp.json() diff --git a/stac_fastapi/elasticsearch/tests/resources/test_conformance.py b/stac_fastapi/elasticsearch/tests/resources/test_conformance.py index cb85c744..ab70a00b 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_conformance.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_conformance.py @@ -1,11 +1,12 @@ import urllib.parse import pytest +import pytest_asyncio -@pytest.fixture -def response(app_client): - return app_client.get("/") +@pytest_asyncio.fixture +async def response(app_client): + return await app_client.get("/") @pytest.fixture @@ -19,7 +20,7 @@ def get_link(landing_page, rel_type): ) -def test_landing_page_health(response): +async def test_landing_page_health(response): """Test landing page""" assert response.status_code == 200 assert response.headers["content-type"] == "application/json" @@ -39,7 +40,7 @@ def test_landing_page_health(response): @pytest.mark.parametrize("rel_type,expected_media_type,expected_path", link_tests) -def test_landing_page_links( +async def test_landing_page_links( response_json, app_client, rel_type, expected_media_type, expected_path ): link = get_link(response_json, rel_type) @@ -50,7 +51,7 @@ def test_landing_page_links( link_path = urllib.parse.urlsplit(link.get("href")).path assert link_path == expected_path - resp = app_client.get(link_path) + resp = await app_client.get(link_path) assert resp.status_code == 200 @@ -58,7 +59,7 @@ def test_landing_page_links( # code here seems meaningless since it would be the same as if the endpoint did not exist. Once # https://github.com/stac-utils/stac-fastapi/pull/227 has been merged we can add this to the # parameterized tests above. -def test_search_link(response_json): +async def test_search_link(response_json): search_link = get_link(response_json, "search") assert search_link is not None diff --git a/stac_fastapi/elasticsearch/tests/resources/test_item.py b/stac_fastapi/elasticsearch/tests/resources/test_item.py index 56edd0a0..6869e60d 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_item.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_item.py @@ -12,7 +12,7 @@ from geojson_pydantic.geometries import Polygon from pystac.utils import datetime_to_str -from stac_fastapi.elasticsearch.core import CoreCrudClient +from stac_fastapi.elasticsearch.core import CoreClient from stac_fastapi.elasticsearch.datetime_utils import now_to_rfc3339_str from stac_fastapi.types.core import LandingPageMixin @@ -23,96 +23,102 @@ def rfc3339_str_to_datetime(s: str) -> datetime: return ciso8601.parse_rfc3339(s) -def test_create_and_delete_item(app_client, ctx, es_txn_client): +async def test_create_and_delete_item(app_client, ctx, txn_client, event_loop): """Test creation and deletion of a single item (transactions extension)""" test_item = ctx.item - resp = app_client.get( + resp = await app_client.get( f"/collections/{test_item['collection']}/items/{test_item['id']}" ) assert resp.status_code == 200 - resp = app_client.delete( + resp = await app_client.delete( f"/collections/{test_item['collection']}/items/{test_item['id']}" ) assert resp.status_code == 200 - refresh_indices(es_txn_client) + await refresh_indices(txn_client) - resp = app_client.get( + resp = await app_client.get( f"/collections/{test_item['collection']}/items/{test_item['id']}" ) assert resp.status_code == 404 -def test_create_item_conflict(app_client, ctx): +async def test_create_item_conflict(app_client, ctx): """Test creation of an item which already exists (transactions extension)""" test_item = ctx.item - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 409 -def test_delete_missing_item(app_client, load_test_data): +async def test_delete_missing_item(app_client, load_test_data): """Test deletion of an item which does not exist (transactions extension)""" test_item = load_test_data("test_item.json") - resp = app_client.delete(f"/collections/{test_item['collection']}/items/hijosh") + resp = await app_client.delete( + f"/collections/{test_item['collection']}/items/hijosh" + ) assert resp.status_code == 404 -def test_create_item_missing_collection(app_client, ctx): +async def test_create_item_missing_collection(app_client, ctx): """Test creation of an item without a parent collection (transactions extension)""" - ctx.item["collection"] = "stc is cool" - resp = app_client.post( + ctx.item["collection"] = "stac_is_cool" + resp = await app_client.post( f"/collections/{ctx.item['collection']}/items", json=ctx.item ) - assert resp.status_code == 422 + assert resp.status_code == 404 -def test_update_item_already_exists(app_client, ctx): +async def test_update_item_already_exists(app_client, ctx): """Test updating an item which already exists (transactions extension)""" assert ctx.item["properties"]["gsd"] != 16 ctx.item["properties"]["gsd"] = 16 - app_client.put(f"/collections/{ctx.item['collection']}/items", json=ctx.item) - resp = app_client.get( + await app_client.put(f"/collections/{ctx.item['collection']}/items", json=ctx.item) + resp = await app_client.get( f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}" ) updated_item = resp.json() assert updated_item["properties"]["gsd"] == 16 - app_client.delete(f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}") + await app_client.delete( + f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}" + ) -def test_update_new_item(app_client, ctx): +async def test_update_new_item(app_client, ctx): """Test updating an item which does not exist (transactions extension)""" test_item = ctx.item test_item["id"] = "a" # note: this endpoint is wrong in stac-fastapi -- should be /collections/{c_id}/items/{item_id} - resp = app_client.put( + resp = await app_client.put( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 404 -def test_update_item_missing_collection(app_client, ctx): +async def test_update_item_missing_collection(app_client, ctx): """Test updating an item without a parent collection (transactions extension)""" # Try to update collection of the item ctx.item["collection"] = "stac_is_cool" - resp = app_client.put(f"/collections/{ctx.item['collection']}/items", json=ctx.item) - assert resp.status_code == 422 + resp = await app_client.put( + f"/collections/{ctx.item['collection']}/items", json=ctx.item + ) + assert resp.status_code == 404 -def test_update_item_geometry(app_client, ctx): +async def test_update_item_geometry(app_client, ctx): ctx.item["id"] = "update_test_item_1" # Create the item - resp = app_client.post( + resp = await app_client.post( f"/collections/{ctx.item['collection']}/items", json=ctx.item ) assert resp.status_code == 200 @@ -129,29 +135,31 @@ def test_update_item_geometry(app_client, ctx): # Update the geometry of the item ctx.item["geometry"]["coordinates"] = new_coordinates - resp = app_client.put(f"/collections/{ctx.item['collection']}/items", json=ctx.item) + resp = await app_client.put( + f"/collections/{ctx.item['collection']}/items", json=ctx.item + ) assert resp.status_code == 200 # Fetch the updated item - resp = app_client.get( + resp = await app_client.get( f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}" ) assert resp.status_code == 200 assert resp.json()["geometry"]["coordinates"] == new_coordinates -def test_get_item(app_client, ctx): +async def test_get_item(app_client, ctx): """Test read an item by id (core)""" - get_item = app_client.get( + get_item = await app_client.get( f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}" ) assert get_item.status_code == 200 -def test_returns_valid_item(app_client, ctx): +async def test_returns_valid_item(app_client, ctx): """Test validates fetched item with jsonschema""" test_item = ctx.item - get_item = app_client.get( + get_item = await app_client.get( f"/collections/{test_item['collection']}/items/{test_item['id']}" ) assert get_item.status_code == 200 @@ -164,15 +172,15 @@ def test_returns_valid_item(app_client, ctx): item.validate() -def test_get_item_collection(app_client, ctx, es_txn_client): +async def test_get_item_collection(app_client, ctx, txn_client): """Test read an item collection (core)""" item_count = randint(1, 4) for idx in range(item_count): ctx.item["id"] = f'{ctx.item["id"]}{idx}' - create_item(es_txn_client, ctx.item) + await create_item(txn_client, ctx.item) - resp = app_client.get(f"/collections/{ctx.item['collection']}/items") + resp = await app_client.get(f"/collections/{ctx.item['collection']}/items") assert resp.status_code == 200 item_collection = resp.json() @@ -180,7 +188,7 @@ def test_get_item_collection(app_client, ctx, es_txn_client): @pytest.mark.skip(reason="Pagination extension not implemented") -def test_pagination(app_client, load_test_data): +async def test_pagination(app_client, load_test_data): """Test item collection pagination (paging extension)""" item_count = 10 test_item = load_test_data("test_item.json") @@ -188,12 +196,12 @@ def test_pagination(app_client, load_test_data): for idx in range(item_count): _test_item = deepcopy(test_item) _test_item["id"] = test_item["id"] + str(idx) - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=_test_item ) assert resp.status_code == 200 - resp = app_client.get( + resp = await app_client.get( f"/collections/{test_item['collection']}/items", params={"limit": 3} ) assert resp.status_code == 200 @@ -201,13 +209,13 @@ def test_pagination(app_client, load_test_data): assert first_page["context"]["returned"] == 3 url_components = urlsplit(first_page["links"][0]["href"]) - resp = app_client.get(f"{url_components.path}?{url_components.query}") + resp = await app_client.get(f"{url_components.path}?{url_components.query}") assert resp.status_code == 200 second_page = resp.json() assert second_page["context"]["returned"] == 3 -def test_item_timestamps(app_client, ctx, load_test_data): +async def test_item_timestamps(app_client, ctx, load_test_data): """Test created and updated timestamps (common metadata)""" # start_time = now_to_rfc3339_str() @@ -219,7 +227,7 @@ def test_item_timestamps(app_client, ctx, load_test_data): # Confirm `updated` timestamp ctx.item["properties"]["proj:epsg"] = 4326 - resp = app_client.put( + resp = await app_client.put( f"/collections/{ctx.item['collection']}/items", json=dict(ctx.item) ) assert resp.status_code == 200 @@ -229,28 +237,27 @@ def test_item_timestamps(app_client, ctx, load_test_data): assert ctx.item["properties"]["created"] == updated_item["properties"]["created"] assert updated_item["properties"]["updated"] > created_dt - app_client.delete( - f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}", - json=dict(ctx.item), + await app_client.delete( + f"/collections/{ctx.item['collection']}/items/{ctx.item['id']}" ) -def test_item_search_by_id_post(app_client, ctx, es_txn_client): +async def test_item_search_by_id_post(app_client, ctx, txn_client): """Test POST search by item id (core)""" ids = ["test1", "test2", "test3"] for _id in ids: ctx.item["id"] = _id - create_item(es_txn_client, ctx.item) + await create_item(txn_client, ctx.item) params = {"collections": [ctx.item["collection"]], "ids": ids} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == len(ids) assert set([feat["id"] for feat in resp_json["features"]]) == set(ids) -def test_item_search_spatial_query_post(app_client, ctx): +async def test_item_search_spatial_query_post(app_client, ctx): """Test POST search with spatial query (core)""" test_item = ctx.item @@ -258,13 +265,13 @@ def test_item_search_spatial_query_post(app_client, ctx): "collections": [test_item["collection"]], "intersects": test_item["geometry"], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert resp_json["features"][0]["id"] == test_item["id"] -def test_item_search_temporal_query_post(app_client, ctx): +async def test_item_search_temporal_query_post(app_client, ctx): """Test POST search with single-tailed spatio-temporal query (core)""" test_item = ctx.item @@ -277,12 +284,12 @@ def test_item_search_temporal_query_post(app_client, ctx): "intersects": test_item["geometry"], "datetime": f"../{datetime_to_str(item_date)}", } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) resp_json = resp.json() assert resp_json["features"][0]["id"] == test_item["id"] -def test_item_search_temporal_window_post(app_client, load_test_data, ctx): +async def test_item_search_temporal_window_post(app_client, load_test_data, ctx): """Test POST search with two-tailed spatio-temporal query (core)""" test_item = ctx.item @@ -295,12 +302,12 @@ def test_item_search_temporal_window_post(app_client, load_test_data, ctx): "intersects": test_item["geometry"], "datetime": f"{datetime_to_str(item_date_before)}/{datetime_to_str(item_date_after)}", } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) resp_json = resp.json() assert resp_json["features"][0]["id"] == test_item["id"] -def test_item_search_temporal_open_window(app_client, ctx): +async def test_item_search_temporal_open_window(app_client, ctx): """Test POST search with open spatio-temporal query (core)""" test_item = ctx.item params = { @@ -308,17 +315,17 @@ def test_item_search_temporal_open_window(app_client, ctx): "intersects": test_item["geometry"], "datetime": "../..", } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) resp_json = resp.json() assert resp_json["features"][0]["id"] == test_item["id"] @pytest.mark.skip(reason="sortby date not implemented") -def test_item_search_sort_post(app_client, load_test_data): +async def test_item_search_sort_post(app_client, load_test_data): """Test POST search with sorting (sort extension)""" first_item = load_test_data("test_item.json") item_date = rfc3339_str_to_datetime(first_item["properties"]["datetime"]) - resp = app_client.post( + resp = await app_client.post( f"/collections/{first_item['collection']}/items", json=first_item ) assert resp.status_code == 200 @@ -327,7 +334,7 @@ def test_item_search_sort_post(app_client, load_test_data): second_item["id"] = "another-item" another_item_date = item_date - timedelta(days=1) second_item["properties"]["datetime"] = datetime_to_str(another_item_date) - resp = app_client.post( + resp = await app_client.post( f"/collections/{second_item['collection']}/items", json=second_item ) assert resp.status_code == 200 @@ -336,55 +343,54 @@ def test_item_search_sort_post(app_client, load_test_data): "collections": [first_item["collection"]], "sortby": [{"field": "datetime", "direction": "desc"}], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert resp_json["features"][0]["id"] == first_item["id"] assert resp_json["features"][1]["id"] == second_item["id"] - app_client.delete( - f"/collections/{first_item['collection']}/items/{first_item['id']}", - json=first_item, + await app_client.delete( + f"/collections/{first_item['collection']}/items/{first_item['id']}" ) -def test_item_search_by_id_get(app_client, ctx, es_txn_client): +async def test_item_search_by_id_get(app_client, ctx, txn_client): """Test GET search by item id (core)""" ids = ["test1", "test2", "test3"] for _id in ids: ctx.item["id"] = _id - create_item(es_txn_client, ctx.item) + await create_item(txn_client, ctx.item) params = {"collections": ctx.item["collection"], "ids": ",".join(ids)} - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == len(ids) assert set([feat["id"] for feat in resp_json["features"]]) == set(ids) -def test_item_search_bbox_get(app_client, ctx): +async def test_item_search_bbox_get(app_client, ctx): """Test GET search with spatial query (core)""" params = { "collections": ctx.item["collection"], "bbox": ",".join([str(coord) for coord in ctx.item["bbox"]]), } - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) assert resp.status_code == 200 resp_json = resp.json() assert resp_json["features"][0]["id"] == ctx.item["id"] -def test_item_search_get_without_collections(app_client, ctx): +async def test_item_search_get_without_collections(app_client, ctx): """Test GET search without specifying collections""" params = { "bbox": ",".join([str(coord) for coord in ctx.item["bbox"]]), } - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) assert resp.status_code == 200 -def test_item_search_temporal_window_get(app_client, ctx): +async def test_item_search_temporal_window_get(app_client, ctx): """Test GET search with spatio-temporal query (core)""" test_item = ctx.item item_date = rfc3339_str_to_datetime(test_item["properties"]["datetime"]) @@ -396,66 +402,66 @@ def test_item_search_temporal_window_get(app_client, ctx): "bbox": ",".join([str(coord) for coord in test_item["bbox"]]), "datetime": f"{datetime_to_str(item_date_before)}/{datetime_to_str(item_date_after)}", } - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) resp_json = resp.json() assert resp_json["features"][0]["id"] == test_item["id"] @pytest.mark.skip(reason="sorting not fully implemented") -def test_item_search_sort_get(app_client, ctx, es_txn_client): +async def test_item_search_sort_get(app_client, ctx, txn_client): """Test GET search with sorting (sort extension)""" first_item = ctx.item item_date = rfc3339_str_to_datetime(first_item["properties"]["datetime"]) - create_item(es_txn_client, ctx.item) + await create_item(txn_client, ctx.item) second_item = ctx.item.copy() second_item["id"] = "another-item" another_item_date = item_date - timedelta(days=1) second_item.update({"properties": {"datetime": datetime_to_str(another_item_date)}}) - create_item(es_txn_client, second_item) + await create_item(txn_client, second_item) params = {"collections": [first_item["collection"]], "sortby": "-datetime"} - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) assert resp.status_code == 200 resp_json = resp.json() assert resp_json["features"][0]["id"] == first_item["id"] assert resp_json["features"][1]["id"] == second_item["id"] -def test_item_search_post_without_collection(app_client, ctx): +async def test_item_search_post_without_collection(app_client, ctx): """Test POST search without specifying a collection""" test_item = ctx.item params = { "bbox": test_item["bbox"], } - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 -def test_item_search_properties_es(app_client, ctx): +async def test_item_search_properties_es(app_client, ctx): """Test POST search with JSONB query (query extension)""" test_item = ctx.item # EPSG is a JSONB key params = {"query": {"proj:epsg": {"gt": test_item["properties"]["proj:epsg"] + 1}}} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 0 -def test_item_search_properties_field(app_client, ctx): +async def test_item_search_properties_field(app_client, ctx): """Test POST search indexed field with query (query extension)""" # Orientation is an indexed field params = {"query": {"orientation": {"eq": "south"}}} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 200 resp_json = resp.json() assert len(resp_json["features"]) == 0 -def test_item_search_get_query_extension(app_client, ctx): +async def test_item_search_get_query_extension(app_client, ctx): """Test GET search with JSONB query (query extension)""" test_item = ctx.item @@ -466,13 +472,13 @@ def test_item_search_get_query_extension(app_client, ctx): {"proj:epsg": {"gt": test_item["properties"]["proj:epsg"] + 1}} ), } - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) assert resp.json()["context"]["returned"] == 0 params["query"] = json.dumps( {"proj:epsg": {"eq": test_item["properties"]["proj:epsg"]}} ) - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) resp_json = resp.json() assert resp_json["context"]["returned"] == 1 assert ( @@ -481,14 +487,14 @@ def test_item_search_get_query_extension(app_client, ctx): ) -def test_get_missing_item_collection(app_client): +async def test_get_missing_item_collection(app_client): """Test reading a collection which does not exist""" - resp = app_client.get("/collections/invalid-collection/items") + resp = await app_client.get("/collections/invalid-collection/items") assert resp.status_code == 200 @pytest.mark.skip(reason="Pagination extension not implemented") -def test_pagination_item_collection(app_client, load_test_data): +async def test_pagination_item_collection(app_client, load_test_data): """Test item collection pagination links (paging extension)""" test_item = load_test_data("test_item.json") ids = [] @@ -497,14 +503,14 @@ def test_pagination_item_collection(app_client, load_test_data): for idx in range(5): uid = str(uuid.uuid4()) test_item["id"] = uid - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 ids.append(uid) # Paginate through all 5 items with a limit of 1 (expecting 5 requests) - page = app_client.get( + page = await app_client.get( f"/collections/{test_item['collection']}/items", params={"limit": 1} ) idx = 0 @@ -517,7 +523,7 @@ def test_pagination_item_collection(app_client, load_test_data): if not next_link: break query_params = parse_qs(urlparse(next_link[0]["href"]).query) - page = app_client.get( + page = await app_client.get( f"/collections/{test_item['collection']}/items", params=query_params, ) @@ -530,7 +536,7 @@ def test_pagination_item_collection(app_client, load_test_data): @pytest.mark.skip(reason="Pagination extension not implemented") -def test_pagination_post(app_client, load_test_data): +async def test_pagination_post(app_client, load_test_data): """Test POST pagination (paging extension)""" test_item = load_test_data("test_item.json") ids = [] @@ -539,7 +545,7 @@ def test_pagination_post(app_client, load_test_data): for idx in range(5): uid = str(uuid.uuid4()) test_item["id"] = uid - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 @@ -547,7 +553,7 @@ def test_pagination_post(app_client, load_test_data): # Paginate through all 5 items with a limit of 1 (expecting 5 requests) request_body = {"ids": ids, "limit": 1} - page = app_client.post("/search", json=request_body) + page = await app_client.post("/search", json=request_body) idx = 0 item_ids = [] while True: @@ -559,7 +565,7 @@ def test_pagination_post(app_client, load_test_data): break # Merge request bodies request_body.update(next_link[0]["body"]) - page = app_client.post("/search", json=request_body) + page = await app_client.post("/search", json=request_body) # Our limit is 1 so we expect len(ids) number of requests before we run out of pages assert idx == len(ids) @@ -569,7 +575,7 @@ def test_pagination_post(app_client, load_test_data): @pytest.mark.skip(reason="Pagination extension not implemented") -def test_pagination_token_idempotent(app_client, load_test_data): +async def test_pagination_token_idempotent(app_client, load_test_data): """Test that pagination tokens are idempotent (paging extension)""" test_item = load_test_data("test_item.json") ids = [] @@ -578,21 +584,21 @@ def test_pagination_token_idempotent(app_client, load_test_data): for idx in range(5): uid = str(uuid.uuid4()) test_item["id"] = uid - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 ids.append(uid) - page = app_client.get("/search", params={"ids": ",".join(ids), "limit": 3}) + page = await app_client.get("/search", params={"ids": ",".join(ids), "limit": 3}) page_data = page.json() next_link = list(filter(lambda l: l["rel"] == "next", page_data["links"])) # Confirm token is idempotent - resp1 = app_client.get( + resp1 = await app_client.get( "/search", params=parse_qs(urlparse(next_link[0]["href"]).query) ) - resp2 = app_client.get( + resp2 = await app_client.get( "/search", params=parse_qs(urlparse(next_link[0]["href"]).query) ) resp1_data = resp1.json() @@ -605,41 +611,41 @@ def test_pagination_token_idempotent(app_client, load_test_data): @pytest.mark.skip(reason="fields not implemented") -def test_field_extension_get_includes(app_client, load_test_data): +async def test_field_extension_get_includes(app_client, load_test_data): """Test GET search with included fields (fields extension)""" test_item = load_test_data("test_item.json") - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 params = {"fields": "+properties.proj:epsg,+properties.gsd"} - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) feat_properties = resp.json()["features"][0]["properties"] assert not set(feat_properties) - {"proj:epsg", "gsd", "datetime"} @pytest.mark.skip(reason="fields not implemented") -def test_field_extension_get_excludes(app_client, load_test_data): +async def test_field_extension_get_excludes(app_client, load_test_data): """Test GET search with included fields (fields extension)""" test_item = load_test_data("test_item.json") - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 params = {"fields": "-properties.proj:epsg,-properties.gsd"} - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) resp_json = resp.json() assert "proj:epsg" not in resp_json["features"][0]["properties"].keys() assert "gsd" not in resp_json["features"][0]["properties"].keys() @pytest.mark.skip(reason="fields not implemented") -def test_field_extension_post(app_client, load_test_data): +async def test_field_extension_post(app_client, load_test_data): """Test POST search with included and excluded fields (fields extension)""" test_item = load_test_data("test_item.json") - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 @@ -651,7 +657,7 @@ def test_field_extension_post(app_client, load_test_data): } } - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) resp_json = resp.json() assert "B1" not in resp_json["features"][0]["assets"].keys() assert not set(resp_json["features"][0]["properties"]) - { @@ -662,10 +668,10 @@ def test_field_extension_post(app_client, load_test_data): @pytest.mark.skip(reason="fields not implemented") -def test_field_extension_exclude_and_include(app_client, load_test_data): +async def test_field_extension_exclude_and_include(app_client, load_test_data): """Test POST search including/excluding same field (fields extension)""" test_item = load_test_data("test_item.json") - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 @@ -677,65 +683,65 @@ def test_field_extension_exclude_and_include(app_client, load_test_data): } } - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) resp_json = resp.json() assert "eo:cloud_cover" not in resp_json["features"][0]["properties"] @pytest.mark.skip(reason="fields not implemented") -def test_field_extension_exclude_default_includes(app_client, load_test_data): +async def test_field_extension_exclude_default_includes(app_client, load_test_data): """Test POST search excluding a forbidden field (fields extension)""" test_item = load_test_data("test_item.json") - resp = app_client.post( + resp = await app_client.post( f"/collections/{test_item['collection']}/items", json=test_item ) assert resp.status_code == 200 body = {"fields": {"exclude": ["gsd"]}} - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) resp_json = resp.json() assert "gsd" not in resp_json["features"][0] -def test_search_intersects_and_bbox(app_client): +async def test_search_intersects_and_bbox(app_client): """Test POST search intersects and bbox are mutually exclusive (core)""" bbox = [-118, 34, -117, 35] geoj = Polygon.from_bounds(*bbox).dict(exclude_none=True) params = {"bbox": bbox, "intersects": geoj} - resp = app_client.post("/search", json=params) + resp = await app_client.post("/search", json=params) assert resp.status_code == 400 -def test_get_missing_item(app_client, load_test_data): +async def test_get_missing_item(app_client, load_test_data): """Test read item which does not exist (transactions extension)""" test_coll = load_test_data("test_collection.json") - resp = app_client.get(f"/collections/{test_coll['id']}/items/invalid-item") + resp = await app_client.get(f"/collections/{test_coll['id']}/items/invalid-item") assert resp.status_code == 404 @pytest.mark.skip(reason="invalid queries not implemented") -def test_search_invalid_query_field(app_client): +async def test_search_invalid_query_field(app_client): body = {"query": {"gsd": {"lt": 100}, "invalid-field": {"eq": 50}}} - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) assert resp.status_code == 400 -def test_search_bbox_errors(app_client): +async def test_search_bbox_errors(app_client): body = {"query": {"bbox": [0]}} - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) assert resp.status_code == 400 body = {"query": {"bbox": [100.0, 0.0, 0.0, 105.0, 1.0, 1.0]}} - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) assert resp.status_code == 400 params = {"bbox": "100.0,0.0,0.0,105.0"} - resp = app_client.get("/search", params=params) + resp = await app_client.get("/search", params=params) assert resp.status_code == 400 -def test_conformance_classes_configurable(): +async def test_conformance_classes_configurable(): """Test conformance class configurability""" landing = LandingPageMixin() landing_page = landing._landing_page( @@ -748,11 +754,11 @@ def test_conformance_classes_configurable(): # Update environment to avoid key error on client instantiation os.environ["READER_CONN_STRING"] = "testing" os.environ["WRITER_CONN_STRING"] = "testing" - client = CoreCrudClient(base_conformance_classes=["this is a test"]) + client = CoreClient(base_conformance_classes=["this is a test"]) assert client.conformance_classes()[0] == "this is a test" -def test_search_datetime_validation_errors(app_client): +async def test_search_datetime_validation_errors(app_client): bad_datetimes = [ "37-01-01T12:00:27.87Z", "1985-13-12T23:20:50.52Z", @@ -765,8 +771,8 @@ def test_search_datetime_validation_errors(app_client): ] for dt in bad_datetimes: body = {"query": {"datetime": dt}} - resp = app_client.post("/search", json=body) + resp = await app_client.post("/search", json=body) assert resp.status_code == 400 - resp = app_client.get("/search?datetime={}".format(dt)) + resp = await app_client.get("/search?datetime={}".format(dt)) assert resp.status_code == 400 diff --git a/stac_fastapi/elasticsearch/tests/resources/test_mgmt.py b/stac_fastapi/elasticsearch/tests/resources/test_mgmt.py index 0a11e38e..9d2bc3dc 100644 --- a/stac_fastapi/elasticsearch/tests/resources/test_mgmt.py +++ b/stac_fastapi/elasticsearch/tests/resources/test_mgmt.py @@ -1,9 +1,9 @@ -def test_ping_no_param(app_client): +async def test_ping_no_param(app_client): """ Test ping endpoint with a mocked client. Args: app_client (TestClient): mocked client fixture """ - res = app_client.get("/_mgmt/ping") + res = await app_client.get("/_mgmt/ping") assert res.status_code == 200 assert res.json() == {"message": "PONG"} From d427b393cab03c37dd69d377c4cde23c113e6bd3 Mon Sep 17 00:00:00 2001 From: Phil Varner Date: Tue, 22 Mar 2022 16:26:41 -0400 Subject: [PATCH 2/2] run all tests --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 18a276d4..e7bb1f74 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,7 @@ docker-shell: .PHONY: test test: - -$(run_es) /bin/bash -c 'export && ./scripts/wait-for-it-es.sh elasticsearch:9200 && cd /app/stac_fastapi/elasticsearch/tests/ && pytest api/test_api.py' + -$(run_es) /bin/bash -c 'export && ./scripts/wait-for-it-es.sh elasticsearch:9200 && cd /app/stac_fastapi/elasticsearch/tests/ && pytest' docker-compose down .PHONY: run-database