7
7
from urllib .parse import urljoin
8
8
9
9
import attr
10
+ import stac_pydantic .api
10
11
from fastapi import HTTPException
11
12
from overrides import overrides
12
13
from pydantic import ValidationError
23
24
Items ,
24
25
)
25
26
from stac_fastapi .types import stac as stac_types
26
- from stac_fastapi .types .core import BaseCoreClient , BaseTransactionsClient
27
+ from stac_fastapi .types .core import AsyncBaseCoreClient , AsyncBaseTransactionsClient
27
28
from stac_fastapi .types .links import CollectionLinks
28
29
from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
29
30
33
34
34
35
35
36
@attr .s
36
- class CoreCrudClient ( BaseCoreClient ):
37
+ class CoreClient ( AsyncBaseCoreClient ):
37
38
"""Client for core endpoints defined by stac."""
38
39
39
40
session : Session = attr .ib (default = attr .Factory (Session .create_from_env ))
40
- item_serializer : Type [serializers .Serializer ] = attr .ib (
41
+ item_serializer : Type [serializers .ItemSerializer ] = attr .ib (
41
42
default = serializers .ItemSerializer
42
43
)
43
- collection_serializer : Type [serializers .Serializer ] = attr .ib (
44
+ collection_serializer : Type [serializers .CollectionSerializer ] = attr .ib (
44
45
default = serializers .CollectionSerializer
45
46
)
46
47
database = DatabaseLogic ()
47
48
48
49
@overrides
49
- def all_collections (self , ** kwargs ) -> Collections :
50
+ async def all_collections (self , ** kwargs ) -> Collections :
50
51
"""Read all collections from the database."""
51
52
base_url = str (kwargs ["request" ].base_url )
52
- serialized_collections = self .database .get_all_collections (base_url = base_url )
53
+ serialized_collections = await self .database .get_all_collections (
54
+ base_url = base_url
55
+ )
53
56
54
57
links = [
55
58
{
@@ -74,21 +77,21 @@ def all_collections(self, **kwargs) -> Collections:
74
77
return collection_list
75
78
76
79
@overrides
77
- def get_collection (self , collection_id : str , ** kwargs ) -> Collection :
80
+ async def get_collection (self , collection_id : str , ** kwargs ) -> Collection :
78
81
"""Get collection by id."""
79
82
base_url = str (kwargs ["request" ].base_url )
80
- collection = self .database .find_collection (collection_id = collection_id )
83
+ collection = await self .database .find_collection (collection_id = collection_id )
81
84
return self .collection_serializer .db_to_stac (collection , base_url )
82
85
83
86
@overrides
84
- def item_collection (
87
+ async def item_collection (
85
88
self , collection_id : str , limit : int = 10 , token : str = None , ** kwargs
86
89
) -> ItemCollection :
87
90
"""Read an item collection from the database."""
88
91
links = []
89
92
base_url = str (kwargs ["request" ].base_url )
90
93
91
- serialized_children , count = self .database .get_item_collection (
94
+ serialized_children , count = await self .database .get_item_collection (
92
95
collection_id = collection_id , limit = limit , base_url = base_url
93
96
)
94
97
@@ -108,10 +111,12 @@ def item_collection(
108
111
)
109
112
110
113
@overrides
111
- def get_item (self , item_id : str , collection_id : str , ** kwargs ) -> Item :
114
+ async def get_item (self , item_id : str , collection_id : str , ** kwargs ) -> Item :
112
115
"""Get item by item id, collection id."""
113
116
base_url = str (kwargs ["request" ].base_url )
114
- item = self .database .get_one_item (item_id = item_id , collection_id = collection_id )
117
+ item = await self .database .get_one_item (
118
+ item_id = item_id , collection_id = collection_id
119
+ )
115
120
return self .item_serializer .db_to_stac (item , base_url )
116
121
117
122
@staticmethod
@@ -139,7 +144,7 @@ def _return_date(interval_str):
139
144
return {"lte" : end_date , "gte" : start_date }
140
145
141
146
@overrides
142
- def get_search (
147
+ async def get_search (
143
148
self ,
144
149
collections : Optional [List [str ]] = None ,
145
150
ids : Optional [List [str ]] = None ,
@@ -192,18 +197,19 @@ def get_search(
192
197
search_request = self .post_request_model (** base_args )
193
198
except ValidationError :
194
199
raise HTTPException (status_code = 400 , detail = "Invalid parameters provided" )
195
- resp = self .post_search (search_request , request = kwargs ["request" ])
200
+ resp = await self .post_search (search_request , request = kwargs ["request" ])
196
201
197
202
return resp
198
203
199
- def post_search (self , search_request , ** kwargs ) -> ItemCollection :
204
+ @overrides
205
+ async def post_search (
206
+ self , search_request : stac_pydantic .api .Search , ** kwargs
207
+ ) -> ItemCollection :
200
208
"""POST search catalog."""
201
209
base_url = str (kwargs ["request" ].base_url )
202
- search = self .database .create_search_object ()
210
+ search = self .database .create_search ()
203
211
204
212
if search_request .query :
205
- if type (search_request .query ) == str :
206
- search_request .query = json .loads (search_request .query )
207
213
for (field_name , expr ) in search_request .query .items ():
208
214
field = "properties__" + field_name
209
215
for (op , value ) in expr .items ():
@@ -217,7 +223,7 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
217
223
)
218
224
219
225
if search_request .collections :
220
- search = self .database .search_collections (
226
+ search = self .database .filter_collections (
221
227
search = search , collection_ids = search_request .collections
222
228
)
223
229
@@ -247,9 +253,9 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
247
253
search = search , field = sort .field , direction = sort .direction
248
254
)
249
255
250
- count = self .database .search_count (search = search )
256
+ count = await self .database .search_count (search = search )
251
257
252
- response_features = self .database .execute_search (
258
+ response_features = await self .database .execute_search (
253
259
search = search , limit = search_request .limit , base_url = base_url
254
260
)
255
261
@@ -298,57 +304,57 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
298
304
299
305
300
306
@attr .s
301
- class TransactionsClient (BaseTransactionsClient ):
307
+ class TransactionsClient (AsyncBaseTransactionsClient ):
302
308
"""Transactions extension specific CRUD operations."""
303
309
304
310
session : Session = attr .ib (default = attr .Factory (Session .create_from_env ))
305
311
database = DatabaseLogic ()
306
312
307
313
@overrides
308
- def create_item (self , item : stac_types .Item , ** kwargs ) -> stac_types .Item :
314
+ async def create_item (self , item : stac_types .Item , ** kwargs ) -> stac_types .Item :
309
315
"""Create item."""
310
316
base_url = str (kwargs ["request" ].base_url )
311
317
312
318
# If a feature collection is posted
313
319
if item ["type" ] == "FeatureCollection" :
314
320
bulk_client = BulkTransactionsClient ()
315
321
processed_items = [
316
- bulk_client .preprocess_item (item , base_url ) for item in item ["features" ]
322
+ bulk_client .preprocess_item (item , base_url ) for item in item ["features" ] # type: ignore
317
323
]
318
- self .database .bulk_sync (
324
+ await self .database .bulk_async (
319
325
processed_items , refresh = kwargs .get ("refresh" , False )
320
326
)
321
327
322
- return None
328
+ return None # type: ignore
323
329
else :
324
- item = self .database .prep_create_item (item = item , base_url = base_url )
325
- self .database .create_item (item , refresh = kwargs .get ("refresh" , False ))
330
+ item = await self .database .prep_create_item (item = item , base_url = base_url )
331
+ await self .database .create_item (item , refresh = kwargs .get ("refresh" , False ))
326
332
return item
327
333
328
334
@overrides
329
- def update_item (self , item : stac_types .Item , ** kwargs ) -> stac_types .Item :
335
+ async def update_item (self , item : stac_types .Item , ** kwargs ) -> stac_types .Item :
330
336
"""Update item."""
331
337
base_url = str (kwargs ["request" ].base_url )
332
338
now = datetime_type .now (timezone .utc ).isoformat ().replace ("+00:00" , "Z" )
333
339
item ["properties" ]["updated" ] = str (now )
334
340
335
- self .database .check_collection_exists (collection_id = item ["collection" ])
341
+ await self .database .check_collection_exists (collection_id = item ["collection" ])
336
342
# todo: index instead of delete and create
337
- self .delete_item (item_id = item ["id" ], collection_id = item ["collection" ])
338
- self .create_item (item = item , ** kwargs )
343
+ await self .delete_item (item_id = item ["id" ], collection_id = item ["collection" ])
344
+ await self .create_item (item = item , ** kwargs )
339
345
340
346
return ItemSerializer .db_to_stac (item , base_url )
341
347
342
348
@overrides
343
- def delete_item (
349
+ async def delete_item (
344
350
self , item_id : str , collection_id : str , ** kwargs
345
351
) -> stac_types .Item :
346
352
"""Delete item."""
347
- self .database .delete_item (item_id = item_id , collection_id = collection_id )
348
- return None
353
+ await self .database .delete_item (item_id = item_id , collection_id = collection_id )
354
+ return None # type: ignore
349
355
350
356
@overrides
351
- def create_collection (
357
+ async def create_collection (
352
358
self , collection : stac_types .Collection , ** kwargs
353
359
) -> stac_types .Collection :
354
360
"""Create collection."""
@@ -357,28 +363,30 @@ def create_collection(
357
363
collection_id = collection ["id" ], base_url = base_url
358
364
).create_links ()
359
365
collection ["links" ] = collection_links
360
- self .database .create_collection (collection = collection )
366
+ await self .database .create_collection (collection = collection )
361
367
362
368
return CollectionSerializer .db_to_stac (collection , base_url )
363
369
364
370
@overrides
365
- def update_collection (
371
+ async def update_collection (
366
372
self , collection : stac_types .Collection , ** kwargs
367
373
) -> stac_types .Collection :
368
374
"""Update collection."""
369
375
base_url = str (kwargs ["request" ].base_url )
370
376
371
- self .database .find_collection (collection_id = collection ["id" ])
372
- self .delete_collection (collection ["id" ])
373
- self .create_collection (collection , ** kwargs )
377
+ await self .database .find_collection (collection_id = collection ["id" ])
378
+ await self .delete_collection (collection ["id" ])
379
+ await self .create_collection (collection , ** kwargs )
374
380
375
381
return CollectionSerializer .db_to_stac (collection , base_url )
376
382
377
383
@overrides
378
- def delete_collection (self , collection_id : str , ** kwargs ) -> stac_types .Collection :
384
+ async def delete_collection (
385
+ self , collection_id : str , ** kwargs
386
+ ) -> stac_types .Collection :
379
387
"""Delete collection."""
380
- self .database .delete_collection (collection_id = collection_id )
381
- return None
388
+ await self .database .delete_collection (collection_id = collection_id )
389
+ return None # type: ignore
382
390
383
391
384
392
@attr .s
@@ -395,8 +403,7 @@ def __attrs_post_init__(self):
395
403
396
404
def preprocess_item (self , item : stac_types .Item , base_url ) -> stac_types .Item :
397
405
"""Preprocess items to match data model."""
398
- item = self .database .prep_create_item (item = item , base_url = base_url )
399
- return item
406
+ return self .database .sync_prep_create_item (item = item , base_url = base_url )
400
407
401
408
@overrides
402
409
def bulk_item_insert (
0 commit comments