Skip to content

Commit e2a943b

Browse files
author
Phil Varner
authored
Merge pull request #85 from stac-utils/pv/pagination-with-search-after
add pagination with search_after, other refactoring
2 parents 8f62e23 + d901041 commit e2a943b

File tree

9 files changed

+294
-143
lines changed

9 files changed

+294
-143
lines changed

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ services:
1111
environment:
1212
- APP_HOST=0.0.0.0
1313
- APP_PORT=8080
14-
- RELOAD=false
14+
- RELOAD=true
1515
- ENVIRONMENT=local
1616
- WEB_CONCURRENCY=10
1717
- ES_HOST=172.17.0.1

stac_fastapi/elasticsearch/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"pystac[validation]",
2020
"uvicorn",
2121
"overrides",
22+
"starlette",
2223
]
2324

2425
extra_reqs = {

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/core.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from pydantic import ValidationError
1414
from stac_pydantic.links import Relations
1515
from stac_pydantic.shared import MimeTypes
16+
from starlette.requests import Request
1617

1718
from stac_fastapi.elasticsearch import serializers
1819
from stac_fastapi.elasticsearch.config import ElasticsearchSettings
1920
from stac_fastapi.elasticsearch.database_logic import DatabaseLogic
21+
from stac_fastapi.elasticsearch.models.links import PagingLinks
2022
from stac_fastapi.elasticsearch.serializers import CollectionSerializer, ItemSerializer
2123
from stac_fastapi.elasticsearch.session import Session
2224
from stac_fastapi.extensions.third_party.bulk_transactions import (
@@ -84,11 +86,17 @@ async def item_collection(
8486
self, collection_id: str, limit: int = 10, token: str = None, **kwargs
8587
) -> ItemCollection:
8688
"""Read an item collection from the database."""
87-
links = []
88-
base_url = str(kwargs["request"].base_url)
89-
90-
items, maybe_count = await self.database.get_collection_items(
91-
collection_id=collection_id, limit=limit, base_url=base_url
89+
request: Request = kwargs["request"]
90+
base_url = str(request.base_url)
91+
92+
items, maybe_count, next_token = await self.database.execute_search(
93+
search=self.database.apply_collections_filter(
94+
self.database.make_search(), [collection_id]
95+
),
96+
limit=limit,
97+
token=token,
98+
sort=None,
99+
base_url=base_url,
92100
)
93101

94102
context_obj = None
@@ -100,6 +108,10 @@ async def item_collection(
100108
if maybe_count is not None:
101109
context_obj["matched"] = maybe_count
102110

111+
links = []
112+
if next_token:
113+
links = await PagingLinks(request=request, next=next_token).get_links()
114+
103115
return ItemCollection(
104116
type="FeatureCollection",
105117
features=items,
@@ -203,16 +215,10 @@ async def post_search(
203215
self, search_request: stac_pydantic.api.Search, **kwargs
204216
) -> ItemCollection:
205217
"""POST search catalog."""
206-
base_url = str(kwargs["request"].base_url)
207-
search = self.database.make_search()
218+
request: Request = kwargs["request"]
219+
base_url = str(request.base_url)
208220

209-
if search_request.query:
210-
for (field_name, expr) in search_request.query.items():
211-
field = "properties__" + field_name
212-
for (op, value) in expr.items():
213-
search = self.database.apply_stacql_filter(
214-
search=search, op=op, field=field, value=value
215-
)
221+
search = self.database.make_search()
216222

217223
if search_request.ids:
218224
search = self.database.apply_ids_filter(
@@ -242,16 +248,28 @@ async def post_search(
242248
search=search, intersects=search_request.intersects
243249
)
244250

251+
if search_request.query:
252+
for (field_name, expr) in search_request.query.items():
253+
field = "properties__" + field_name
254+
for (op, value) in expr.items():
255+
search = self.database.apply_stacql_filter(
256+
search=search, op=op, field=field, value=value
257+
)
258+
259+
sort = None
245260
if search_request.sortby:
246-
for sort in search_request.sortby:
247-
if sort.field == "datetime":
248-
sort.field = "properties__datetime"
249-
search = self.database.apply_sort(
250-
search=search, field=sort.field, direction=sort.direction
251-
)
261+
sort = self.database.populate_sort(search_request.sortby)
262+
263+
limit = 10
264+
if search_request.limit:
265+
limit = search_request.limit
252266

253-
response_features, maybe_count = await self.database.execute_search(
254-
search=search, limit=search_request.limit, base_url=base_url
267+
items, maybe_count, next_token = await self.database.execute_search(
268+
search=search,
269+
limit=limit,
270+
token=search_request.token, # type: ignore
271+
sort=sort,
272+
base_url=base_url,
255273
)
256274

257275
# if self.extension_is_enabled("FieldsExtension"):
@@ -274,26 +292,23 @@ async def post_search(
274292
# for feat in response_features
275293
# ]
276294

277-
if search_request.limit:
278-
limit = search_request.limit
279-
response_features = response_features[0:limit]
280-
else:
281-
limit = 10
282-
response_features = response_features[0:limit]
283-
284295
context_obj = None
285296
if self.extension_is_enabled("ContextExtension"):
286297
context_obj = {
287-
"returned": len(response_features),
298+
"returned": len(items),
288299
"limit": limit,
289300
}
290301
if maybe_count is not None:
291302
context_obj["matched"] = maybe_count
292303

304+
links = []
305+
if next_token:
306+
links = await PagingLinks(request=request, next=next_token).get_links()
307+
293308
return ItemCollection(
294309
type="FeatureCollection",
295-
features=response_features,
296-
links=[],
310+
features=items,
311+
links=links,
297312
context=context_obj,
298313
)
299314

stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Database logic."""
22
import logging
3+
from base64 import urlsafe_b64decode, urlsafe_b64encode
34
from typing import Dict, List, Optional, Tuple, Type, Union
45

56
import attr
@@ -75,17 +76,6 @@ async def get_all_collections(self, base_url: str) -> List[Collection]:
7576
for c in collections["hits"]["hits"]
7677
]
7778

78-
async def get_collection_items(
79-
self, collection_id: str, limit: int, base_url: str
80-
) -> Tuple[List[Item], Optional[int]]:
81-
"""Database logic to retrieve an ItemCollection and a count of items contained."""
82-
search = self.apply_collections_filter(Search(), [collection_id])
83-
items, maybe_count = await self.execute_search(
84-
search=search, limit=limit, base_url=base_url
85-
)
86-
87-
return items, maybe_count
88-
8979
async def get_one_item(self, collection_id: str, item_id: str) -> Dict:
9080
"""Database logic to retrieve a single item."""
9181
try:
@@ -190,31 +180,53 @@ def apply_stacql_filter(search: Search, op: str, field: str, value: float):
190180
return search
191181

192182
@staticmethod
193-
def apply_sort(search: Search, field, direction):
183+
def populate_sort(sortby: List) -> Optional[Dict[str, Dict[str, str]]]:
194184
"""Database logic to sort search instance."""
195-
return search.sort({field: {"order": direction}})
185+
if sortby:
186+
return {s.field: {"order": s.direction} for s in sortby}
187+
else:
188+
return None
196189

197190
async def execute_search(
198-
self, search, limit: int, base_url: str
199-
) -> Tuple[List[Item], Optional[int]]:
191+
self,
192+
search: Search,
193+
limit: int,
194+
token: Optional[str],
195+
sort: Optional[Dict[str, Dict[str, str]]],
196+
base_url: str,
197+
) -> Tuple[List[Item], Optional[int], Optional[str]]:
200198
"""Database logic to execute search with limit."""
201-
search = search[0:limit]
202199
body = search.to_dict()
203200

204201
maybe_count = (
205202
await self.client.count(index=ITEMS_INDEX, body=search.to_dict(count=True))
206203
).get("count")
207204

205+
search_after = None
206+
if token:
207+
search_after = urlsafe_b64decode(token.encode()).decode().split(",")
208+
208209
es_response = await self.client.search(
209-
index=ITEMS_INDEX, query=body.get("query"), sort=body.get("sort")
210+
index=ITEMS_INDEX,
211+
query=body.get("query"),
212+
sort=sort or DEFAULT_SORT,
213+
search_after=search_after,
214+
size=limit,
210215
)
211216

217+
hits = es_response["hits"]["hits"]
212218
items = [
213219
self.item_serializer.db_to_stac(hit["_source"], base_url=base_url)
214-
for hit in es_response["hits"]["hits"]
220+
for hit in hits
215221
]
216222

217-
return items, maybe_count
223+
next_token = None
224+
if hits and (sort_array := hits[-1].get("sort")):
225+
next_token = urlsafe_b64encode(
226+
",".join([str(x) for x in sort_array]).encode()
227+
).decode()
228+
229+
return items, maybe_count, next_token
218230

219231
""" TRANSACTION LOGIC """
220232

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""link helpers."""
2+
3+
from typing import Any, Dict, List, Optional
4+
from urllib.parse import ParseResult, parse_qs, unquote, urlencode, urljoin, urlparse
5+
6+
import attr
7+
from stac_pydantic.links import Relations
8+
from stac_pydantic.shared import MimeTypes
9+
from starlette.requests import Request
10+
11+
# Copied from pgstac links
12+
13+
# These can be inferred from the item/collection, so they aren't included in the database
14+
# Instead they are dynamically generated when querying the database using the classes defined below
15+
INFERRED_LINK_RELS = ["self", "item", "parent", "collection", "root"]
16+
17+
18+
def merge_params(url: str, newparams: Dict) -> str:
19+
"""Merge url parameters."""
20+
u = urlparse(url)
21+
params = parse_qs(u.query)
22+
params.update(newparams)
23+
param_string = unquote(urlencode(params, True))
24+
25+
href = ParseResult(
26+
scheme=u.scheme,
27+
netloc=u.netloc,
28+
path=u.path,
29+
params=u.params,
30+
query=param_string,
31+
fragment=u.fragment,
32+
).geturl()
33+
return href
34+
35+
36+
@attr.s
37+
class BaseLinks:
38+
"""Create inferred links common to collections and items."""
39+
40+
request: Request = attr.ib()
41+
42+
@property
43+
def base_url(self):
44+
"""Get the base url."""
45+
return str(self.request.base_url)
46+
47+
@property
48+
def url(self):
49+
"""Get the current request url."""
50+
return str(self.request.url)
51+
52+
def resolve(self, url):
53+
"""Resolve url to the current request url."""
54+
return urljoin(str(self.base_url), str(url))
55+
56+
def link_self(self) -> Dict:
57+
"""Return the self link."""
58+
return dict(rel=Relations.self.value, type=MimeTypes.json.value, href=self.url)
59+
60+
def link_root(self) -> Dict:
61+
"""Return the catalog root."""
62+
return dict(
63+
rel=Relations.root.value, type=MimeTypes.json.value, href=self.base_url
64+
)
65+
66+
def create_links(self) -> List[Dict[str, Any]]:
67+
"""Return all inferred links."""
68+
links = []
69+
for name in dir(self):
70+
if name.startswith("link_") and callable(getattr(self, name)):
71+
link = getattr(self, name)()
72+
if link is not None:
73+
links.append(link)
74+
return links
75+
76+
async def get_links(
77+
self, extra_links: Optional[List[Dict[str, Any]]] = None
78+
) -> List[Dict[str, Any]]:
79+
"""
80+
Generate all the links.
81+
82+
Get the links object for a stac resource by iterating through
83+
available methods on this class that start with link_.
84+
"""
85+
# TODO: Pass request.json() into function so this doesn't need to be coroutine
86+
if self.request.method == "POST":
87+
self.request.postbody = await self.request.json()
88+
# join passed in links with generated links
89+
# and update relative paths
90+
links = self.create_links()
91+
92+
if extra_links:
93+
# For extra links passed in,
94+
# add links modified with a resolved href.
95+
# Drop any links that are dynamically
96+
# determined by the server (e.g. self, parent, etc.)
97+
# Resolving the href allows for relative paths
98+
# to be stored in pgstac and for the hrefs in the
99+
# links of response STAC objects to be resolved
100+
# to the request url.
101+
links += [
102+
{**link, "href": self.resolve(link["href"])}
103+
for link in extra_links
104+
if link["rel"] not in INFERRED_LINK_RELS
105+
]
106+
107+
return links
108+
109+
110+
@attr.s
111+
class PagingLinks(BaseLinks):
112+
"""Create links for paging."""
113+
114+
next: Optional[str] = attr.ib(kw_only=True, default=None)
115+
116+
def link_next(self) -> Optional[Dict[str, Any]]:
117+
"""Create link for next page."""
118+
if self.next is not None:
119+
method = self.request.method
120+
if method == "GET":
121+
href = merge_params(self.url, {"token": self.next})
122+
link = dict(
123+
rel=Relations.next.value,
124+
type=MimeTypes.json.value,
125+
method=method,
126+
href=href,
127+
)
128+
return link
129+
if method == "POST":
130+
return {
131+
"rel": Relations.next,
132+
"type": MimeTypes.json,
133+
"method": method,
134+
"href": f"{self.request.url}",
135+
"body": {**self.request.postbody, "token": self.next},
136+
}
137+
138+
return None

0 commit comments

Comments
 (0)