diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py index 421cf91f..e2af91a6 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py @@ -43,7 +43,22 @@ @attr.s class CoreClient(AsyncBaseCoreClient): - """Client for core endpoints defined by stac.""" + """Client for core endpoints defined by the STAC specification. + + This class is a implementation of `AsyncBaseCoreClient` that implements the core endpoints + defined by the STAC specification. It uses the `DatabaseLogic` class to interact with the + database, and `ItemSerializer` and `CollectionSerializer` to convert between STAC objects and + database records. + + Attributes: + session (Session): A requests session instance to be used for all HTTP requests. + item_serializer (Type[serializers.ItemSerializer]): A serializer class to be used to convert + between STAC items and database records. + collection_serializer (Type[serializers.CollectionSerializer]): A serializer class to be + used to convert between STAC collections and database records. + database (DatabaseLogic): An instance of the `DatabaseLogic` class that is used to interact + with the database. + """ session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) item_serializer: Type[serializers.ItemSerializer] = attr.ib( @@ -56,7 +71,15 @@ class CoreClient(AsyncBaseCoreClient): @overrides async def all_collections(self, **kwargs) -> Collections: - """Read all collections from the database.""" + """Read all collections from the database. + + Returns: + Collections: A `Collections` object containing all the collections in the database and + links to various resources. + + Raises: + Exception: If any error occurs while reading the collections from the database. + """ base_url = str(kwargs["request"].base_url) return Collections( @@ -85,7 +108,18 @@ async def all_collections(self, **kwargs) -> Collections: @overrides async def get_collection(self, collection_id: str, **kwargs) -> Collection: - """Get collection by id.""" + """Get a collection from the database by its id. + + Args: + collection_id (str): The id of the collection to retrieve. + kwargs: Additional keyword arguments passed to the API call. + + Returns: + Collection: A `Collection` object representing the requested collection. + + Raises: + NotFoundError: If the collection with the given id cannot be found in the database. + """ base_url = str(kwargs["request"].base_url) collection = await self.database.find_collection(collection_id=collection_id) return self.collection_serializer.db_to_stac(collection, base_url) @@ -100,7 +134,24 @@ async def item_collection( token: str = None, **kwargs, ) -> ItemCollection: - """Read an item collection from the database.""" + """Read items from a specific collection in the database. + + Args: + collection_id (str): The identifier of the collection to read items from. + bbox (Optional[List[NumType]]): The bounding box to filter items by. + datetime (Union[str, datetime_type, None]): The datetime range to filter items by. + limit (int): The maximum number of items to return. The default value is 10. + token (str): A token used for pagination. + request (Request): The incoming request. + + Returns: + ItemCollection: An `ItemCollection` object containing the items from the specified collection that meet + the filter criteria and links to various resources. + + Raises: + HTTPException: If the specified collection is not found. + Exception: If any error occurs while reading the items from the database. + """ request: Request = kwargs["request"] base_url = str(request.base_url) @@ -163,7 +214,19 @@ async def item_collection( @overrides async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: - """Get item by item id, collection id.""" + """Get an item from the database based on its id and collection id. + + Args: + collection_id (str): The ID of the collection the item belongs to. + item_id (str): The ID of the item to be retrieved. + + Returns: + Item: An `Item` object representing the requested item. + + Raises: + Exception: If any error occurs while getting the item from the database. + NotFoundError: If the item does not exist in the specified collection. + """ base_url = str(kwargs["request"].base_url) item = await self.database.get_one_item( item_id=item_id, collection_id=collection_id @@ -172,27 +235,44 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: @staticmethod def _return_date(interval_str): + """ + Convert a date interval string into a dictionary for filtering search results. + + The date interval string should be formatted as either a single date or a range of dates separated + by "/". The date format should be ISO-8601 (YYYY-MM-DDTHH:MM:SSZ). If the interval string is a + single date, it will be converted to a dictionary with a single "eq" key whose value is the date in + the ISO-8601 format. If the interval string is a range of dates, it will be converted to a + dictionary with "gte" (greater than or equal to) and "lte" (less than or equal to) keys. If the + interval string is a range of dates with ".." instead of "/", the start and end dates will be + assigned default values to encompass the entire possible date range. + + Args: + interval_str (str): The date interval string to be converted. + + Returns: + dict: A dictionary representing the date interval for use in filtering search results. + """ intervals = interval_str.split("/") if len(intervals) == 1: - datetime = intervals[0][0:19] + "Z" + datetime = f"{intervals[0][0:19]}Z" return {"eq": datetime} else: start_date = intervals[0] end_date = intervals[1] if ".." not in intervals: - start_date = start_date[0:19] + "Z" - end_date = end_date[0:19] + "Z" + start_date = f"{start_date[0:19]}Z" + end_date = f"{end_date[0:19]}Z" elif start_date != "..": - start_date = start_date[0:19] + "Z" + start_date = f"{start_date[0:19]}Z" end_date = "2200-12-01T12:31:12Z" elif end_date != "..": start_date = "1900-10-01T00:00:00Z" - end_date = end_date[0:19] + "Z" + end_date = f"{end_date[0:19]}Z" else: start_date = "1900-10-01T00:00:00Z" end_date = "2200-12-01T12:31:12Z" - return {"lte": end_date, "gte": start_date} + return {"lte": end_date, "gte": start_date} @overrides async def get_search( @@ -210,7 +290,26 @@ async def get_search( # filter_lang: Optional[str] = None, # todo: requires fastapi > 2.3 unreleased **kwargs, ) -> ItemCollection: - """GET search catalog.""" + """Get search results from the database. + + Args: + collections (Optional[List[str]]): List of collection IDs to search in. + ids (Optional[List[str]]): List of item IDs to search for. + bbox (Optional[List[NumType]]): Bounding box to search in. + datetime (Optional[Union[str, datetime_type]]): Filter items based on the datetime field. + limit (Optional[int]): Maximum number of results to return. + query (Optional[str]): Query string to filter the results. + token (Optional[str]): Access token to use when searching the catalog. + fields (Optional[List[str]]): Fields to include or exclude from the results. + sortby (Optional[str]): Sorting options for the results. + kwargs: Additional parameters to be passed to the API. + + Returns: + ItemCollection: Collection of `Item` objects representing the search results. + + Raises: + HTTPException: If any error occurs while searching the catalog. + """ base_args = { "collections": collections, "ids": ids, @@ -267,7 +366,19 @@ async def get_search( async def post_search( self, search_request: BaseSearchPostRequest, **kwargs ) -> ItemCollection: - """POST search catalog.""" + """ + Perform a POST search on the catalog. + + Args: + search_request (BaseSearchPostRequest): Request object that includes the parameters for the search. + kwargs: Keyword arguments passed to the function. + + Returns: + ItemCollection: A collection of items matching the search criteria. + + Raises: + HTTPException: If there is an error with the cql2_json filter. + """ request: Request = kwargs["request"] base_url = str(request.base_url) @@ -391,7 +502,21 @@ class TransactionsClient(AsyncBaseTransactionsClient): async def create_item( self, collection_id: str, item: stac_types.Item, **kwargs ) -> stac_types.Item: - """Create item.""" + """Create an item in the collection. + + Args: + collection_id (str): The id of the collection to add the item to. + item (stac_types.Item): The item to be added to the collection. + kwargs: Additional keyword arguments. + + Returns: + stac_types.Item: The created item. + + Raises: + NotFound: If the specified collection is not found in the database. + ConflictError: If the item in the specified collection already exists. + + """ base_url = str(kwargs["request"].base_url) # If a feature collection is posted @@ -415,14 +540,26 @@ async def create_item( async def update_item( self, collection_id: str, item_id: str, item: stac_types.Item, **kwargs ) -> stac_types.Item: - """Update item.""" - base_url = str(kwargs["request"].base_url) + """Update an item in the collection. + + Args: + collection_id (str): The ID of the collection the item belongs to. + item_id (str): The ID of the item to be updated. + item (stac_types.Item): The new item data. + kwargs: Other optional arguments, including the request object. + + Returns: + stac_types.Item: The updated item object. + Raises: + NotFound: If the specified collection is not found in the database. + + """ + base_url = str(kwargs["request"].base_url) now = datetime_type.now(timezone.utc).isoformat().replace("+00:00", "Z") - item["properties"]["updated"] = str(now) + item["properties"]["updated"] = now await self.database.check_collection_exists(collection_id) - # todo: index instead of delete and create await self.delete_item(item_id=item_id, collection_id=collection_id) await self.create_item(collection_id=collection_id, item=item, **kwargs) @@ -432,7 +569,15 @@ async def update_item( async def delete_item( self, item_id: str, collection_id: str, **kwargs ) -> stac_types.Item: - """Delete item.""" + """Delete an item from a collection. + + Args: + item_id (str): The identifier of the item to delete. + collection_id (str): The identifier of the collection that contains the item. + + Returns: + Optional[stac_types.Item]: The deleted item, or `None` if the item was successfully deleted. + """ await self.database.delete_item(item_id=item_id, collection_id=collection_id) return None # type: ignore @@ -440,7 +585,18 @@ async def delete_item( async def create_collection( self, collection: stac_types.Collection, **kwargs ) -> stac_types.Collection: - """Create collection.""" + """Create a new collection in the database. + + Args: + collection (stac_types.Collection): The collection to be created. + kwargs: Additional keyword arguments. + + Returns: + stac_types.Collection: The created collection object. + + Raises: + ConflictError: If the collection already exists. + """ base_url = str(kwargs["request"].base_url) collection_links = CollectionLinks( collection_id=collection["id"], base_url=base_url @@ -454,7 +610,21 @@ async def create_collection( async def update_collection( self, collection: stac_types.Collection, **kwargs ) -> stac_types.Collection: - """Update collection.""" + """ + Update a collection. + + This method updates an existing collection in the database by first finding + the collection by its id, then deleting the old version, and finally creating + a new version of the updated collection. The updated collection is then returned. + + Args: + collection: A STAC collection that needs to be updated. + kwargs: Additional keyword arguments. + + Returns: + A STAC collection that has been updated in the database. + + """ base_url = str(kwargs["request"].base_url) await self.database.find_collection(collection_id=collection["id"]) @@ -467,14 +637,33 @@ async def update_collection( async def delete_collection( self, collection_id: str, **kwargs ) -> stac_types.Collection: - """Delete collection.""" + """ + Delete a collection. + + This method deletes an existing collection in the database. + + Args: + collection_id (str): The identifier of the collection that contains the item. + kwargs: Additional keyword arguments. + + Returns: + None. + + Raises: + NotFoundError: If the collection doesn't exist. + """ await self.database.delete_collection(collection_id=collection_id) return None # type: ignore @attr.s class BulkTransactionsClient(BaseBulkTransactionsClient): - """Postgres bulk transactions.""" + """A client for posting bulk transactions to a Postgres database. + + Attributes: + session: An instance of `Session` to use for database connection. + database: An instance of `DatabaseLogic` to perform database operations. + """ session: Session = attr.ib(default=attr.Factory(Session.create_from_env)) database = DatabaseLogic() @@ -485,14 +674,31 @@ def __attrs_post_init__(self): self.client = settings.create_client def preprocess_item(self, item: stac_types.Item, base_url) -> stac_types.Item: - """Preprocess items to match data model.""" + """Preprocess an item to match the data model. + + Args: + item: The item to preprocess. + base_url: The base URL of the request. + + Returns: + The preprocessed item. + """ return self.database.sync_prep_create_item(item=item, base_url=base_url) @overrides def bulk_item_insert( self, items: Items, chunk_size: Optional[int] = None, **kwargs ) -> str: - """Bulk item insertion using es.""" + """Perform a bulk insertion of items into the database using Elasticsearch. + + Args: + items: The items to insert. + chunk_size: The size of each chunk for bulk processing. + **kwargs: Additional keyword arguments, such as `request` and `refresh`. + + Returns: + A string indicating the number of items successfully added. + """ request = kwargs.get("request") if request: base_url = str(request.base_url) @@ -530,6 +736,13 @@ async def get_queryables( under OGC CQL but it is allowed by the STAC API Filter Extension https://github.com/radiantearth/stac-api-spec/tree/master/fragments/filter#queryables + + Args: + collection_id (str, optional): The id of the collection to get queryables for. + **kwargs: additional keyword arguments + + Returns: + Dict[str, Any]: A dictionary containing the queryables for the given collection. """ return { "$schema": "https://json-schema.org/draft/2019-09/schema", diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index a59e5b2c..6a7a1f27 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -127,12 +127,28 @@ def index_by_collection_id(collection_id: str) -> str: - """Translate a collection id into an ES index name.""" + """ + Translate a collection id into an Elasticsearch index name. + + Args: + collection_id (str): The collection id to translate into an index name. + + Returns: + str: The index name derived from the collection id. + """ return f"{ITEMS_INDEX_PREFIX}{collection_id}" def indices(collection_ids: Optional[List[str]]) -> str: - """Get a comma-separated string value of indexes for a given list of collection ids.""" + """ + Get a comma-separated string of index names for a given list of collection ids. + + Args: + collection_ids: A list of collection ids. + + Returns: + A string of comma-separated index names. If `collection_ids` is None, returns the default indices. + """ if collection_ids is None: return DEFAULT_INDICES else: @@ -140,7 +156,11 @@ def indices(collection_ids: Optional[List[str]]) -> str: async def create_collection_index() -> None: - """Create the index for Collections.""" + """Create the index for Collections in Elasticsearch. + + This function creates the Elasticsearch index for the `Collections` with the predefined mapping. + If the index already exists, the function ignores the error and continues execution. + """ client = AsyncElasticsearchSettings().create_client await client.indices.create( @@ -152,7 +172,16 @@ async def create_collection_index() -> None: async def create_item_index(collection_id: str): - """Create the index for Items.""" + """ + Create the index for Items. + + Args: + collection_id (str): Collection identifier. + + Returns: + None + + """ client = AsyncElasticsearchSettings().create_client await client.indices.create( @@ -165,7 +194,11 @@ async def create_item_index(collection_id: str): async def delete_item_index(collection_id: str): - """Delete the index for Items.""" + """Delete the index for items in a collection. + + Args: + collection_id (str): The ID of the collection whose items index will be deleted. + """ client = AsyncElasticsearchSettings().create_client await client.indices.delete(index=index_by_collection_id(collection_id)) @@ -173,17 +206,47 @@ async def delete_item_index(collection_id: str): def bbox2polygon(b0: float, b1: float, b2: float, b3: float) -> List[List[List[float]]]: - """Transform bbox to polygon.""" + """Transform a bounding box represented by its four coordinates `b0`, `b1`, `b2`, and `b3` into a polygon. + + Args: + b0 (float): The x-coordinate of the lower-left corner of the bounding box. + b1 (float): The y-coordinate of the lower-left corner of the bounding box. + b2 (float): The x-coordinate of the upper-right corner of the bounding box. + b3 (float): The y-coordinate of the upper-right corner of the bounding box. + + Returns: + List[List[List[float]]]: A polygon represented as a list of lists of coordinates. + """ return [[[b0, b1], [b2, b1], [b2, b3], [b0, b3], [b0, b1]]] def mk_item_id(item_id: str, collection_id: str): - """Make the Elasticsearch document _id value from the Item id and collection.""" + """Create the document id for an Item in Elasticsearch. + + Args: + item_id (str): The id of the Item. + collection_id (str): The id of the Collection that the Item belongs to. + + Returns: + str: The document id for the Item, combining the Item id and the Collection id, separated by a `|` character. + """ return f"{item_id}|{collection_id}" def mk_actions(collection_id: str, processed_items: List[Item]): - """Make the Elasticsearch bulk action for a list of Items.""" + """Create Elasticsearch bulk actions for a list of processed items. + + Args: + collection_id (str): The identifier for the collection the items belong to. + processed_items (List[Item]): The list of processed items to be bulk indexed. + + Returns: + List[Dict[str, Union[str, Dict]]]: The list of bulk actions to be executed, + each action being a dictionary with the following keys: + - `_index`: the index to store the document in. + - `_id`: the document's identifier. + - `_source`: the source of the document. + """ return [ { "_index": index_by_collection_id(collection_id), @@ -219,14 +282,38 @@ class DatabaseLogic: """CORE LOGIC""" async def get_all_collections(self) -> Iterable[Dict[str, Any]]: - """Database logic to retrieve a list of all collections.""" + """Retrieve a list of all collections from the database. + + Returns: + collections (Iterable[Dict[str, Any]]): A list of dictionaries containing the source data for each collection. + + Notes: + The collections are retrieved from the Elasticsearch database using the `client.search` method, + with the `COLLECTIONS_INDEX` as the target index and `size=1000` to retrieve up to 1000 records. + The result is a generator of dictionaries containing the source data for each collection. + """ # https://github.com/stac-utils/stac-fastapi-elasticsearch/issues/65 # collections should be paginated, but at least return more than the default 10 for now collections = await self.client.search(index=COLLECTIONS_INDEX, size=1000) return (c["_source"] for c in collections["hits"]["hits"]) async def get_one_item(self, collection_id: str, item_id: str) -> Dict: - """Database logic to retrieve a single item.""" + """Retrieve a single item from the database. + + Args: + collection_id (str): The id of the Collection that the Item belongs to. + item_id (str): The id of the Item. + + Returns: + item (Dict): A dictionary containing the source data for the Item. + + Raises: + NotFoundError: If the specified Item does not exist in the Collection. + + Notes: + The Item is retrieved from the Elasticsearch database using the `client.get` method, + with the index for the Collection as the target index and the combined `mk_item_id` as the document id. + """ try: item = await self.client.get( index=index_by_collection_id(collection_id), @@ -255,7 +342,15 @@ def apply_collections_filter(search: Search, collection_ids: List[str]): @staticmethod def apply_datetime_filter(search: Search, datetime_search): - """Database logic to search datetime field.""" + """Apply a filter to search based on datetime field. + + Args: + search (Search): The search object to filter. + datetime_search (dict): The datetime filter criteria. + + Returns: + Search: The filtered search object. + """ if "eq" in datetime_search: search = search.filter( "term", **{"properties__datetime": datetime_search["eq"]} @@ -271,7 +366,19 @@ def apply_datetime_filter(search: Search, datetime_search): @staticmethod def apply_bbox_filter(search: Search, bbox: List): - """Database logic to search on bounding box.""" + """Filter search results based on bounding box. + + Args: + search (Search): The search object to apply the filter to. + bbox (List): The bounding box coordinates, represented as a list of four values [minx, miny, maxx, maxy]. + + Returns: + search (Search): The search object with the bounding box filter applied. + + Notes: + The bounding box is transformed into a polygon using the `bbox2polygon` function and + a geo_shape filter is added to the search object, set to intersect with the specified polygon. + """ return search.filter( Q( { @@ -293,7 +400,18 @@ def apply_intersects_filter( search: Search, intersects: Geometry, ): - """Database logic to search a geojson object.""" + """Filter search results based on intersecting geometry. + + Args: + search (Search): The search object to apply the filter to. + intersects (Geometry): The intersecting geometry, represented as a GeoJSON-like object. + + Returns: + search (Search): The search object with the intersecting geometry filter applied. + + Notes: + A geo_shape filter is added to the search object, set to intersect with the specified geometry. + """ return search.filter( Q( { @@ -312,7 +430,18 @@ def apply_intersects_filter( @staticmethod def apply_stacql_filter(search: Search, op: str, field: str, value: float): - """Database logic to perform query for search endpoint.""" + """Filter search results based on a comparison between a field and a value. + + Args: + search (Search): The search object to apply the filter to. + op (str): The comparison operator to use. Can be 'eq' (equal), 'gt' (greater than), 'gte' (greater than or equal), + 'lt' (less than), or 'lte' (less than or equal). + field (str): The field to perform the comparison on. + value (float): The value to compare the field against. + + Returns: + search (Search): The search object with the specified filter applied. + """ if op != "eq": key_filter = {field: {f"{op}": value}} search = search.filter(Q("range", **key_filter)) @@ -345,7 +474,27 @@ async def execute_search( collection_ids: Optional[List[str]], ignore_unavailable: bool = True, ) -> Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]: - """Database logic to execute search with limit.""" + """Execute a search query with limit and other optional parameters. + + Args: + search (Search): The search query to be executed. + limit (int): The maximum number of results to be returned. + token (Optional[str]): The token used to return the next set of results. + sort (Optional[Dict[str, Dict[str, str]]]): Specifies how the results should be sorted. + collection_ids (Optional[List[str]]): The collection ids to search. + ignore_unavailable (bool, optional): Whether to ignore unavailable collections. Defaults to True. + + Returns: + Tuple[Iterable[Dict[str, Any]], Optional[int], Optional[str]]: A tuple containing: + - An iterable of search results, where each result is a dictionary with keys and values representing the + fields and values of each document. + - The total number of results (if the count could be computed), or None if the count could not be + computed. + - The token to be used to retrieve the next set of results, or None if there are no more results. + + Raises: + NotFoundError: If the collections specified in `collection_ids` do not exist. + """ search_after = None if token: search_after = urlsafe_b64decode(token.encode()).decode().split(",") @@ -406,7 +555,20 @@ async def check_collection_exists(self, collection_id: str): raise NotFoundError(f"Collection {collection_id} does not exist") async def prep_create_item(self, item: Item, base_url: str) -> Item: - """Database logic for prepping an item for insertion.""" + """ + Preps an item for insertion into the database. + + Args: + item (Item): The item to be prepped for insertion. + base_url (str): The base URL used to create the item's self URL. + + Returns: + Item: The prepped item. + + Raises: + ConflictError: If the item already exists in the database. + + """ await self.check_collection_exists(collection_id=item["collection"]) if await self.client.exists( @@ -420,7 +582,24 @@ async def prep_create_item(self, item: Item, base_url: str) -> Item: return self.item_serializer.stac_to_db(item, base_url) def sync_prep_create_item(self, item: Item, base_url: str) -> Item: - """Database logic for prepping an item for insertion.""" + """ + Prepare an item for insertion into the database. + + This method performs pre-insertion preparation on the given `item`, + such as checking if the collection the item belongs to exists, + and verifying that an item with the same ID does not already exist in the database. + + Args: + item (Item): The item to be inserted into the database. + base_url (str): The base URL used for constructing URLs for the item. + + Returns: + Item: The item after preparation is done. + + Raises: + NotFoundError: If the collection that the item belongs to does not exist in the database. + ConflictError: If an item with the same ID already exists in the collection. + """ item_id = item["id"] collection_id = item["collection"] if not self.sync_client.exists(index=COLLECTIONS_INDEX, id=collection_id): @@ -437,7 +616,18 @@ def sync_prep_create_item(self, item: Item, base_url: str) -> Item: return self.item_serializer.stac_to_db(item, base_url) async def create_item(self, item: Item, refresh: bool = False): - """Database logic for creating one item.""" + """Database logic for creating one item. + + Args: + item (Item): The item to be created. + refresh (bool, optional): Refresh the index after performing the operation. Defaults to False. + + Raises: + ConflictError: If the item already exists in the database. + + Returns: + None + """ # todo: check if collection exists, but cache item_id = item["id"] collection_id = item["collection"] @@ -456,7 +646,16 @@ async def create_item(self, item: Item, refresh: bool = False): async def delete_item( self, item_id: str, collection_id: str, refresh: bool = False ): - """Database logic for deleting one item.""" + """Delete a single item from the database. + + Args: + item_id (str): The id of the Item to be deleted. + collection_id (str): The id of the Collection that the Item belongs to. + refresh (bool, optional): Whether to refresh the index after the deletion. Default is False. + + Raises: + NotFoundError: If the Item does not exist in the database. + """ try: await self.client.delete( index=index_by_collection_id(collection_id), @@ -469,7 +668,18 @@ async def delete_item( ) async def create_collection(self, collection: Collection, refresh: bool = False): - """Database logic for creating one collection.""" + """Create a single collection in the database. + + Args: + collection (Collection): The Collection object to be created. + refresh (bool, optional): Whether to refresh the index after the creation. Default is False. + + Raises: + ConflictError: If a Collection with the same id already exists in the database. + + Notes: + A new index is created for the items in the Collection using the `create_item_index` function. + """ collection_id = collection["id"] if await self.client.exists(index=COLLECTIONS_INDEX, id=collection_id): @@ -485,7 +695,22 @@ async def create_collection(self, collection: Collection, refresh: bool = False) await create_item_index(collection_id) async def find_collection(self, collection_id: str) -> Collection: - """Database logic to find and return a collection.""" + """Find and return a collection from the database. + + Args: + self: The instance of the object calling this function. + collection_id (str): The ID of the collection to be found. + + Returns: + Collection: The found collection, represented as a `Collection` object. + + Raises: + NotFoundError: If the collection with the given `collection_id` is not found in the database. + + Notes: + This function searches for a collection in the database using the specified `collection_id` and returns the found + collection as a `Collection` object. If the collection is not found, a `NotFoundError` is raised. + """ try: collection = await self.client.get( index=COLLECTIONS_INDEX, id=collection_id @@ -496,7 +721,21 @@ async def find_collection(self, collection_id: str) -> Collection: return collection["_source"] async def delete_collection(self, collection_id: str, refresh: bool = False): - """Database logic for deleting one collection.""" + """Delete a collection from the database. + + Parameters: + self: The instance of the object calling this function. + collection_id (str): The ID of the collection to be deleted. + refresh (bool): Whether to refresh the index after the deletion (default: False). + + Raises: + NotFoundError: If the collection with the given `collection_id` is not found in the database. + + Notes: + This function first verifies that the collection with the specified `collection_id` exists in the database, and then + deletes the collection. If `refresh` is set to True, the index is refreshed after the deletion. Additionally, this + function also calls `delete_item_index` to delete the index for the items in the collection. + """ await self.find_collection(collection_id=collection_id) await self.client.delete( index=COLLECTIONS_INDEX, id=collection_id, refresh=refresh @@ -506,7 +745,20 @@ async def delete_collection(self, collection_id: str, refresh: bool = False): async def bulk_async( self, collection_id: str, processed_items: List[Item], refresh: bool = False ) -> None: - """Database logic for async bulk item insertion.""" + """Perform a bulk insert of items into the database asynchronously. + + Args: + self: The instance of the object calling this function. + collection_id (str): The ID of the collection to which the items belong. + processed_items (List[Item]): A list of `Item` objects to be inserted into the database. + refresh (bool): Whether to refresh the index after the bulk insert (default: False). + + Notes: + This function performs a bulk insert of `processed_items` into the database using the specified `collection_id`. The + insert is performed asynchronously, and the event loop is used to run the operation in a separate executor. The + `mk_actions` function is called to generate a list of actions for the bulk insert. If `refresh` is set to True, the + index is refreshed after the bulk insert. The function does not return any value. + """ await asyncio.get_event_loop().run_in_executor( None, lambda: helpers.bulk( @@ -520,7 +772,20 @@ async def bulk_async( def bulk_sync( self, collection_id: str, processed_items: List[Item], refresh: bool = False ) -> None: - """Database logic for sync bulk item insertion.""" + """Perform a bulk insert of items into the database synchronously. + + Args: + self: The instance of the object calling this function. + collection_id (str): The ID of the collection to which the items belong. + processed_items (List[Item]): A list of `Item` objects to be inserted into the database. + refresh (bool): Whether to refresh the index after the bulk insert (default: False). + + Notes: + This function performs a bulk insert of `processed_items` into the database using the specified `collection_id`. The + insert is performed synchronously and blocking, meaning that the function does not return until the insert has + completed. The `mk_actions` function is called to generate a list of actions for the bulk insert. If `refresh` is set to + True, the index is refreshed after the bulk insert. The function does not return any value. + """ helpers.bulk( self.sync_client, mk_actions(collection_id, processed_items), diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/serializers.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/serializers.py index 4f853a79..690b475b 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/serializers.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/serializers.py @@ -9,14 +9,39 @@ from stac_fastapi.types.links import CollectionLinks, ItemLinks, resolve_links -@attr.s # type:ignore +@attr.s class Serializer(abc.ABC): - """Defines serialization methods between the API and the data model.""" + """Defines serialization methods between the API and the data model. + + This class is meant to be subclassed and implemented by specific serializers for different STAC objects (e.g. Item, Collection). + """ @classmethod @abc.abstractmethod def db_to_stac(cls, item: dict, base_url: str) -> Any: - """Transform database model to stac.""" + """Transform database model to STAC object. + + Arguments: + item (dict): A dictionary representing the database model. + base_url (str): The base URL of the STAC API. + + Returns: + Any: A STAC object, e.g. an `Item` or `Collection`, representing the input `item`. + """ + ... + + @classmethod + @abc.abstractmethod + def stac_to_db(cls, stac_object: Any, base_url: str) -> dict: + """Transform STAC object to database model. + + Arguments: + stac_object (Any): A STAC object, e.g. an `Item` or `Collection`. + base_url (str): The base URL of the STAC API. + + Returns: + dict: A dictionary representing the database model. + """ ... @@ -25,7 +50,15 @@ class ItemSerializer(Serializer): @classmethod def stac_to_db(cls, stac_data: stac_types.Item, base_url: str) -> stac_types.Item: - """Transform STAC Item to database-ready STAC Item.""" + """Transform STAC item to database-ready STAC item. + + Args: + stac_data (stac_types.Item): The STAC item object to be transformed. + base_url (str): The base URL for the STAC API. + + Returns: + stac_types.Item: The database-ready STAC item object. + """ item_links = ItemLinks( collection_id=stac_data["collection"], item_id=stac_data["id"], @@ -33,14 +66,6 @@ def stac_to_db(cls, stac_data: stac_types.Item, base_url: str) -> stac_types.Ite ).create_links() stac_data["links"] = item_links - # elasticsearch doesn't like the fact that some values are float and some were int - if "eo:bands" in stac_data["properties"]: - for wave in stac_data["properties"]["eo:bands"]: - for k, v in wave.items(): - if type(v) != str: - v = float(v) - wave.update({k: v}) - now = now_to_rfc3339_str() if "created" not in stac_data["properties"]: stac_data["properties"]["created"] = now @@ -49,30 +74,36 @@ def stac_to_db(cls, stac_data: stac_types.Item, base_url: str) -> stac_types.Ite @classmethod def db_to_stac(cls, item: dict, base_url: str) -> stac_types.Item: - """Transform database model to stac item.""" + """Transform database-ready STAC item to STAC item. + + Args: + item (dict): The database-ready STAC item to be transformed. + base_url (str): The base URL for the STAC API. + + Returns: + stac_types.Item: The STAC item object. + """ item_id = item["id"] collection_id = item["collection"] item_links = ItemLinks( collection_id=collection_id, item_id=item_id, base_url=base_url ).create_links() - original_links = item["links"] + original_links = item.get("links", []) if original_links: item_links += resolve_links(original_links, base_url) return stac_types.Item( type="Feature", - stac_version=item["stac_version"] if "stac_version" in item else "", - stac_extensions=item["stac_extensions"] - if "stac_extensions" in item - else [], + stac_version=item.get("stac_version", ""), + stac_extensions=item.get("stac_extensions", []), id=item_id, - collection=item["collection"] if "collection" in item else "", - geometry=item["geometry"] if "geometry" in item else {}, - bbox=item["bbox"] if "bbox" in item else [], - properties=item["properties"] if "properties" in item else {}, - links=item_links if "links" in item else [], - assets=item["assets"] if "assets" in item else {}, + collection=item.get("collection", ""), + geometry=item.get("geometry", {}), + bbox=item.get("bbox", []), + properties=item.get("properties", {}), + links=item_links, + assets=item.get("assets", {}), ) @@ -81,32 +112,49 @@ class CollectionSerializer(Serializer): @classmethod def db_to_stac(cls, collection: dict, base_url: str) -> stac_types.Collection: - """Transform database model to stac collection.""" + """Transform database model to STAC collection. + + Args: + collection (dict): The collection data in dictionary form, extracted from the database. + base_url (str): The base URL for the collection. + + Returns: + stac_types.Collection: The STAC collection object. + """ + # Use dictionary unpacking to extract values from the collection dictionary + collection_id = collection.get("id") + stac_extensions = collection.get("stac_extensions", []) + stac_version = collection.get("stac_version", "") + title = collection.get("title", "") + description = collection.get("description", "") + keywords = collection.get("keywords", []) + license = collection.get("license", "") + providers = collection.get("providers", {}) + summaries = collection.get("summaries", {}) + extent = collection.get("extent", {}) + + # Create the collection links using CollectionLinks collection_links = CollectionLinks( - collection_id=collection["id"], base_url=base_url + collection_id=collection_id, base_url=base_url ).create_links() - original_links = collection["links"] + # Add any additional links from the collection dictionary + original_links = collection.get("links") if original_links: collection_links += resolve_links(original_links, base_url) + # Return the stac_types.Collection object return stac_types.Collection( type="Collection", - id=collection["id"], - stac_extensions=collection["stac_extensions"] - if "stac_extensions" in collection - else [], - stac_version=collection["stac_version"] - if "stac_version" in collection - else "", - title=collection["title"] if "title" in collection else "", - description=collection["description"] - if "description" in collection - else "", - keywords=collection["keywords"] if "keywords" in collection else [], - license=collection["license"] if "license" in collection else "", - providers=collection["providers"] if "providers" in collection else {}, - summaries=collection["summaries"] if "summaries" in collection else {}, - extent=collection["extent"] if "extent" in collection else {}, + id=collection_id, + stac_extensions=stac_extensions, + stac_version=stac_version, + title=title, + description=description, + keywords=keywords, + license=license, + providers=providers, + summaries=summaries, + extent=extent, links=collection_links, )