Skip to content

Commit a466f75

Browse files
author
Phil Varner
authored
refactor tests to not need sleep, use fixture for setup/teardown (#69)
* refactor tests to not need sleep, use fixture for setup/teardown of most tests, remove use of execute methods in ES DSL
1 parent f354e61 commit a466f75

File tree

9 files changed

+413
-679
lines changed

9 files changed

+413
-679
lines changed

stac_fastapi/elasticsearch/pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[pytest]
22
testpaths = tests
3-
addopts = -sv
3+
addopts = -sv
4+
asyncio_mode = auto

stac_fastapi/elasticsearch/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"requests",
3030
"ciso8601",
3131
"overrides",
32+
"starlette",
3233
],
3334
"docs": ["mkdocs", "mkdocs-material", "pdocs"],
3435
"server": ["uvicorn[standard]>=0.12.0,<0.14.0"],

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def item_collection(
9595
context_obj = None
9696
if self.extension_is_enabled("ContextExtension"):
9797
context_obj = {
98-
"returned": count if count < limit else limit,
98+
"returned": count if count is not None and count < limit else limit,
9999
"limit": limit,
100100
"matched": count,
101101
}
@@ -243,9 +243,8 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
243243
for sort in search_request.sortby:
244244
if sort.field == "datetime":
245245
sort.field = "properties__datetime"
246-
field = sort.field + ".keyword"
247246
search = self.database.sort_field(
248-
search=search, field=field, direction=sort.direction
247+
search=search, field=sort.field, direction=sort.direction
249248
)
250249

251250
count = self.database.search_count(search=search)
@@ -316,13 +315,14 @@ def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item:
316315
processed_items = [
317316
bulk_client.preprocess_item(item, base_url) for item in item["features"]
318317
]
319-
return_msg = f"Successfully added {len(processed_items)} items."
320-
self.database.bulk_sync(processed_items)
318+
self.database.bulk_sync(
319+
processed_items, refresh=kwargs.get("refresh", False)
320+
)
321321

322-
return return_msg
322+
return None
323323
else:
324324
item = self.database.prep_create_item(item=item, base_url=base_url)
325-
self.database.create_item(item=item, base_url=base_url)
325+
self.database.create_item(item, refresh=kwargs.get("refresh", False))
326326
return item
327327

328328
@overrides
@@ -413,6 +413,6 @@ def bulk_item_insert(
413413
self.preprocess_item(item, base_url) for item in items.items.values()
414414
]
415415

416-
self.database.bulk_sync(processed_items)
416+
self.database.bulk_sync(processed_items, refresh=kwargs.get("refresh", False))
417417

418418
return f"Successfully added {len(processed_items)} Items."

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 98 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
"""Database logic."""
22
import logging
3-
from typing import List, Type, Union
3+
from typing import List, Optional, Tuple, Type, Union
44

55
import attr
66
import elasticsearch
77
from elasticsearch import helpers
88
from elasticsearch_dsl import Q, Search
9+
from geojson_pydantic.geometries import (
10+
GeometryCollection,
11+
LineString,
12+
MultiLineString,
13+
MultiPoint,
14+
MultiPolygon,
15+
Point,
16+
Polygon,
17+
)
918

1019
from stac_fastapi.elasticsearch import serializers
1120
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
1221
from stac_fastapi.types.errors import ConflictError, ForeignKeyError, NotFoundError
13-
from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
22+
from stac_fastapi.types.stac import Collection, Item
1423

1524
logger = logging.getLogger(__name__)
1625

@@ -31,10 +40,10 @@ class DatabaseLogic:
3140

3241
settings = ElasticsearchSettings()
3342
client = settings.create_client
34-
item_serializer: Type[serializers.Serializer] = attr.ib(
43+
item_serializer: Type[serializers.ItemSerializer] = attr.ib(
3544
default=serializers.ItemSerializer
3645
)
37-
collection_serializer: Type[serializers.Serializer] = attr.ib(
46+
collection_serializer: Type[serializers.CollectionSerializer] = attr.ib(
3847
default=serializers.CollectionSerializer
3948
)
4049

@@ -46,7 +55,7 @@ def bbox2poly(b0, b1, b2, b3):
4655

4756
"""CORE LOGIC"""
4857

49-
def get_all_collections(self, base_url: str) -> Collections:
58+
def get_all_collections(self, base_url: str) -> List[Collection]:
5059
"""Database logic to retrieve a list of all collections."""
5160
try:
5261
collections = self.client.search(
@@ -66,9 +75,10 @@ def get_all_collections(self, base_url: str) -> Collections:
6675

6776
def get_item_collection(
6877
self, collection_id: str, limit: int, base_url: str
69-
) -> ItemCollection:
78+
) -> Tuple[List[Item], Optional[int]]:
7079
"""Database logic to retrieve an ItemCollection and a count of items contained."""
71-
search = Search(using=self.client, index="stac_items")
80+
search = self.create_search_object()
81+
search = self.search_collections(search, [collection_id])
7282

7383
collection_filter = Q(
7484
"bool", should=[Q("match_phrase", **{"collection": collection_id})]
@@ -79,7 +89,11 @@ def get_item_collection(
7989

8090
# search = search.sort({"id.keyword" : {"order" : "asc"}})
8191
search = search.query()[0:limit]
82-
collection_children = search.execute().to_dict()
92+
93+
body = search.to_dict()
94+
collection_children = self.client.search(
95+
index=ITEMS_INDEX, query=body["query"], sort=body.get("sort")
96+
)
8397

8498
serialized_children = [
8599
self.item_serializer.db_to_stac(item["_source"], base_url=base_url)
@@ -100,21 +114,17 @@ def get_one_item(self, collection_id: str, item_id: str) -> Item:
100114
)
101115
return item["_source"]
102116

103-
def create_search_object(self):
117+
@staticmethod
118+
def create_search_object():
104119
"""Database logic to create a nosql Search instance."""
105-
search = (
106-
Search()
107-
.using(self.client)
108-
.index(ITEMS_INDEX)
109-
.sort(
110-
{"properties.datetime": {"order": "desc"}},
111-
{"id": {"order": "desc"}},
112-
{"collection": {"order": "desc"}},
113-
)
120+
return Search().sort(
121+
{"properties.datetime": {"order": "desc"}},
122+
{"id": {"order": "desc"}},
123+
{"collection": {"order": "desc"}},
114124
)
115-
return search
116125

117-
def create_query_filter(self, search, op: str, field: str, value: float):
126+
@staticmethod
127+
def create_query_filter(search: Search, op: str, field: str, value: float):
118128
"""Database logic to perform query for search endpoint."""
119129
if op != "eq":
120130
key_filter = {field: {f"{op}": value}}
@@ -124,7 +134,8 @@ def create_query_filter(self, search, op: str, field: str, value: float):
124134

125135
return search
126136

127-
def search_ids(self, search, item_ids: List):
137+
@staticmethod
138+
def search_ids(search: Search, item_ids: List):
128139
"""Database logic to search a list of STAC item ids."""
129140
id_list = []
130141
for item_id in item_ids:
@@ -134,17 +145,14 @@ def search_ids(self, search, item_ids: List):
134145

135146
return search
136147

137-
def search_collections(self, search, collection_ids: List):
148+
@staticmethod
149+
def search_collections(search: Search, collection_ids: List):
138150
"""Database logic to search a list of STAC collection ids."""
139-
collection_list = []
140-
for collection_id in collection_ids:
141-
collection_list.append(Q("match_phrase", **{"collection": collection_id}))
142-
collection_filter = Q("bool", should=collection_list)
143-
search = search.query(collection_filter)
144-
145-
return search
151+
collections_query = [Q("term", **{"collection": cid}) for cid in collection_ids]
152+
return search.query(Q("bool", should=collections_query))
146153

147-
def search_datetime(self, search, datetime_search):
154+
@staticmethod
155+
def search_datetime(search: Search, datetime_search):
148156
"""Database logic to search datetime field."""
149157
if "eq" in datetime_search:
150158
search = search.query(
@@ -159,9 +167,10 @@ def search_datetime(self, search, datetime_search):
159167
)
160168
return search
161169

162-
def search_bbox(self, search, bbox: List):
170+
@staticmethod
171+
def search_bbox(search: Search, bbox: List):
163172
"""Database logic to search on bounding box."""
164-
poly = self.bbox2poly(bbox[0], bbox[1], bbox[2], bbox[3])
173+
poly = DatabaseLogic.bbox2poly(bbox[0], bbox[1], bbox[2], bbox[3])
165174
bbox_filter = Q(
166175
{
167176
"geo_shape": {
@@ -175,7 +184,19 @@ def search_bbox(self, search, bbox: List):
175184
search = search.query(bbox_filter)
176185
return search
177186

178-
def search_intersects(self, search, intersects: dict):
187+
@staticmethod
188+
def search_intersects(
189+
search: Search,
190+
intersects: Union[
191+
Point,
192+
MultiPoint,
193+
LineString,
194+
MultiLineString,
195+
Polygon,
196+
MultiPolygon,
197+
GeometryCollection,
198+
],
199+
):
179200
"""Database logic to search a geojson object."""
180201
intersect_filter = Q(
181202
{
@@ -193,24 +214,27 @@ def search_intersects(self, search, intersects: dict):
193214
search = search.query(intersect_filter)
194215
return search
195216

196-
def sort_field(self, search, field, direction):
197-
"""Database logic to sort nosql search instance."""
198-
search = search.sort({field: {"order": direction}})
199-
return search
217+
@staticmethod
218+
def sort_field(search: Search, field, direction):
219+
"""Database logic to sort search instance."""
220+
return search.sort({field: {"order": direction}})
200221

201-
def search_count(self, search) -> int:
222+
def search_count(self, search: Search) -> int:
202223
"""Database logic to count search results."""
203224
try:
204-
count = search.count()
225+
return self.client.count(
226+
index=ITEMS_INDEX, body=search.to_dict(count=True)
227+
).get("count")
205228
except elasticsearch.exceptions.NotFoundError:
206229
raise NotFoundError("No items exist")
207230

208-
return count
209-
210231
def execute_search(self, search, limit: int, base_url: str) -> List:
211232
"""Database logic to execute search with limit."""
212233
search = search.query()[0:limit]
213-
response = search.execute().to_dict()
234+
body = search.to_dict()
235+
response = self.client.search(
236+
index=ITEMS_INDEX, query=body["query"], sort=body.get("sort")
237+
)
214238

215239
if len(response["hits"]["hits"]) > 0:
216240
response_features = [
@@ -242,30 +266,35 @@ def prep_create_item(self, item: Item, base_url: str) -> Item:
242266

243267
return self.item_serializer.stac_to_db(item, base_url)
244268

245-
def create_item(self, item: Item, base_url: str):
269+
def create_item(self, item: Item, refresh: bool = False):
246270
"""Database logic for creating one item."""
247271
# todo: check if collection exists, but cache
248272
es_resp = self.client.index(
249273
index=ITEMS_INDEX,
250274
id=mk_item_id(item["id"], item["collection"]),
251275
document=item,
276+
refresh=refresh,
252277
)
253278

254279
if (meta := es_resp.get("meta")) and meta.get("status") == 409:
255280
raise ConflictError(
256281
f"Item {item['id']} in collection {item['collection']} already exists"
257282
)
258283

259-
def delete_item(self, item_id: str, collection_id: str):
284+
def delete_item(self, item_id: str, collection_id: str, refresh: bool = False):
260285
"""Database logic for deleting one item."""
261286
try:
262-
self.client.delete(index=ITEMS_INDEX, id=mk_item_id(item_id, collection_id))
287+
self.client.delete(
288+
index=ITEMS_INDEX,
289+
id=mk_item_id(item_id, collection_id),
290+
refresh=refresh,
291+
)
263292
except elasticsearch.exceptions.NotFoundError:
264293
raise NotFoundError(
265294
f"Item {item_id} in collection {collection_id} not found"
266295
)
267296

268-
def create_collection(self, collection: Collection):
297+
def create_collection(self, collection: Collection, refresh: bool = False):
269298
"""Database logic for creating one collection."""
270299
if self.client.exists(index=COLLECTIONS_INDEX, id=collection["id"]):
271300
raise ConflictError(f"Collection {collection['id']} already exists")
@@ -274,6 +303,7 @@ def create_collection(self, collection: Collection):
274303
index=COLLECTIONS_INDEX,
275304
id=collection["id"],
276305
document=collection,
306+
refresh=refresh,
277307
)
278308

279309
def find_collection(self, collection_id: str) -> Collection:
@@ -285,12 +315,12 @@ def find_collection(self, collection_id: str) -> Collection:
285315

286316
return collection["_source"]
287317

288-
def delete_collection(self, collection_id: str):
318+
def delete_collection(self, collection_id: str, refresh: bool = False):
289319
"""Database logic for deleting one collection."""
290320
_ = self.find_collection(collection_id=collection_id)
291-
self.client.delete(index=COLLECTIONS_INDEX, id=collection_id)
321+
self.client.delete(index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh)
292322

293-
def bulk_sync(self, processed_items):
323+
def bulk_sync(self, processed_items, refresh: bool = False):
294324
"""Database logic for bulk item insertion."""
295325
actions = [
296326
{
@@ -300,4 +330,22 @@ def bulk_sync(self, processed_items):
300330
}
301331
for item in processed_items
302332
]
303-
helpers.bulk(self.client, actions)
333+
helpers.bulk(self.client, actions, refresh=refresh)
334+
335+
# DANGER
336+
def delete_items(self) -> None:
337+
"""Danger. this is only for tests."""
338+
self.client.delete_by_query(
339+
index=ITEMS_INDEX,
340+
body={"query": {"match_all": {}}},
341+
wait_for_completion=True,
342+
)
343+
344+
# DANGER
345+
def delete_collections(self) -> None:
346+
"""Danger. this is only for tests."""
347+
self.client.delete_by_query(
348+
index=COLLECTIONS_INDEX,
349+
body={"query": {"match_all": {}}},
350+
wait_for_completion=True,
351+
)

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/indexes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def create_indexes(self):
105105
"""Create the index for Items and Collections."""
106106
self.client.indices.create(
107107
index=ITEMS_INDEX,
108-
body={"mappings": self.ES_ITEMS_MAPPINGS},
108+
mappings=self.ES_ITEMS_MAPPINGS,
109109
ignore=400, # ignore 400 already exists code
110110
)
111111
self.client.indices.create(
112112
index=COLLECTIONS_INDEX,
113-
body={"mappings": self.ES_COLLECTIONS_MAPPINGS},
113+
mappings=self.ES_COLLECTIONS_MAPPINGS,
114114
ignore=400, # ignore 400 already exists code
115115
)

0 commit comments

Comments
 (0)