diff --git a/docs/source/api.rst b/docs/source/api.rst index 020a7e88..36af800a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -804,14 +804,14 @@ A :class:`neo4j.Result` is attached to an active connection, through a :class:`n .. automethod:: graph - **This is experimental.** (See :ref:`filter-warnings-ref`) - .. automethod:: value .. automethod:: values .. automethod:: data + .. automethod:: to_df + .. automethod:: closed See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. @@ -987,7 +987,7 @@ Path :class:`neo4j.graph.Path` Node ==== -.. autoclass:: neo4j.graph.Node() +.. autoclass:: neo4j.graph.Node .. describe:: node == other @@ -1022,6 +1022,8 @@ Node .. autoattribute:: id + .. autoattribute:: element_id + .. autoattribute:: labels .. automethod:: get @@ -1036,7 +1038,7 @@ Node Relationship ============ -.. autoclass:: neo4j.graph.Relationship() +.. autoclass:: neo4j.graph.Relationship .. describe:: relationship == other @@ -1076,6 +1078,8 @@ Relationship .. autoattribute:: id + .. autoattribute:: element_id + .. autoattribute:: nodes .. autoattribute:: start_node @@ -1097,7 +1101,7 @@ Relationship Path ==== -.. autoclass:: neo4j.graph.Path() +.. autoclass:: neo4j.graph.Path .. describe:: path == other diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 5f831141..10952b06 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -511,14 +511,14 @@ A :class:`neo4j.AsyncResult` is attached to an active connection, through a :cla .. automethod:: graph - **This is experimental.** (See :ref:`filter-warnings-ref`) - .. automethod:: value .. automethod:: values .. automethod:: data + .. automethod:: to_df + .. automethod:: closed See https://neo4j.com/docs/python-manual/current/cypher-workflow/#python-driver-type-mapping for more about type mapping. diff --git a/neo4j/_async/work/result.py b/neo4j/_async/work/result.py index cf7937c3..ebca415b 100644 --- a/neo4j/_async/work/result.py +++ b/neo4j/_async/work/result.py @@ -20,11 +20,15 @@ from warnings import warn from ..._async_compat.util import AsyncUtil -from ...data import DataDehydrator +from ...data import ( + DataDehydrator, + RecordTableRowExporter, +) from ...exceptions import ( ResultConsumedError, ResultNotSingleError, ) +from ...meta import experimental from ...work import ResultSummary from ..io import ConnectionErrorHandler @@ -455,6 +459,8 @@ async def graph(self): was obtained has been closed or the Result has been explicitly consumed. + **This is experimental.** (See :ref:`filter-warnings-ref`) + .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. """ @@ -519,6 +525,127 @@ async def data(self, *keys): """ return [record.data(*keys) async for record in self] + @experimental("pandas support is experimental and might be changed or " + "removed in future versions") + async def to_df(self, expand=False): + r"""Convert (the rest of) the result to a pandas DataFrame. + + This method is only available if the `pandas` library is installed. + + :: + + res = await tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 as m") + df = await res.to_df() + + for instance will return a DataFrame with two columns: ``n`` and ``m`` + and 10 rows. + + :param expand: if :const:`True`, some structures in the result will be + recursively expanded (flattened out into multiple columns) like so + (everything inside ``<...>`` is a placeholder): + + * :class:`.Node` objects under any variable ```` will be + expanded into columns (the recursion stops here) + + * ``().prop.`` (any) for each property of the + node. + * ``().element_id`` (str) the node's element id. + See :attr:`.Node.element_id`. + * ``().labels`` (frozenset of str) the node's labels. + See :attr:`.Node.labels`. + + * :class:`.Relationship` objects under any variable ```` + will be expanded into columns (the recursion stops here) + + * ``->.prop.`` (any) for each property of the + relationship. + * ``->.element_id`` (str) the relationship's element id. + See :attr:`.Relationship.element_id`. + * ``->.start.element_id`` (str) the relationship's + start node's element id. + See :attr:`.Relationship.start_node`. + * ``->.end.element_id`` (str) the relationship's + end node's element id. + See :attr:`.Relationship.end_node`. + * ``->.type`` (str) the relationship's type. + See :attr:`.Relationship.type`. + + * :const:`list` objects under any variable ```` will be expanded + into + + * ``[].0`` (any) the 1st list element + * ``[].1`` (any) the 2nd list element + * ... + + * :const:`dict` objects under any variable ```` will be expanded + into + + * ``{}.`` (any) the 1st key of the dict + * ``{}.`` (any) the 2nd key of the dict + * ... + + * :const:`list` and :const:`dict` objects are expanded recursively. + Example:: + + variable x: [{"foo": "bar", "baz": [42, 0]}, "foobar"] + + will be expanded to:: + + { + "x[].0{}.foo": "bar", + "x[].0{}.baz[].0": 42, + "n[].0{}.baz[].1": 0, + "n[].1": "foobar" + } + + * Everything else (including :class:`.Path` objects) will not + be flattened. + + :const:`dict` keys and variable names that contain ``.`` or ``\`` + will be escaped with a backslash (``\.`` and ``\\`` respectively). + :type expand: bool + + :rtype: :py:class:`pandas.DataFrame` + :raises ImportError: if `pandas` library is not available. + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed or the Result has been explicitly + consumed. + + **This is experimental.** + ``pandas`` support might be changed or removed in future versions + without warning. (See :ref:`filter-warnings-ref`) + """ + import pandas as pd + + if not expand: + return pd.DataFrame(await self.values(), columns=self._keys) + else: + df_keys = None + rows = [] + async for record in self: + row = RecordTableRowExporter().transform(dict(record.items())) + if df_keys == row.keys(): + rows.append(row.values()) + elif df_keys is None: + df_keys = row.keys() + rows.append(row.values()) + elif df_keys is False: + rows.append(row) + else: + # The rows have different keys. We need to pass a list + # of dicts to pandas + rows = [{k: v for k, v in zip(df_keys, r)} for r in rows] + df_keys = False + rows.append(row) + if df_keys is False: + return pd.DataFrame(rows) + else: + columns = df_keys or [ + k.replace(".", "\\.").replace("\\", "\\\\") + for k in self._keys + ] + return pd.DataFrame(rows, columns=columns) + def closed(self): """Return True if the result has been closed. diff --git a/neo4j/_sync/work/result.py b/neo4j/_sync/work/result.py index 69cd409d..e2a4b448 100644 --- a/neo4j/_sync/work/result.py +++ b/neo4j/_sync/work/result.py @@ -20,11 +20,15 @@ from warnings import warn from ..._async_compat.util import Util -from ...data import DataDehydrator +from ...data import ( + DataDehydrator, + RecordTableRowExporter, +) from ...exceptions import ( ResultConsumedError, ResultNotSingleError, ) +from ...meta import experimental from ...work import ResultSummary from ..io import ConnectionErrorHandler @@ -455,6 +459,8 @@ def graph(self): was obtained has been closed or the Result has been explicitly consumed. + **This is experimental.** (See :ref:`filter-warnings-ref`) + .. versionchanged:: 5.0 Can raise :exc:`ResultConsumedError`. """ @@ -519,6 +525,127 @@ def data(self, *keys): """ return [record.data(*keys) for record in self] + @experimental("pandas support is experimental and might be changed or " + "removed in future versions") + def to_df(self, expand=False): + r"""Convert (the rest of) the result to a pandas DataFrame. + + This method is only available if the `pandas` library is installed. + + :: + + res = tx.run("UNWIND range(1, 10) AS n RETURN n, n+1 as m") + df = res.to_df() + + for instance will return a DataFrame with two columns: ``n`` and ``m`` + and 10 rows. + + :param expand: if :const:`True`, some structures in the result will be + recursively expanded (flattened out into multiple columns) like so + (everything inside ``<...>`` is a placeholder): + + * :class:`.Node` objects under any variable ```` will be + expanded into columns (the recursion stops here) + + * ``().prop.`` (any) for each property of the + node. + * ``().element_id`` (str) the node's element id. + See :attr:`.Node.element_id`. + * ``().labels`` (frozenset of str) the node's labels. + See :attr:`.Node.labels`. + + * :class:`.Relationship` objects under any variable ```` + will be expanded into columns (the recursion stops here) + + * ``->.prop.`` (any) for each property of the + relationship. + * ``->.element_id`` (str) the relationship's element id. + See :attr:`.Relationship.element_id`. + * ``->.start.element_id`` (str) the relationship's + start node's element id. + See :attr:`.Relationship.start_node`. + * ``->.end.element_id`` (str) the relationship's + end node's element id. + See :attr:`.Relationship.end_node`. + * ``->.type`` (str) the relationship's type. + See :attr:`.Relationship.type`. + + * :const:`list` objects under any variable ```` will be expanded + into + + * ``[].0`` (any) the 1st list element + * ``[].1`` (any) the 2nd list element + * ... + + * :const:`dict` objects under any variable ```` will be expanded + into + + * ``{}.`` (any) the 1st key of the dict + * ``{}.`` (any) the 2nd key of the dict + * ... + + * :const:`list` and :const:`dict` objects are expanded recursively. + Example:: + + variable x: [{"foo": "bar", "baz": [42, 0]}, "foobar"] + + will be expanded to:: + + { + "x[].0{}.foo": "bar", + "x[].0{}.baz[].0": 42, + "n[].0{}.baz[].1": 0, + "n[].1": "foobar" + } + + * Everything else (including :class:`.Path` objects) will not + be flattened. + + :const:`dict` keys and variable names that contain ``.`` or ``\`` + will be escaped with a backslash (``\.`` and ``\\`` respectively). + :type expand: bool + + :rtype: :py:class:`pandas.DataFrame` + :raises ImportError: if `pandas` library is not available. + :raises ResultConsumedError: if the transaction from which this result + was obtained has been closed or the Result has been explicitly + consumed. + + **This is experimental.** + ``pandas`` support might be changed or removed in future versions + without warning. (See :ref:`filter-warnings-ref`) + """ + import pandas as pd + + if not expand: + return pd.DataFrame(self.values(), columns=self._keys) + else: + df_keys = None + rows = [] + for record in self: + row = RecordTableRowExporter().transform(dict(record.items())) + if df_keys == row.keys(): + rows.append(row.values()) + elif df_keys is None: + df_keys = row.keys() + rows.append(row.values()) + elif df_keys is False: + rows.append(row) + else: + # The rows have different keys. We need to pass a list + # of dicts to pandas + rows = [{k: v for k, v in zip(df_keys, r)} for r in rows] + df_keys = False + rows.append(row) + if df_keys is False: + return pd.DataFrame(rows) + else: + columns = df_keys or [ + k.replace(".", "\\.").replace("\\", "\\\\") + for k in self._keys + ] + return pd.DataFrame(rows, columns=columns) + def closed(self): """Return True if the result has been closed. diff --git a/neo4j/data.py b/neo4j/data.py index d60937cd..c4f38845 100644 --- a/neo4j/data.py +++ b/neo4j/data.py @@ -297,6 +297,59 @@ def transform(self, x): return x +class RecordTableRowExporter(DataTransformer): + """Transformer class used by the :meth:`.Result.to_df` method.""" + + def transform(self, x): + assert isinstance(x, Mapping) + t = type(x) + return t(item + for k, v in x.items() + for item in self._transform( + v, prefix=k.replace("\\", "\\\\").replace(".", "\\.") + ).items()) + + def _transform(self, x, prefix): + if isinstance(x, Node): + res = { + "%s().element_id" % prefix: x.element_id, + "%s().labels" % prefix: x.labels, + } + res.update(("%s().prop.%s" % (prefix, k), v) for k, v in x.items()) + return res + elif isinstance(x, Relationship): + res = { + "%s->.element_id" % prefix: x.element_id, + "%s->.start.element_id" % prefix: x.start_node.element_id, + "%s->.end.element_id" % prefix: x.end_node.element_id, + "%s->.type" % prefix: x.__class__.__name__, + } + res.update(("%s->.prop.%s" % (prefix, k), v) for k, v in x.items()) + return res + elif isinstance(x, Path) or isinstance(x, str): + return {prefix: x} + elif isinstance(x, Sequence): + return dict( + item + for i, v in enumerate(x) + for item in self._transform( + v, prefix="%s[].%i" % (prefix, i) + ).items() + ) + elif isinstance(x, Mapping): + t = type(x) + return t( + item + for k, v in x.items() + for item in self._transform( + v, prefix="%s{}.%s" % (prefix, k.replace("\\", "\\\\") + .replace(".", "\\.")) + ).items() + ) + else: + return {prefix: x} + + class DataHydrator: # TODO: extend DataTransformer diff --git a/neo4j/graph/__init__.py b/neo4j/graph/__init__.py index cc0d97ed..62b1d5bc 100644 --- a/neo4j/graph/__init__.py +++ b/neo4j/graph/__init__.py @@ -207,6 +207,10 @@ def id(self): Depending on the version of the server this entity was retrieved from, this may be empty (None). + .. Warning:: + This value can change for the same entity across multiple + queries. Don't rely on it for cross-query computations. + .. deprecated:: 5.0 Use :attr:`.element_id` instead. @@ -218,7 +222,11 @@ def id(self): def element_id(self): """The identity of this entity in its container :class:`.Graph`. - .. added:: 5.0 + .. Warning:: + This value can change for the same entity across multiple + queries. Don't rely on it for cross-query computations. + + .. versionadded:: 5.0 :rtype: str """ diff --git a/tests/requirements.txt b/tests/requirements.txt index 297be88b..31d0646a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -7,3 +7,4 @@ pytest-cov>=3.0.0 pytest-mock>=3.6.1 mock>=4.0.3 teamcity-messages>=1.29 +pandas>=1.0.0 diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 44f43f0b..f4043b5e 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -16,10 +16,10 @@ # limitations under the License. -from itertools import product from unittest import mock import warnings +import pandas as pd import pytest from neo4j import ( @@ -32,21 +32,25 @@ Version, ) from neo4j._async_compat.util import AsyncUtil -from neo4j.data import DataHydrator +from neo4j.data import ( + DataHydrator, + Node, + Relationship, +) from neo4j.exceptions import ( ResultConsumedError, ResultNotSingleError, ) +from neo4j.packstream import Structure from ...._async_compat import mark_async_test class Records: def __init__(self, fields, records): - assert all(len(fields) == len(r) for r in records) - self.fields = fields - # self.records = [{"record_values": r} for r in records] - self.records = records + self.fields = tuple(fields) + self.records = tuple(records) + assert all(len(self.fields) == len(r) for r in self.records) def __len__(self): return self.records.__len__() @@ -569,3 +573,214 @@ async def test_data(num_records): assert await result.data("hello", "world") == expected_data for record in records: assert record.data.called_once_with("hello", "world") + + +@pytest.mark.parametrize( + ("keys", "values", "types", "instances"), + ( + (["i"], list(zip(range(5))), ["int64"], None), + (["x"], list(zip((n - .5) / 5 for n in range(5))), ["float64"], None), + (["s"], list(zip(("foo", "bar", "baz", "foobar"))), ["object"], None), + (["l"], list(zip(([1, 2], [3, 4]))), ["object"], None), + ( + ["n"], + list(zip(( + Structure(b"N", 0, ["LABEL_A"], {"a": 1, "b": 2}), + Structure(b"N", 2, ["LABEL_B"], {"a": 1, "c": 1.2}), + Structure(b"N", 1, ["LABEL_A", "LABEL_B"], {"a": [1, "a"]}), + Structure(b"N", None, ["LABEL_A", "LABEL_B"], {"a": [1, "a"]}, + "cool_id"), + ))), + ["object"], + [Node] + ), + ( + ["r"], + list(zip(( + Structure(b"R", 0, 1, 2, "TYPE", {"a": 1, "b": 2}), + Structure(b"R", 420, 1337, 69, "HYPE", {"all memes": True}), + Structure(b"R", None, None, None, "HYPE", {"all memes": True}, + "420", "1337", "69"), + ))), + ["object"], + [Relationship] + ), + ) +) +@pytest.mark.parametrize("test_default_expand", (True, False)) +@mark_async_test +async def test_to_df(keys, values, types, instances, test_default_expand): + connection = AsyncConnectionStub(records=Records(keys, values)) + result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + if test_default_expand: + df = await result.to_df() + else: + df = await result.to_df(expand=False) + + assert isinstance(df, pd.DataFrame) + assert df.keys().to_list() == keys + assert len(df) == len(values) + assert df.dtypes.to_list() == types + + expected_df = pd.DataFrame( + {k: [v[i] for v in values] for i, k in enumerate(keys)} + ) + + if instances: + for i, k in enumerate(keys): + assert all(isinstance(v, instances[i]) for v in df[k]) + else: + assert df.equals(expected_df) + + +@pytest.mark.parametrize( + ("keys", "values", "expected_columns", "expected_rows", "expected_types"), + ( + ( + ["i"], + list(zip(range(5))), + ["i"], + [[0], [1], [2], [3], [4]], + ["int64"], + ), + # test variable name escaping + ( + ["i.[]->.().{}.\\"], + list(zip(range(5))), + ["i\\.[]->\\.()\\.{}\\.\\\\"], + [[0], [1], [2], [3], [4]], + ["int64"], + ), + ( + ["x"], + list(zip((n - .5) / 5 for n in range(5))), + ["x"], + [[-0.1], [0.1], [0.3], [0.5], [0.7]], + ["float64"], + ), + ( + ["s"], + list(zip(("foo", "bar", "baz", "foobar"))), + ["s"], + [["foo"], ["bar"], ["baz"], ["foobar"]], + ["object"], + ), + ( + ["l"], + list(zip(([1, 2], [3, 4]))), + ["l[].0", "l[].1"], + [[1, 2], [3, 4]], + ["int64", "int64"], + ), + ( + ["l"], + list(zip(([1, 2], [3, 4, 5], [6]))), + ["l[].0", "l[].1", "l[].2"], + [[1, 2, None], [3, 4, 5], [6, None, None]], + # pandas turns None in int columns into NaN + # which requires the column to become a float column + ["int64", "float64", "float64"], + ), + ( + ["d"], + list(zip(({"a": 1, "b": 2}, {"a": 3, "b": 4, "": 0}))), + ["d{}.a", "d{}.b", "d{}."], + [[1, 2, None], [3, 4, 0]], + ["int64", "int64", "float64"], + ), + # test key escaping + ( + ["d"], + list(zip(({"a.[]\\{}->.().{}.": 1, "b": 2},))), + ["d{}.a\\.[]\\\\{}->\\.()\\.{}\\.", "d{}.b"], + [[1, 2]], + ["int64", "int64"], + ), + ( + ["d"], + list(zip(({"a": 1, "b": 2}, {"a": 3, "c": 4}))), + ["d{}.a", "d{}.b", "d{}.c"], + [[1, 2, None], [3, None, 4]], + # pandas turns None in int columns into NaN + # which requires the column to become a float column + ["int64", "float64", "float64"], + ), + ( + ["x"], + list(zip(([{"foo": "bar", "baz": [42, 0.1]}, "foobar"],))), + ["x[].0{}.foo", "x[].0{}.baz[].0", "x[].0{}.baz[].1", "x[].1"], + [["bar", 42, 0.1, "foobar"]], + ["object", "int64", "float64", "object"], + ), + ( + ["n"], + list(zip(( + Structure(b"N", 0, ["LABEL_A"], + {"a": 1, "b": 2, "d": 1}, "00"), + Structure(b"N", 2, ["LABEL_B"], + {"a": 1, "c": 1.2, "d": 2}, "02"), + Structure(b"N", 1, ["LABEL_A", "LABEL_B"], + {"a": [1, "a"], "d": 3}, "01"), + ))), + [ + "n().element_id", "n().labels", "n().prop.a", "n().prop.b", + "n().prop.c", "n().prop.d" + ], + [ + ["00", frozenset(("LABEL_A",)), 1, 2, None, 1], + ["02", frozenset(("LABEL_B",)), 1, None, 1.2, 2], + [ + "01", frozenset(("LABEL_A", "LABEL_B")), + [1, "a"], None, None, 3 + ], + ], + ["object", "object", "object", "float64", "float64", "int64"], + ), + ( + ["r"], + list(zip(( + Structure(b"R", 0, 1, 2, "TYPE", {"a": 1, "all memes": False}, + "r-0", "r-1", "r-2"), + Structure(b"R", 420, 1337, 69, "HYPE", {"all memes": True}, + "r-420", "r-1337", "r-69"), + ))), + [ + "r->.element_id", "r->.start.element_id", "r->.end.element_id", + "r->.type", "r->.prop.a", "r->.prop.all memes" + ], + [ + ["r-0", "r-1", "r-2", "TYPE", 1, False], + ["r-420", "r-1337", "r-69", "HYPE", None, True], + ], + ["object", "object", "object", "object", "float64", "bool"], + ), + ) +) +@mark_async_test +async def test_to_df_expand(keys, values, expected_columns, expected_rows, + expected_types): + connection = AsyncConnectionStub(records=Records(keys, values)) + result = AsyncResult(connection, DataHydrator(), 1, noop, noop) + await result._run("CYPHER", {}, None, None, "r", None) + df = await result.to_df(expand=True) + + assert isinstance(df, pd.DataFrame) + assert len(set(expected_columns)) == len(expected_columns) + assert set(df.keys().to_list()) == set(expected_columns) + + # We don't expect the columns to be in a specific order. + # Hence, we need to sort them before comparing. + new_order = [df.keys().get_loc(ex_c) for ex_c in expected_columns] + expected_rows = [ + [row[i] for i in new_order] + for row in expected_rows + ] + expected_types = [expected_types[i] for i in new_order] + expected_columns = [expected_columns[i] for i in new_order] + + assert len(df) == len(values) + assert df.dtypes.to_list() == expected_types + + expected_df = pd.DataFrame(expected_rows, columns=expected_columns) + assert df.equals(expected_df) diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 3c629cdf..d21d121a 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -16,10 +16,10 @@ # limitations under the License. -from itertools import product from unittest import mock import warnings +import pandas as pd import pytest from neo4j import ( @@ -32,21 +32,25 @@ Version, ) from neo4j._async_compat.util import Util -from neo4j.data import DataHydrator +from neo4j.data import ( + DataHydrator, + Node, + Relationship, +) from neo4j.exceptions import ( ResultConsumedError, ResultNotSingleError, ) +from neo4j.packstream import Structure from ...._async_compat import mark_sync_test class Records: def __init__(self, fields, records): - assert all(len(fields) == len(r) for r in records) - self.fields = fields - # self.records = [{"record_values": r} for r in records] - self.records = records + self.fields = tuple(fields) + self.records = tuple(records) + assert all(len(self.fields) == len(r) for r in self.records) def __len__(self): return self.records.__len__() @@ -569,3 +573,214 @@ def test_data(num_records): assert result.data("hello", "world") == expected_data for record in records: assert record.data.called_once_with("hello", "world") + + +@pytest.mark.parametrize( + ("keys", "values", "types", "instances"), + ( + (["i"], list(zip(range(5))), ["int64"], None), + (["x"], list(zip((n - .5) / 5 for n in range(5))), ["float64"], None), + (["s"], list(zip(("foo", "bar", "baz", "foobar"))), ["object"], None), + (["l"], list(zip(([1, 2], [3, 4]))), ["object"], None), + ( + ["n"], + list(zip(( + Structure(b"N", 0, ["LABEL_A"], {"a": 1, "b": 2}), + Structure(b"N", 2, ["LABEL_B"], {"a": 1, "c": 1.2}), + Structure(b"N", 1, ["LABEL_A", "LABEL_B"], {"a": [1, "a"]}), + Structure(b"N", None, ["LABEL_A", "LABEL_B"], {"a": [1, "a"]}, + "cool_id"), + ))), + ["object"], + [Node] + ), + ( + ["r"], + list(zip(( + Structure(b"R", 0, 1, 2, "TYPE", {"a": 1, "b": 2}), + Structure(b"R", 420, 1337, 69, "HYPE", {"all memes": True}), + Structure(b"R", None, None, None, "HYPE", {"all memes": True}, + "420", "1337", "69"), + ))), + ["object"], + [Relationship] + ), + ) +) +@pytest.mark.parametrize("test_default_expand", (True, False)) +@mark_sync_test +def test_to_df(keys, values, types, instances, test_default_expand): + connection = ConnectionStub(records=Records(keys, values)) + result = Result(connection, DataHydrator(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + if test_default_expand: + df = result.to_df() + else: + df = result.to_df(expand=False) + + assert isinstance(df, pd.DataFrame) + assert df.keys().to_list() == keys + assert len(df) == len(values) + assert df.dtypes.to_list() == types + + expected_df = pd.DataFrame( + {k: [v[i] for v in values] for i, k in enumerate(keys)} + ) + + if instances: + for i, k in enumerate(keys): + assert all(isinstance(v, instances[i]) for v in df[k]) + else: + assert df.equals(expected_df) + + +@pytest.mark.parametrize( + ("keys", "values", "expected_columns", "expected_rows", "expected_types"), + ( + ( + ["i"], + list(zip(range(5))), + ["i"], + [[0], [1], [2], [3], [4]], + ["int64"], + ), + # test variable name escaping + ( + ["i.[]->.().{}.\\"], + list(zip(range(5))), + ["i\\.[]->\\.()\\.{}\\.\\\\"], + [[0], [1], [2], [3], [4]], + ["int64"], + ), + ( + ["x"], + list(zip((n - .5) / 5 for n in range(5))), + ["x"], + [[-0.1], [0.1], [0.3], [0.5], [0.7]], + ["float64"], + ), + ( + ["s"], + list(zip(("foo", "bar", "baz", "foobar"))), + ["s"], + [["foo"], ["bar"], ["baz"], ["foobar"]], + ["object"], + ), + ( + ["l"], + list(zip(([1, 2], [3, 4]))), + ["l[].0", "l[].1"], + [[1, 2], [3, 4]], + ["int64", "int64"], + ), + ( + ["l"], + list(zip(([1, 2], [3, 4, 5], [6]))), + ["l[].0", "l[].1", "l[].2"], + [[1, 2, None], [3, 4, 5], [6, None, None]], + # pandas turns None in int columns into NaN + # which requires the column to become a float column + ["int64", "float64", "float64"], + ), + ( + ["d"], + list(zip(({"a": 1, "b": 2}, {"a": 3, "b": 4, "": 0}))), + ["d{}.a", "d{}.b", "d{}."], + [[1, 2, None], [3, 4, 0]], + ["int64", "int64", "float64"], + ), + # test key escaping + ( + ["d"], + list(zip(({"a.[]\\{}->.().{}.": 1, "b": 2},))), + ["d{}.a\\.[]\\\\{}->\\.()\\.{}\\.", "d{}.b"], + [[1, 2]], + ["int64", "int64"], + ), + ( + ["d"], + list(zip(({"a": 1, "b": 2}, {"a": 3, "c": 4}))), + ["d{}.a", "d{}.b", "d{}.c"], + [[1, 2, None], [3, None, 4]], + # pandas turns None in int columns into NaN + # which requires the column to become a float column + ["int64", "float64", "float64"], + ), + ( + ["x"], + list(zip(([{"foo": "bar", "baz": [42, 0.1]}, "foobar"],))), + ["x[].0{}.foo", "x[].0{}.baz[].0", "x[].0{}.baz[].1", "x[].1"], + [["bar", 42, 0.1, "foobar"]], + ["object", "int64", "float64", "object"], + ), + ( + ["n"], + list(zip(( + Structure(b"N", 0, ["LABEL_A"], + {"a": 1, "b": 2, "d": 1}, "00"), + Structure(b"N", 2, ["LABEL_B"], + {"a": 1, "c": 1.2, "d": 2}, "02"), + Structure(b"N", 1, ["LABEL_A", "LABEL_B"], + {"a": [1, "a"], "d": 3}, "01"), + ))), + [ + "n().element_id", "n().labels", "n().prop.a", "n().prop.b", + "n().prop.c", "n().prop.d" + ], + [ + ["00", frozenset(("LABEL_A",)), 1, 2, None, 1], + ["02", frozenset(("LABEL_B",)), 1, None, 1.2, 2], + [ + "01", frozenset(("LABEL_A", "LABEL_B")), + [1, "a"], None, None, 3 + ], + ], + ["object", "object", "object", "float64", "float64", "int64"], + ), + ( + ["r"], + list(zip(( + Structure(b"R", 0, 1, 2, "TYPE", {"a": 1, "all memes": False}, + "r-0", "r-1", "r-2"), + Structure(b"R", 420, 1337, 69, "HYPE", {"all memes": True}, + "r-420", "r-1337", "r-69"), + ))), + [ + "r->.element_id", "r->.start.element_id", "r->.end.element_id", + "r->.type", "r->.prop.a", "r->.prop.all memes" + ], + [ + ["r-0", "r-1", "r-2", "TYPE", 1, False], + ["r-420", "r-1337", "r-69", "HYPE", None, True], + ], + ["object", "object", "object", "object", "float64", "bool"], + ), + ) +) +@mark_sync_test +def test_to_df_expand(keys, values, expected_columns, expected_rows, + expected_types): + connection = ConnectionStub(records=Records(keys, values)) + result = Result(connection, DataHydrator(), 1, noop, noop) + result._run("CYPHER", {}, None, None, "r", None) + df = result.to_df(expand=True) + + assert isinstance(df, pd.DataFrame) + assert len(set(expected_columns)) == len(expected_columns) + assert set(df.keys().to_list()) == set(expected_columns) + + # We don't expect the columns to be in a specific order. + # Hence, we need to sort them before comparing. + new_order = [df.keys().get_loc(ex_c) for ex_c in expected_columns] + expected_rows = [ + [row[i] for i in new_order] + for row in expected_rows + ] + expected_types = [expected_types[i] for i in new_order] + expected_columns = [expected_columns[i] for i in new_order] + + assert len(df) == len(values) + assert df.dtypes.to_list() == expected_types + + expected_df = pd.DataFrame(expected_rows, columns=expected_columns) + assert df.equals(expected_df)