diff --git a/docs/guide/configuration.asciidoc b/docs/guide/configuration.asciidoc index 61bc44e..b5055bf 100644 --- a/docs/guide/configuration.asciidoc +++ b/docs/guide/configuration.asciidoc @@ -264,6 +264,24 @@ es = Elasticsearch( ) ------------------------------------ +If the `orjson` package is installed, you can use the faster ``OrjsonSerializer`` for the default mimetype (``application/json``): + +[source,python] +------------------------------------ +from elasticsearch import Elasticsearch, OrjsonSerializer + +es = Elasticsearch( + ..., + serializer=OrjsonSerializer() +) +------------------------------------ + +orjson is particularly fast when serializing vectors as it has native numpy support. This will be the default in a future release. Note that you can install orjson with the `orjson` extra: + +[source,sh] +-------------------------------------------- +$ python -m pip install elasticsearch[orjson] +-------------------------------------------- [discrete] [[nodes]] diff --git a/elasticsearch_serverless/__init__.py b/elasticsearch_serverless/__init__.py index 8b38b39..723b3a2 100644 --- a/elasticsearch_serverless/__init__.py +++ b/elasticsearch_serverless/__init__.py @@ -63,6 +63,11 @@ ) from .serializer import JSONSerializer, JsonSerializer +try: + from .serializer import OrjsonSerializer +except ImportError: + OrjsonSerializer = None # type: ignore[assignment,misc] + # Only raise one warning per deprecation message so as not # to spam up the user if the same action is done multiple times. warnings.simplefilter("default", category=ElasticsearchWarning, append=True) @@ -86,6 +91,8 @@ "UnsupportedProductError", "ElasticsearchWarning", ] +if OrjsonSerializer is not None: + __all__.append("OrjsonSerializer") fixup_module_metadata(__name__, globals()) del fixup_module_metadata diff --git a/elasticsearch_serverless/serializer.py b/elasticsearch_serverless/serializer.py index 64592d2..37ad572 100644 --- a/elasticsearch_serverless/serializer.py +++ b/elasticsearch_serverless/serializer.py @@ -41,6 +41,13 @@ "MapboxVectorTileSerializer", ] +try: + from elastic_transport import OrjsonSerializer as _OrjsonSerializer + + __all__.append("OrjsonSerializer") +except ImportError: + _OrjsonSerializer = None # type: ignore[assignment,misc] + class JsonSerializer(_JsonSerializer): mimetype: ClassVar[str] = "application/json" @@ -73,6 +80,13 @@ def default(self, data: Any) -> Any: raise TypeError(f"Unable to serialize {data!r} (type: {type(data)})") +if _OrjsonSerializer is not None: + + class OrjsonSerializer(JsonSerializer, _OrjsonSerializer): + def default(self, data: Any) -> Any: + return JsonSerializer.default(self, data) + + class NdjsonSerializer(JsonSerializer, _NdjsonSerializer): mimetype: ClassVar[str] = "application/x-ndjson" diff --git a/noxfile.py b/noxfile.py index 9360f74..01a1282 100644 --- a/noxfile.py +++ b/noxfile.py @@ -86,7 +86,7 @@ def lint(session): session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) # Workaround to make '-r' to still work despite uninstalling aiohttp below. - session.install(".[async,requests]", env=INSTALL_ENV) + session.install(".[async,requests,orjson]", env=INSTALL_ENV) # Run mypy on the package and then the type examples separately for # the two different mypy use-cases, ourselves and our users. @@ -118,5 +118,5 @@ def lint(session): @nox.session() def docs(session): - session.install(".[docs]") + session.install(".[docs,orjson]") session.run("sphinx-build", "docs/sphinx/", "docs/sphinx/_build", "-b", "html") diff --git a/pyproject.toml b/pyproject.toml index fcbbc29..ee1c98f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,12 +44,9 @@ dependencies = [ ] [project.optional-dependencies] -async = [ - "aiohttp>=3,<4", -] -requests = [ - "requests>=2.4.0, <3.0.0", -] +async = ["aiohttp>=3,<4"] +requests = ["requests>=2.4.0, <3.0.0" ] +orjson = ["orjson>=3"] dev = [ "requests>=2, <3", "aiohttp", @@ -66,6 +63,7 @@ dev = [ "twine", "build", "nox", + "orjson", "numpy", "pandas", "mapbox-vector-tile", diff --git a/test_elasticsearch_serverless/test_serializer.py b/test_elasticsearch_serverless/test_serializer.py index 4674e97..b9757a8 100644 --- a/test_elasticsearch_serverless/test_serializer.py +++ b/test_elasticsearch_serverless/test_serializer.py @@ -16,7 +16,6 @@ # specific language governing permissions and limitations # under the License. -import sys import uuid from datetime import datetime from decimal import Decimal @@ -33,41 +32,45 @@ from elasticsearch_serverless import Elasticsearch from elasticsearch_serverless.exceptions import SerializationError -from elasticsearch_serverless.serializer import JSONSerializer, TextSerializer +from elasticsearch_serverless.serializer import ( + JSONSerializer, + OrjsonSerializer, + TextSerializer, +) requires_numpy_and_pandas = pytest.mark.skipif( - np is None or pd is None, reason="Test requires numpy or pandas to be available" + np is None or pd is None, reason="Test requires numpy and pandas to be available" ) -def test_datetime_serialization(): - assert b'{"d":"2010-10-01T02:30:00"}' == JSONSerializer().dumps( +@pytest.fixture(params=[JSONSerializer, OrjsonSerializer]) +def json_serializer(request: pytest.FixtureRequest): + yield request.param() + + +def test_datetime_serialization(json_serializer): + assert b'{"d":"2010-10-01T02:30:00"}' == json_serializer.dumps( {"d": datetime(2010, 10, 1, 2, 30)} ) -def test_decimal_serialization(): - requires_numpy_and_pandas() +def test_decimal_serialization(json_serializer): + assert b'{"d":3.8}' == json_serializer.dumps({"d": Decimal("3.8")}) - if sys.version_info[:2] == (2, 6): - pytest.skip("Float rounding is broken in 2.6.") - assert b'{"d":3.8}' == JSONSerializer().dumps({"d": Decimal("3.8")}) - -def test_uuid_serialization(): - assert b'{"d":"00000000-0000-0000-0000-000000000003"}' == JSONSerializer().dumps( +def test_uuid_serialization(json_serializer): + assert b'{"d":"00000000-0000-0000-0000-000000000003"}' == json_serializer.dumps( {"d": uuid.UUID("00000000-0000-0000-0000-000000000003")} ) @requires_numpy_and_pandas -def test_serializes_numpy_bool(): - assert b'{"d":true}' == JSONSerializer().dumps({"d": np.bool_(True)}) +def test_serializes_numpy_bool(json_serializer): + assert b'{"d":true}' == json_serializer.dumps({"d": np.bool_(True)}) @requires_numpy_and_pandas -def test_serializes_numpy_integers(): - ser = JSONSerializer() +def test_serializes_numpy_integers(json_serializer): for np_type in ( np.int_, np.int8, @@ -75,7 +78,7 @@ def test_serializes_numpy_integers(): np.int32, np.int64, ): - assert ser.dumps({"d": np_type(-1)}) == b'{"d":-1}' + assert json_serializer.dumps({"d": np_type(-1)}) == b'{"d":-1}' for np_type in ( np.uint8, @@ -83,81 +86,84 @@ def test_serializes_numpy_integers(): np.uint32, np.uint64, ): - assert ser.dumps({"d": np_type(1)}) == b'{"d":1}' + assert json_serializer.dumps({"d": np_type(1)}) == b'{"d":1}' @requires_numpy_and_pandas -def test_serializes_numpy_floats(): - ser = JSONSerializer() +def test_serializes_numpy_floats(json_serializer): for np_type in ( np.float32, np.float64, ): - assert re.search(rb'^{"d":1\.2[\d]*}$', ser.dumps({"d": np_type(1.2)})) + assert re.search( + rb'^{"d":1\.2[\d]*}$', json_serializer.dumps({"d": np_type(1.2)}) + ) @requires_numpy_and_pandas -def test_serializes_numpy_datetime(): - assert b'{"d":"2010-10-01T02:30:00"}' == JSONSerializer().dumps( +def test_serializes_numpy_datetime(json_serializer): + assert b'{"d":"2010-10-01T02:30:00"}' == json_serializer.dumps( {"d": np.datetime64("2010-10-01T02:30:00")} ) @requires_numpy_and_pandas -def test_serializes_numpy_ndarray(): - assert b'{"d":[0,0,0,0,0]}' == JSONSerializer().dumps( +def test_serializes_numpy_ndarray(json_serializer): + assert b'{"d":[0,0,0,0,0]}' == json_serializer.dumps( {"d": np.zeros((5,), dtype=np.uint8)} ) # This isn't useful for Elasticsearch, just want to make sure it works. - assert b'{"d":[[0,0],[0,0]]}' == JSONSerializer().dumps( + assert b'{"d":[[0,0],[0,0]]}' == json_serializer.dumps( {"d": np.zeros((2, 2), dtype=np.uint8)} ) @requires_numpy_and_pandas def test_serializes_numpy_nan_to_nan(): - assert b'{"d":NaN}' == JSONSerializer().dumps({"d": np.nan}) + assert b'{"d":NaN}' == JSONSerializer().dumps({"d": float("NaN")}) + # NaN is invalid JSON, and orjson silently converts it to null + assert b'{"d":null}' == OrjsonSerializer().dumps({"d": float("NaN")}) @requires_numpy_and_pandas -def test_serializes_pandas_timestamp(): - assert b'{"d":"2010-10-01T02:30:00"}' == JSONSerializer().dumps( +def test_serializes_pandas_timestamp(json_serializer): + assert b'{"d":"2010-10-01T02:30:00"}' == json_serializer.dumps( {"d": pd.Timestamp("2010-10-01T02:30:00")} ) @requires_numpy_and_pandas -def test_serializes_pandas_series(): - assert b'{"d":["a","b","c","d"]}' == JSONSerializer().dumps( +def test_serializes_pandas_series(json_serializer): + assert b'{"d":["a","b","c","d"]}' == json_serializer.dumps( {"d": pd.Series(["a", "b", "c", "d"])} ) @requires_numpy_and_pandas @pytest.mark.skipif(not hasattr(pd, "NA"), reason="pandas.NA is required") -def test_serializes_pandas_na(): - assert b'{"d":null}' == JSONSerializer().dumps({"d": pd.NA}) +def test_serializes_pandas_na(json_serializer): + assert b'{"d":null}' == json_serializer.dumps({"d": pd.NA}) @requires_numpy_and_pandas @pytest.mark.skipif(not hasattr(pd, "NaT"), reason="pandas.NaT required") -def test_raises_serialization_error_pandas_nat(): +def test_raises_serialization_error_pandas_nat(json_serializer): with pytest.raises(SerializationError): - JSONSerializer().dumps({"d": pd.NaT}) + json_serializer.dumps({"d": pd.NaT}) @requires_numpy_and_pandas -def test_serializes_pandas_category(): +def test_serializes_pandas_category(json_serializer): cat = pd.Categorical(["a", "c", "b", "a"], categories=["a", "b", "c"]) - assert b'{"d":["a","c","b","a"]}' == JSONSerializer().dumps({"d": cat}) + assert b'{"d":["a","c","b","a"]}' == json_serializer.dumps({"d": cat}) cat = pd.Categorical([1, 2, 3], categories=[1, 2, 3]) - assert b'{"d":[1,2,3]}' == JSONSerializer().dumps({"d": cat}) + assert b'{"d":[1,2,3]}' == json_serializer.dumps({"d": cat}) -def test_json_raises_serialization_error_on_dump_error(): +def test_json_raises_serialization_error_on_dump_error(json_serializer): with pytest.raises(SerializationError): - JSONSerializer().dumps(object()) + json_serializer.dumps(object()) def test_raises_serialization_error_on_load_error():