From 12a7a1598324e1f952e8cbe8cd9a06f9e0831cfb Mon Sep 17 00:00:00 2001 From: Quentin Pradet Date: Fri, 11 Oct 2024 13:41:25 +0400 Subject: [PATCH] Add optional Arrow deserialization support --- docs/guide/configuration.asciidoc | 2 +- elasticsearch_serverless/serializer.py | 34 +++++++++++++++++++ noxfile.py | 3 +- pyproject.toml | 2 ++ .../test_client/test_deprecated_options.py | 2 ++ .../test_client/test_serializers.py | 3 ++ .../test_serializer.py | 21 ++++++++++++ 7 files changed, 64 insertions(+), 3 deletions(-) diff --git a/docs/guide/configuration.asciidoc b/docs/guide/configuration.asciidoc index b5055bf..f9e03b0 100644 --- a/docs/guide/configuration.asciidoc +++ b/docs/guide/configuration.asciidoc @@ -242,7 +242,7 @@ When using the `ignore_status` parameter the error response will be returned ser [[serializer]] === Serializers -Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, and `application/mapbox-vector-tile`. +Serializers transform bytes on the wire into native Python objects and vice-versa. By default the client ships with serializers for `application/json`, `application/x-ndjson`, `text/*`, `application/vnd.apache.arrow.stream` and `application/mapbox-vector-tile`. You can define custom serializers via the `serializers` parameter: diff --git a/elasticsearch_serverless/serializer.py b/elasticsearch_serverless/serializer.py index 37ad572..2e7ffb6 100644 --- a/elasticsearch_serverless/serializer.py +++ b/elasticsearch_serverless/serializer.py @@ -48,6 +48,13 @@ except ImportError: _OrjsonSerializer = None # type: ignore[assignment,misc] +try: + import pyarrow as pa + + __all__.append("PyArrowSerializer") +except ImportError: + pa = None + class JsonSerializer(_JsonSerializer): mimetype: ClassVar[str] = "application/json" @@ -114,6 +121,29 @@ def dumps(self, data: bytes) -> bytes: raise SerializationError(f"Cannot serialize {data!r} into a MapBox vector tile") +if pa is not None: + + class PyArrowSerializer(Serializer): + """PyArrow serializer for deserializing Arrow Stream data.""" + + mimetype: ClassVar[str] = "application/vnd.apache.arrow.stream" + + def loads(self, data: bytes) -> pa.Table: + try: + with pa.ipc.open_stream(data) as reader: + return reader.read_all() + except pa.ArrowException as e: + raise SerializationError( + message=f"Unable to deserialize as Arrow stream: {data!r}", + errors=(e,), + ) + + def dumps(self, data: Any) -> bytes: + raise SerializationError( + message="Elasticsearch does not accept Arrow input data" + ) + + DEFAULT_SERIALIZERS: Dict[str, Serializer] = { JsonSerializer.mimetype: JsonSerializer(), MapboxVectorTileSerializer.mimetype: MapboxVectorTileSerializer(), @@ -122,6 +152,10 @@ def dumps(self, data: bytes) -> bytes: CompatibilityModeNdjsonSerializer.mimetype: CompatibilityModeNdjsonSerializer(), } +if pa is not None: + DEFAULT_SERIALIZERS[PyArrowSerializer.mimetype] = PyArrowSerializer() + + # Alias for backwards compatibility JSONSerializer = JsonSerializer diff --git a/noxfile.py b/noxfile.py index 01a1282..5c0365a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -85,8 +85,7 @@ def lint(session): session.run("flake8", *SOURCE_FILES) session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES) - # Workaround to make '-r' to still work despite uninstalling aiohttp below. - session.install(".[async,requests,orjson]", env=INSTALL_ENV) + session.install(".[async,requests,orjson,pyarrow]", env=INSTALL_ENV) # Run mypy on the package and then the type examples separately for # the two different mypy use-cases, ourselves and our users. diff --git a/pyproject.toml b/pyproject.toml index ee1c98f..27cf479 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ async = ["aiohttp>=3,<4"] requests = ["requests>=2.4.0, <3.0.0" ] orjson = ["orjson>=3"] +pyarrow = ["pyarrow>=1"] dev = [ "requests>=2, <3", "aiohttp", @@ -65,6 +66,7 @@ dev = [ "nox", "orjson", "numpy", + "pyarrow", "pandas", "mapbox-vector-tile", ] diff --git a/test_elasticsearch_serverless/test_client/test_deprecated_options.py b/test_elasticsearch_serverless/test_client/test_deprecated_options.py index a0871f1..84cef2e 100644 --- a/test_elasticsearch_serverless/test_client/test_deprecated_options.py +++ b/test_elasticsearch_serverless/test_client/test_deprecated_options.py @@ -73,6 +73,7 @@ class CustomSerializer(JsonSerializer): "application/x-ndjson", "application/json", "text/*", + "application/vnd.apache.arrow.stream", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", } @@ -93,6 +94,7 @@ class CustomSerializer(JsonSerializer): "application/x-ndjson", "application/json", "text/*", + "application/vnd.apache.arrow.stream", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", "application/cbor", diff --git a/test_elasticsearch_serverless/test_client/test_serializers.py b/test_elasticsearch_serverless/test_client/test_serializers.py index c09237c..369c6a0 100644 --- a/test_elasticsearch_serverless/test_client/test_serializers.py +++ b/test_elasticsearch_serverless/test_client/test_serializers.py @@ -71,6 +71,7 @@ class CustomSerializer: "application/json", "text/*", "application/x-ndjson", + "application/vnd.apache.arrow.stream", "application/vnd.mapbox-vector-tile", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", @@ -98,6 +99,7 @@ class CustomSerializer: "application/json", "text/*", "application/x-ndjson", + "application/vnd.apache.arrow.stream", "application/vnd.mapbox-vector-tile", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", @@ -117,6 +119,7 @@ class CustomSerializer: "application/json", "text/*", "application/x-ndjson", + "application/vnd.apache.arrow.stream", "application/vnd.mapbox-vector-tile", "application/vnd.elasticsearch+json", "application/vnd.elasticsearch+x-ndjson", diff --git a/test_elasticsearch_serverless/test_serializer.py b/test_elasticsearch_serverless/test_serializer.py index b9757a8..fd59b2c 100644 --- a/test_elasticsearch_serverless/test_serializer.py +++ b/test_elasticsearch_serverless/test_serializer.py @@ -20,6 +20,7 @@ from datetime import datetime from decimal import Decimal +import pyarrow as pa import pytest try: @@ -35,6 +36,7 @@ from elasticsearch_serverless.serializer import ( JSONSerializer, OrjsonSerializer, + PyArrowSerializer, TextSerializer, ) @@ -161,6 +163,25 @@ def test_serializes_pandas_category(json_serializer): assert b'{"d":[1,2,3]}' == json_serializer.dumps({"d": cat}) +def test_pyarrow_loads(): + data = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", None]), + pa.array([True, None, False, True]), + ] + batch = pa.record_batch(data, names=["f0", "f1", "f2"]) + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, batch.schema) as writer: + writer.write_batch(batch) + + serializer = PyArrowSerializer() + assert serializer.loads(sink.getvalue()).to_pydict() == { + "f0": [1, 2, 3, 4], + "f1": ["foo", "bar", "baz", None], + "f2": [True, None, False, True], + } + + def test_json_raises_serialization_error_on_dump_error(json_serializer): with pytest.raises(SerializationError): json_serializer.dumps(object())