Skip to content

add pagination with search_after, other refactoring #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ services:
environment:
- APP_HOST=0.0.0.0
- APP_PORT=8080
- RELOAD=false
- RELOAD=true
- ENVIRONMENT=local
- WEB_CONCURRENCY=10
- ES_HOST=172.17.0.1
Expand Down
1 change: 1 addition & 0 deletions stac_fastapi/elasticsearch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"pystac[validation]",
"uvicorn",
"overrides",
"starlette",
]

extra_reqs = {
Expand Down
79 changes: 47 additions & 32 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from pydantic import ValidationError
from stac_pydantic.links import Relations
from stac_pydantic.shared import MimeTypes
from starlette.requests import Request

from stac_fastapi.elasticsearch import serializers
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
from stac_fastapi.elasticsearch.database_logic import DatabaseLogic
from stac_fastapi.elasticsearch.models.links import PagingLinks
from stac_fastapi.elasticsearch.serializers import CollectionSerializer, ItemSerializer
from stac_fastapi.elasticsearch.session import Session
from stac_fastapi.extensions.third_party.bulk_transactions import (
Expand Down Expand Up @@ -84,11 +86,17 @@ async def item_collection(
self, collection_id: str, limit: int = 10, token: str = None, **kwargs
) -> ItemCollection:
"""Read an item collection from the database."""
links = []
base_url = str(kwargs["request"].base_url)

items, maybe_count = await self.database.get_collection_items(
collection_id=collection_id, limit=limit, base_url=base_url
request: Request = kwargs["request"]
base_url = str(request.base_url)

items, maybe_count, next_token = await self.database.execute_search(
search=self.database.apply_collections_filter(
self.database.make_search(), [collection_id]
),
limit=limit,
token=token,
sort=None,
base_url=base_url,
)

context_obj = None
Expand All @@ -100,6 +108,10 @@ async def item_collection(
if maybe_count is not None:
context_obj["matched"] = maybe_count

links = []
if next_token:
links = await PagingLinks(request=request, next=next_token).get_links()

return ItemCollection(
type="FeatureCollection",
features=items,
Expand Down Expand Up @@ -203,16 +215,10 @@ async def post_search(
self, search_request: stac_pydantic.api.Search, **kwargs
) -> ItemCollection:
"""POST search catalog."""
base_url = str(kwargs["request"].base_url)
search = self.database.make_search()
request: Request = kwargs["request"]
base_url = str(request.base_url)

if search_request.query:
for (field_name, expr) in search_request.query.items():
field = "properties__" + field_name
for (op, value) in expr.items():
search = self.database.apply_stacql_filter(
search=search, op=op, field=field, value=value
)
search = self.database.make_search()

if search_request.ids:
search = self.database.apply_ids_filter(
Expand Down Expand Up @@ -242,16 +248,28 @@ async def post_search(
search=search, intersects=search_request.intersects
)

if search_request.query:
for (field_name, expr) in search_request.query.items():
field = "properties__" + field_name
for (op, value) in expr.items():
search = self.database.apply_stacql_filter(
search=search, op=op, field=field, value=value
)

sort = None
if search_request.sortby:
for sort in search_request.sortby:
if sort.field == "datetime":
sort.field = "properties__datetime"
search = self.database.apply_sort(
search=search, field=sort.field, direction=sort.direction
)
sort = self.database.populate_sort(search_request.sortby)

limit = 10
if search_request.limit:
limit = search_request.limit

response_features, maybe_count = await self.database.execute_search(
search=search, limit=search_request.limit, base_url=base_url
items, maybe_count, next_token = await self.database.execute_search(
search=search,
limit=limit,
token=search_request.token, # type: ignore
sort=sort,
base_url=base_url,
)

# if self.extension_is_enabled("FieldsExtension"):
Expand All @@ -274,26 +292,23 @@ async def post_search(
# for feat in response_features
# ]

if search_request.limit:
limit = search_request.limit
response_features = response_features[0:limit]
else:
limit = 10
response_features = response_features[0:limit]

context_obj = None
if self.extension_is_enabled("ContextExtension"):
context_obj = {
"returned": len(response_features),
"returned": len(items),
"limit": limit,
}
if maybe_count is not None:
context_obj["matched"] = maybe_count

links = []
if next_token:
links = await PagingLinks(request=request, next=next_token).get_links()

return ItemCollection(
type="FeatureCollection",
features=response_features,
links=[],
features=items,
links=links,
context=context_obj,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Database logic."""
import logging
from base64 import urlsafe_b64decode, urlsafe_b64encode
from typing import Dict, List, Optional, Tuple, Type, Union

import attr
Expand Down Expand Up @@ -75,17 +76,6 @@ async def get_all_collections(self, base_url: str) -> List[Collection]:
for c in collections["hits"]["hits"]
]

async def get_collection_items(
self, collection_id: str, limit: int, base_url: str
) -> Tuple[List[Item], Optional[int]]:
"""Database logic to retrieve an ItemCollection and a count of items contained."""
search = self.apply_collections_filter(Search(), [collection_id])
items, maybe_count = await self.execute_search(
search=search, limit=limit, base_url=base_url
)

return items, maybe_count

async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
"""Database logic to retrieve a single item."""
try:
Expand Down Expand Up @@ -190,31 +180,53 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
return search

@staticmethod
def apply_sort(search: Search, field, direction):
def populate_sort(sortby: List) -> Optional[Dict[str, Dict[str, str]]]:
"""Database logic to sort search instance."""
return search.sort({field: {"order": direction}})
if sortby:
return {s.field: {"order": s.direction} for s in sortby}
else:
return None

async def execute_search(
self, search, limit: int, base_url: str
) -> Tuple[List[Item], Optional[int]]:
self,
search: Search,
limit: int,
token: Optional[str],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we need the token, and pass sort as a separate value instead of part of search

sort: Optional[Dict[str, Dict[str, str]]],
base_url: str,
) -> Tuple[List[Item], Optional[int], Optional[str]]:
"""Database logic to execute search with limit."""
search = search[0:limit]
body = search.to_dict()

maybe_count = (
await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True))
).get("count")

search_after = None
if token:
search_after = urlsafe_b64decode(token.encode()).decode().split(",")

es_response = await self.client.search(
index=ITEMS_INDEX, query=body.get("query"), sort=body.get("sort")
index=ITEMS_INDEX,
query=body.get("query"),
sort=sort or DEFAULT_SORT,
search_after=search_after,
size=limit,
)

hits = es_response["hits"]["hits"]
items = [
self.item_serializer.db_to_stac(hit["_source"], base_url=base_url)
for hit in es_response["hits"]["hits"]
for hit in hits
]

return items, maybe_count
next_token = None
if hits and (sort_array := hits[-1].get("sort")):
next_token = urlsafe_b64encode(
",".join([str(x) for x in sort_array]).encode()
).decode()

return items, maybe_count, next_token

""" TRANSACTION LOGIC """

Expand Down
138 changes: 138 additions & 0 deletions stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/models/links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""link helpers."""

from typing import Any, Dict, List, Optional
from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse

import attr
from stac_pydantic.links import Relations
from stac_pydantic.shared import MimeTypes
from starlette.requests import Request

# Copied from pgstac links

# These can be inferred from the item/collection, so they aren't included in the database
# Instead they are dynamically generated when querying the database using the classes defined below
INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"]


def merge_params(url: str, newparams: Dict) -> str:
"""Merge url parameters."""
u = urlparse(url)
params = parse_qs(u.query)
params.update(newparams)
param_string = unquote(urlencode(params, True))

href = ParseResult(
scheme=u.scheme,
netloc=u.netloc,
path=u.path,
params=u.params,
query=param_string,
fragment=u.fragment,
).geturl()
return href


@attr.s
class BaseLinks:
"""Create inferred links common to collections and items."""

request: Request = attr.ib()

@property
def base_url(self):
"""Get the base url."""
return str(self.request.base_url)

@property
def url(self):
"""Get the current request url."""
return str(self.request.url)

def resolve(self, url):
"""Resolve url to the current request url."""
return urljoin(str(self.base_url), str(url))

def link_self(self) -> Dict:
"""Return the self link."""
return dict(rel=Relations.self.value, type=MimeTypes.json.value, href=self.url)

def link_root(self) -> Dict:
"""Return the catalog root."""
return dict(
rel=Relations.root.value, type=MimeTypes.json.value, href=self.base_url
)

def create_links(self) -> List[Dict[str, Any]]:
"""Return all inferred links."""
links = []
for name in dir(self):
if name.startswith("link_") and callable(getattr(self, name)):
link = getattr(self, name)()
if link is not None:
links.append(link)
return links

async def get_links(
self, extra_links: Optional[List[Dict[str, Any]]] = None
) -> List[Dict[str, Any]]:
"""
Generate all the links.

Get the links object for a stac resource by iterating through
available methods on this class that start with link_.
"""
# TODO: Pass request.json() into function so this doesn't need to be coroutine
if self.request.method == "POST":
self.request.postbody = await self.request.json()
# join passed in links with generated links
# and update relative paths
links = self.create_links()

if extra_links:
# For extra links passed in,
# add links modified with a resolved href.
# Drop any links that are dynamically
# determined by the server (e.g. self, parent, etc.)
# Resolving the href allows for relative paths
# to be stored in pgstac and for the hrefs in the
# links of response STAC objects to be resolved
# to the request url.
links += [
{**link, "href": self.resolve(link["href"])}
for link in extra_links
if link["rel"] not in INFERRED_LINK_RELS
]

return links


@attr.s
class PagingLinks(BaseLinks):
"""Create links for paging."""

next: Optional[str] = attr.ib(kw_only=True, default=None)

def link_next(self) -> Optional[Dict[str, Any]]:
"""Create link for next page."""
if self.next is not None:
method = self.request.method
if method == "GET":
href = merge_params(self.url, {"token": self.next})
link = dict(
rel=Relations.next.value,
type=MimeTypes.json.value,
method=method,
href=href,
)
return link
if method == "POST":
return {
"rel": Relations.next,
"type": MimeTypes.json,
"method": method,
"href": f"{self.request.url}",
"body": {**self.request.postbody, "token": self.next},
}

return None
Loading