1
1
"""Database logic."""
2
2
import logging
3
- from typing import List , Type , Union
3
+ from typing import List , Optional , Tuple , Type , Union
4
4
5
5
import attr
6
6
import elasticsearch
7
7
from elasticsearch import helpers
8
8
from elasticsearch_dsl import Q , Search
9
+ from geojson_pydantic .geometries import (
10
+ GeometryCollection ,
11
+ LineString ,
12
+ MultiLineString ,
13
+ MultiPoint ,
14
+ MultiPolygon ,
15
+ Point ,
16
+ Polygon ,
17
+ )
9
18
10
19
from stac_fastapi .elasticsearch import serializers
11
20
from stac_fastapi .elasticsearch .config import ElasticsearchSettings
12
21
from stac_fastapi .types .errors import ConflictError , ForeignKeyError , NotFoundError
13
- from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
22
+ from stac_fastapi .types .stac import Collection , Item
14
23
15
24
logger = logging .getLogger (__name__ )
16
25
@@ -31,10 +40,10 @@ class DatabaseLogic:
31
40
32
41
settings = ElasticsearchSettings ()
33
42
client = settings .create_client
34
- item_serializer : Type [serializers .Serializer ] = attr .ib (
43
+ item_serializer : Type [serializers .ItemSerializer ] = attr .ib (
35
44
default = serializers .ItemSerializer
36
45
)
37
- collection_serializer : Type [serializers .Serializer ] = attr .ib (
46
+ collection_serializer : Type [serializers .CollectionSerializer ] = attr .ib (
38
47
default = serializers .CollectionSerializer
39
48
)
40
49
@@ -46,7 +55,7 @@ def bbox2poly(b0, b1, b2, b3):
46
55
47
56
"""CORE LOGIC"""
48
57
49
- def get_all_collections (self , base_url : str ) -> Collections :
58
+ def get_all_collections (self , base_url : str ) -> List [ Collection ] :
50
59
"""Database logic to retrieve a list of all collections."""
51
60
try :
52
61
collections = self .client .search (
@@ -66,9 +75,10 @@ def get_all_collections(self, base_url: str) -> Collections:
66
75
67
76
def get_item_collection (
68
77
self , collection_id : str , limit : int , base_url : str
69
- ) -> ItemCollection :
78
+ ) -> Tuple [ List [ Item ], Optional [ int ]] :
70
79
"""Database logic to retrieve an ItemCollection and a count of items contained."""
71
- search = Search (using = self .client , index = "stac_items" )
80
+ search = self .create_search_object ()
81
+ search = self .search_collections (search , [collection_id ])
72
82
73
83
collection_filter = Q (
74
84
"bool" , should = [Q ("match_phrase" , ** {"collection" : collection_id })]
@@ -79,7 +89,11 @@ def get_item_collection(
79
89
80
90
# search = search.sort({"id.keyword" : {"order" : "asc"}})
81
91
search = search .query ()[0 :limit ]
82
- collection_children = search .execute ().to_dict ()
92
+
93
+ body = search .to_dict ()
94
+ collection_children = self .client .search (
95
+ index = ITEMS_INDEX , query = body ["query" ], sort = body .get ("sort" )
96
+ )
83
97
84
98
serialized_children = [
85
99
self .item_serializer .db_to_stac (item ["_source" ], base_url = base_url )
@@ -100,21 +114,17 @@ def get_one_item(self, collection_id: str, item_id: str) -> Item:
100
114
)
101
115
return item ["_source" ]
102
116
103
- def create_search_object (self ):
117
+ @staticmethod
118
+ def create_search_object ():
104
119
"""Database logic to create a nosql Search instance."""
105
- search = (
106
- Search ()
107
- .using (self .client )
108
- .index (ITEMS_INDEX )
109
- .sort (
110
- {"properties.datetime" : {"order" : "desc" }},
111
- {"id" : {"order" : "desc" }},
112
- {"collection" : {"order" : "desc" }},
113
- )
120
+ return Search ().sort (
121
+ {"properties.datetime" : {"order" : "desc" }},
122
+ {"id" : {"order" : "desc" }},
123
+ {"collection" : {"order" : "desc" }},
114
124
)
115
- return search
116
125
117
- def create_query_filter (self , search , op : str , field : str , value : float ):
126
+ @staticmethod
127
+ def create_query_filter (search : Search , op : str , field : str , value : float ):
118
128
"""Database logic to perform query for search endpoint."""
119
129
if op != "eq" :
120
130
key_filter = {field : {f"{ op } " : value }}
@@ -124,7 +134,8 @@ def create_query_filter(self, search, op: str, field: str, value: float):
124
134
125
135
return search
126
136
127
- def search_ids (self , search , item_ids : List ):
137
+ @staticmethod
138
+ def search_ids (search : Search , item_ids : List ):
128
139
"""Database logic to search a list of STAC item ids."""
129
140
id_list = []
130
141
for item_id in item_ids :
@@ -134,17 +145,14 @@ def search_ids(self, search, item_ids: List):
134
145
135
146
return search
136
147
137
- def search_collections (self , search , collection_ids : List ):
148
+ @staticmethod
149
+ def search_collections (search : Search , collection_ids : List ):
138
150
"""Database logic to search a list of STAC collection ids."""
139
- collection_list = []
140
- for collection_id in collection_ids :
141
- collection_list .append (Q ("match_phrase" , ** {"collection" : collection_id }))
142
- collection_filter = Q ("bool" , should = collection_list )
143
- search = search .query (collection_filter )
144
-
145
- return search
151
+ collections_query = [Q ("term" , ** {"collection" : cid }) for cid in collection_ids ]
152
+ return search .query (Q ("bool" , should = collections_query ))
146
153
147
- def search_datetime (self , search , datetime_search ):
154
+ @staticmethod
155
+ def search_datetime (search : Search , datetime_search ):
148
156
"""Database logic to search datetime field."""
149
157
if "eq" in datetime_search :
150
158
search = search .query (
@@ -159,9 +167,10 @@ def search_datetime(self, search, datetime_search):
159
167
)
160
168
return search
161
169
162
- def search_bbox (self , search , bbox : List ):
170
+ @staticmethod
171
+ def search_bbox (search : Search , bbox : List ):
163
172
"""Database logic to search on bounding box."""
164
- poly = self .bbox2poly (bbox [0 ], bbox [1 ], bbox [2 ], bbox [3 ])
173
+ poly = DatabaseLogic .bbox2poly (bbox [0 ], bbox [1 ], bbox [2 ], bbox [3 ])
165
174
bbox_filter = Q (
166
175
{
167
176
"geo_shape" : {
@@ -175,7 +184,19 @@ def search_bbox(self, search, bbox: List):
175
184
search = search .query (bbox_filter )
176
185
return search
177
186
178
- def search_intersects (self , search , intersects : dict ):
187
+ @staticmethod
188
+ def search_intersects (
189
+ search : Search ,
190
+ intersects : Union [
191
+ Point ,
192
+ MultiPoint ,
193
+ LineString ,
194
+ MultiLineString ,
195
+ Polygon ,
196
+ MultiPolygon ,
197
+ GeometryCollection ,
198
+ ],
199
+ ):
179
200
"""Database logic to search a geojson object."""
180
201
intersect_filter = Q (
181
202
{
@@ -193,24 +214,27 @@ def search_intersects(self, search, intersects: dict):
193
214
search = search .query (intersect_filter )
194
215
return search
195
216
196
- def sort_field ( self , search , field , direction ):
197
- """Database logic to sort nosql search instance."""
198
- search = search . sort ({ field : { "order" : direction }})
199
- return search
217
+ @ staticmethod
218
+ def sort_field ( search : Search , field , direction ):
219
+ """Database logic to sort search instance."""
220
+ return search . sort ({ field : { "order" : direction }})
200
221
201
- def search_count (self , search ) -> int :
222
+ def search_count (self , search : Search ) -> int :
202
223
"""Database logic to count search results."""
203
224
try :
204
- count = search .count ()
225
+ return self .client .count (
226
+ index = ITEMS_INDEX , body = search .to_dict (count = True )
227
+ ).get ("count" )
205
228
except elasticsearch .exceptions .NotFoundError :
206
229
raise NotFoundError ("No items exist" )
207
230
208
- return count
209
-
210
231
def execute_search (self , search , limit : int , base_url : str ) -> List :
211
232
"""Database logic to execute search with limit."""
212
233
search = search .query ()[0 :limit ]
213
- response = search .execute ().to_dict ()
234
+ body = search .to_dict ()
235
+ response = self .client .search (
236
+ index = ITEMS_INDEX , query = body ["query" ], sort = body .get ("sort" )
237
+ )
214
238
215
239
if len (response ["hits" ]["hits" ]) > 0 :
216
240
response_features = [
@@ -242,30 +266,35 @@ def prep_create_item(self, item: Item, base_url: str) -> Item:
242
266
243
267
return self .item_serializer .stac_to_db (item , base_url )
244
268
245
- def create_item (self , item : Item , base_url : str ):
269
+ def create_item (self , item : Item , refresh : bool = False ):
246
270
"""Database logic for creating one item."""
247
271
# todo: check if collection exists, but cache
248
272
es_resp = self .client .index (
249
273
index = ITEMS_INDEX ,
250
274
id = mk_item_id (item ["id" ], item ["collection" ]),
251
275
document = item ,
276
+ refresh = refresh ,
252
277
)
253
278
254
279
if (meta := es_resp .get ("meta" )) and meta .get ("status" ) == 409 :
255
280
raise ConflictError (
256
281
f"Item { item ['id' ]} in collection { item ['collection' ]} already exists"
257
282
)
258
283
259
- def delete_item (self , item_id : str , collection_id : str ):
284
+ def delete_item (self , item_id : str , collection_id : str , refresh : bool = False ):
260
285
"""Database logic for deleting one item."""
261
286
try :
262
- self .client .delete (index = ITEMS_INDEX , id = mk_item_id (item_id , collection_id ))
287
+ self .client .delete (
288
+ index = ITEMS_INDEX ,
289
+ id = mk_item_id (item_id , collection_id ),
290
+ refresh = refresh ,
291
+ )
263
292
except elasticsearch .exceptions .NotFoundError :
264
293
raise NotFoundError (
265
294
f"Item { item_id } in collection { collection_id } not found"
266
295
)
267
296
268
- def create_collection (self , collection : Collection ):
297
+ def create_collection (self , collection : Collection , refresh : bool = False ):
269
298
"""Database logic for creating one collection."""
270
299
if self .client .exists (index = COLLECTIONS_INDEX , id = collection ["id" ]):
271
300
raise ConflictError (f"Collection { collection ['id' ]} already exists" )
@@ -274,6 +303,7 @@ def create_collection(self, collection: Collection):
274
303
index = COLLECTIONS_INDEX ,
275
304
id = collection ["id" ],
276
305
document = collection ,
306
+ refresh = refresh ,
277
307
)
278
308
279
309
def find_collection (self , collection_id : str ) -> Collection :
@@ -285,12 +315,12 @@ def find_collection(self, collection_id: str) -> Collection:
285
315
286
316
return collection ["_source" ]
287
317
288
- def delete_collection (self , collection_id : str ):
318
+ def delete_collection (self , collection_id : str , refresh : bool = False ):
289
319
"""Database logic for deleting one collection."""
290
320
_ = self .find_collection (collection_id = collection_id )
291
- self .client .delete (index = COLLECTIONS_INDEX , id = collection_id )
321
+ self .client .delete (index = COLLECTIONS_INDEX , id = collection_id , refresh = refresh )
292
322
293
- def bulk_sync (self , processed_items ):
323
+ def bulk_sync (self , processed_items , refresh : bool = False ):
294
324
"""Database logic for bulk item insertion."""
295
325
actions = [
296
326
{
@@ -300,4 +330,22 @@ def bulk_sync(self, processed_items):
300
330
}
301
331
for item in processed_items
302
332
]
303
- helpers .bulk (self .client , actions )
333
+ helpers .bulk (self .client , actions , refresh = refresh )
334
+
335
+ # DANGER
336
+ def delete_items (self ) -> None :
337
+ """Danger. this is only for tests."""
338
+ self .client .delete_by_query (
339
+ index = ITEMS_INDEX ,
340
+ body = {"query" : {"match_all" : {}}},
341
+ wait_for_completion = True ,
342
+ )
343
+
344
+ # DANGER
345
+ def delete_collections (self ) -> None :
346
+ """Danger. this is only for tests."""
347
+ self .client .delete_by_query (
348
+ index = COLLECTIONS_INDEX ,
349
+ body = {"query" : {"match_all" : {}}},
350
+ wait_for_completion = True ,
351
+ )
0 commit comments