diff --git a/CHANGELOG.md b/CHANGELOG.md index cafe4bcb..16f5ab46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Changed +- Make `orjson` usage more consistent [#402](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/402) - Improved datetime query handling to only check start and end datetime values when datetime is None [#396](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/396) - Optimize data_loader.py script [#395](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/395) - Refactored test configuration to use shared app config pattern [#399](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/399) diff --git a/stac_fastapi/core/stac_fastapi/core/route_dependencies.py b/stac_fastapi/core/stac_fastapi/core/route_dependencies.py index 29dcc58b..fa5e4934 100644 --- a/stac_fastapi/core/stac_fastapi/core/route_dependencies.py +++ b/stac_fastapi/core/stac_fastapi/core/route_dependencies.py @@ -2,11 +2,11 @@ import importlib import inspect -import json import logging import os from typing import List +import orjson from fastapi import Depends from jsonschema import validate @@ -84,14 +84,14 @@ def get_route_dependencies_conf(route_dependencies_env: str) -> list: """Get Route dependencies configuration from file or environment variable.""" - if os.path.exists(route_dependencies_env): - with open(route_dependencies_env, encoding="utf-8") as route_dependencies_file: - route_dependencies_conf = json.load(route_dependencies_file) + if os.path.isfile(route_dependencies_env): + with open(route_dependencies_env, "rb") as f: + route_dependencies_conf = orjson.loads(f.read()) else: try: - route_dependencies_conf = json.loads(route_dependencies_env) - except json.JSONDecodeError as exception: + route_dependencies_conf = orjson.loads(route_dependencies_env) + except orjson.JSONDecodeError as exception: _LOGGER.error("Invalid JSON format for route dependencies. %s", exception) raise diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 2d630769..94f2530f 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -1,7 +1,6 @@ """Database logic.""" import asyncio -import json import logging from base64 import urlsafe_b64decode, urlsafe_b64encode from copy import deepcopy @@ -9,6 +8,7 @@ import attr import elasticsearch.helpers as helpers +import orjson from elasticsearch.dsl import Q, Search from elasticsearch.exceptions import NotFoundError as ESNotFoundError from starlette.requests import Request @@ -503,7 +503,7 @@ async def execute_search( search_after = None if token: - search_after = json.loads(urlsafe_b64decode(token).decode()) + search_after = orjson.loads(urlsafe_b64decode(token)) query = search.query.to_dict() if search.query else None @@ -543,7 +543,7 @@ async def execute_search( next_token = None if len(hits) > limit and limit < max_result_window: if hits and (sort_array := hits[limit - 1].get("sort")): - next_token = urlsafe_b64encode(json.dumps(sort_array).encode()).decode() + next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode() matched = ( es_response["hits"]["total"]["value"] diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index ba6fb9e8..979a0f8f 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -1,13 +1,13 @@ """Database logic.""" import asyncio -import json import logging from base64 import urlsafe_b64decode, urlsafe_b64encode from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union import attr +import orjson from opensearchpy import exceptions, helpers from opensearchpy.helpers.query import Q from opensearchpy.helpers.search import Search @@ -527,7 +527,7 @@ async def execute_search( search_after = None if token: - search_after = json.loads(urlsafe_b64decode(token).decode()) + search_after = orjson.loads(urlsafe_b64decode(token)) if search_after: search_body["search_after"] = search_after @@ -567,7 +567,7 @@ async def execute_search( next_token = None if len(hits) > limit and limit < max_result_window: if hits and (sort_array := hits[limit - 1].get("sort")): - next_token = urlsafe_b64encode(json.dumps(sort_array).encode()).decode() + next_token = urlsafe_b64encode(orjson.dumps(sort_array)).decode() matched = ( es_response["hits"]["total"]["value"]