|
2 | 2 | import json
|
3 | 3 | import logging
|
4 | 4 | from datetime import datetime as datetime_type
|
| 5 | +from datetime import timezone |
5 | 6 | from typing import List, Optional, Type, Union
|
6 | 7 | from urllib.parse import urljoin
|
7 | 8 |
|
8 | 9 | import attr
|
9 | 10 | from fastapi import HTTPException
|
10 | 11 | from overrides import overrides
|
11 |
| - |
12 |
| -# from geojson_pydantic.geometries import Polygon |
13 | 12 | from pydantic import ValidationError
|
14 | 13 | from stac_pydantic.links import Relations
|
15 | 14 | from stac_pydantic.shared import MimeTypes
|
16 | 15 |
|
17 | 16 | from stac_fastapi.elasticsearch import serializers
|
| 17 | +from stac_fastapi.elasticsearch.config import ElasticsearchSettings |
18 | 18 | from stac_fastapi.elasticsearch.database_logic import DatabaseLogic
|
| 19 | +from stac_fastapi.elasticsearch.serializers import CollectionSerializer, ItemSerializer |
19 | 20 | from stac_fastapi.elasticsearch.session import Session
|
20 |
| - |
21 |
| -# from stac_fastapi.elasticsearch.types.error_checks import ErrorChecks |
22 |
| -from stac_fastapi.types.core import BaseCoreClient |
| 21 | +from stac_fastapi.extensions.third_party.bulk_transactions import ( |
| 22 | + BaseBulkTransactionsClient, |
| 23 | + Items, |
| 24 | +) |
| 25 | +from stac_fastapi.types import stac as stac_types |
| 26 | +from stac_fastapi.types.core import BaseCoreClient, BaseTransactionsClient |
| 27 | +from stac_fastapi.types.links import CollectionLinks |
23 | 28 | from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection
|
24 | 29 |
|
25 | 30 | logger = logging.getLogger(__name__)
|
@@ -291,3 +296,123 @@ def post_search(self, search_request, **kwargs) -> ItemCollection:
|
291 | 296 | links=links,
|
292 | 297 | context=context_obj,
|
293 | 298 | )
|
| 299 | + |
| 300 | + |
| 301 | +@attr.s |
| 302 | +class TransactionsClient(BaseTransactionsClient): |
| 303 | + """Transactions extension specific CRUD operations.""" |
| 304 | + |
| 305 | + session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) |
| 306 | + database = DatabaseLogic() |
| 307 | + |
| 308 | + @overrides |
| 309 | + def create_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: |
| 310 | + """Create item.""" |
| 311 | + base_url = str(kwargs["request"].base_url) |
| 312 | + |
| 313 | + # If a feature collection is posted |
| 314 | + if item["type"] == "FeatureCollection": |
| 315 | + bulk_client = BulkTransactionsClient() |
| 316 | + processed_items = [ |
| 317 | + bulk_client.preprocess_item(item, base_url) for item in item["features"] |
| 318 | + ] |
| 319 | + return_msg = f"Successfully added {len(processed_items)} items." |
| 320 | + self.database.bulk_sync(processed_items) |
| 321 | + |
| 322 | + return return_msg |
| 323 | + else: |
| 324 | + item = self.database.prep_create_item(item=item, base_url=base_url) |
| 325 | + self.database.create_item(item=item, base_url=base_url) |
| 326 | + return item |
| 327 | + |
| 328 | + @overrides |
| 329 | + def update_item(self, item: stac_types.Item, **kwargs) -> stac_types.Item: |
| 330 | + """Update item.""" |
| 331 | + base_url = str(kwargs["request"].base_url) |
| 332 | + now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z") |
| 333 | + item["properties"]["updated"] = str(now) |
| 334 | + |
| 335 | + self.database.check_collection_exists(collection_id=item["collection"]) |
| 336 | + # todo: index instead of delete and create |
| 337 | + self.delete_item(item_id=item["id"], collection_id=item["collection"]) |
| 338 | + self.create_item(item=item, **kwargs) |
| 339 | + |
| 340 | + return ItemSerializer.db_to_stac(item, base_url) |
| 341 | + |
| 342 | + @overrides |
| 343 | + def delete_item( |
| 344 | + self, item_id: str, collection_id: str, **kwargs |
| 345 | + ) -> stac_types.Item: |
| 346 | + """Delete item.""" |
| 347 | + self.database.delete_item(item_id=item_id, collection_id=collection_id) |
| 348 | + return None |
| 349 | + |
| 350 | + @overrides |
| 351 | + def create_collection( |
| 352 | + self, collection: stac_types.Collection, **kwargs |
| 353 | + ) -> stac_types.Collection: |
| 354 | + """Create collection.""" |
| 355 | + base_url = str(kwargs["request"].base_url) |
| 356 | + collection_links = CollectionLinks( |
| 357 | + collection_id=collection["id"], base_url=base_url |
| 358 | + ).create_links() |
| 359 | + collection["links"] = collection_links |
| 360 | + self.database.create_collection(collection=collection) |
| 361 | + |
| 362 | + return CollectionSerializer.db_to_stac(collection, base_url) |
| 363 | + |
| 364 | + @overrides |
| 365 | + def update_collection( |
| 366 | + self, collection: stac_types.Collection, **kwargs |
| 367 | + ) -> stac_types.Collection: |
| 368 | + """Update collection.""" |
| 369 | + base_url = str(kwargs["request"].base_url) |
| 370 | + |
| 371 | + self.database.find_collection(collection_id=collection["id"]) |
| 372 | + self.delete_collection(collection["id"]) |
| 373 | + self.create_collection(collection, **kwargs) |
| 374 | + |
| 375 | + return CollectionSerializer.db_to_stac(collection, base_url) |
| 376 | + |
| 377 | + @overrides |
| 378 | + def delete_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: |
| 379 | + """Delete collection.""" |
| 380 | + self.database.delete_collection(collection_id=collection_id) |
| 381 | + return None |
| 382 | + |
| 383 | + |
| 384 | +@attr.s |
| 385 | +class BulkTransactionsClient(BaseBulkTransactionsClient): |
| 386 | + """Postgres bulk transactions.""" |
| 387 | + |
| 388 | + session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) |
| 389 | + database = DatabaseLogic() |
| 390 | + |
| 391 | + def __attrs_post_init__(self): |
| 392 | + """Create es engine.""" |
| 393 | + settings = ElasticsearchSettings() |
| 394 | + self.client = settings.create_client |
| 395 | + |
| 396 | + def preprocess_item(self, item: stac_types.Item, base_url) -> stac_types.Item: |
| 397 | + """Preprocess items to match data model.""" |
| 398 | + item = self.database.prep_create_item(item=item, base_url=base_url) |
| 399 | + return item |
| 400 | + |
| 401 | + @overrides |
| 402 | + def bulk_item_insert( |
| 403 | + self, items: Items, chunk_size: Optional[int] = None, **kwargs |
| 404 | + ) -> str: |
| 405 | + """Bulk item insertion using es.""" |
| 406 | + request = kwargs.get("request") |
| 407 | + if request: |
| 408 | + base_url = str(request.base_url) |
| 409 | + else: |
| 410 | + base_url = "" |
| 411 | + |
| 412 | + processed_items = [ |
| 413 | + self.preprocess_item(item, base_url) for item in items.items.values() |
| 414 | + ] |
| 415 | + |
| 416 | + self.database.bulk_sync(processed_items) |
| 417 | + |
| 418 | + return f"Successfully added {len(processed_items)} Items." |
0 commit comments