diff --git a/pytensor/xtensor/rewriting/shape.py b/pytensor/xtensor/rewriting/shape.py index b2eabb5c8e..976bf68731 100644 --- a/pytensor/xtensor/rewriting/shape.py +++ b/pytensor/xtensor/rewriting/shape.py @@ -2,7 +2,7 @@ from pytensor.tensor import moveaxis from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.rewriting.basic import register_xcanonicalize -from pytensor.xtensor.shape import Stack +from pytensor.xtensor.shape import Stack, UnStack @register_xcanonicalize @@ -27,3 +27,19 @@ def lower_stack(fgraph, node): new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) return [new_out] + + +@register_xcanonicalize +@node_rewriter(tracks=[UnStack]) +def lower_unstack(fgraph, node): + [x] = node.inputs + axis_to_unstack = x.type.dims.index(node.op.old_dim_name) + + x_tensor = tensor_from_xtensor(x) + x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1]) + final_tensor = x_tensor_transposed.reshape( + (*x_tensor_transposed.shape[:-1], *node.op.unstacked_lengths) + ) + + new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) + return [new_out] diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index 8fa0f42630..18a966d5f4 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -69,3 +69,75 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) ) y = Stack(new_dim_name, tuple(stacked_dims))(y) return y + + +class UnStack(XOp): + __props__ = ("old_dim_name", "unstacked_dims", "unstacked_lengths") + + def __init__( + self, + old_dim_name: str, + unstacked_dims: tuple[str, ...], + unstacked_lengths: tuple[int, ...], + ): + super().__init__() + if old_dim_name in unstacked_dims: + raise ValueError( + f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}" + ) + if len(unstacked_dims) != len(unstacked_lengths): + raise ValueError( + "Tuples with unstacked dim names and lengths must have the same length " + f"but have {len(unstacked_dims)} and {len(unstacked_lengths)}" + ) + if not unstacked_dims: + raise ValueError("Dims to unstack into can't be empty.") + if len(unstacked_dims) == 1: + raise ValueError("Only one dimension to unstack into, use rename instead") + self.old_dim_name = old_dim_name + self.unstacked_dims = unstacked_dims + self.unstacked_lengths = unstacked_lengths + + def make_node(self, x): + x = as_xtensor(x) + if self.old_dim_name not in x.type.dims: + raise ValueError( + f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}" + ) + if not set(self.unstacked_dims).isdisjoint(x.type.dims): + raise ValueError( + f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}" + ) + if x.type.ndim == 1: + batch_dims, batch_shape = (), () + else: + batch_dims, batch_shape = zip( + *( + (dim, shape) + for dim, shape in zip(x.type.dims, x.type.shape) + if dim != self.old_dim_name + ) + ) + + output = xtensor( + dtype=x.type.dtype, + shape=(*batch_shape, *self.unstacked_lengths), + dims=(*batch_dims, *self.unstacked_dims), + ) + return Apply(self, [x], [output]) + + +def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]): + if dim is not None: + if dims: + raise ValueError( + "Cannot use both positional dim and keyword dims in unstack" + ) + dims = dim + + y = x + for old_dim_name, unstacked_dict in dims.items(): + y = UnStack( + old_dim_name, tuple(unstacked_dict.keys()), tuple(unstacked_dict.values()) + )(y) + return y diff --git a/tests/xtensor/test_shape.py b/tests/xtensor/test_shape.py index 25bdf68ee6..42ab59e780 100644 --- a/tests/xtensor/test_shape.py +++ b/tests/xtensor/test_shape.py @@ -9,9 +9,9 @@ import numpy as np from xarray import DataArray -from pytensor.xtensor.shape import stack +from pytensor.xtensor.shape import stack, unstack from pytensor.xtensor.type import xtensor -from tests.xtensor.util import xr_assert_allclose, xr_function +from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function def powerset(iterable, min_group_size=0): @@ -41,10 +41,7 @@ def test_transpose(): outs = [transpose(x, *perm) for perm in permutations] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [x_test.transpose(*perm) for perm in permutations] for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): @@ -60,10 +57,7 @@ def test_stack(): ] fn = xr_function([x], outs) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) res = fn(x_test) expected_res = [ @@ -80,10 +74,7 @@ def test_stack_single_dim(): assert out.type.dims == ("b", "c", "d") fn = xr_function([x], out) - x_test = DataArray( - np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), - dims=x.type.dims, - ) + x_test = xr_arange_like(x) fn.fn.dprint(print_type=True) res = fn(x_test) expected_res = x_test.stack(d=["a"]) @@ -102,3 +93,90 @@ def test_multiple_stacks(): res = fn(x_test) expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) xr_assert_allclose(res[0], expected_res) + + +def test_unstack(): + unstacked_dims = {"a": 2, "b": 3, "c": 5, "d": 7} + dims = ("abcd",) + x = xtensor("x", dims=dims, shape=(2 * 3 * 5 * 7,)) + outs = [ + unstack( + x, + abcd=( + {d: l for d, l in unstacked_dims.items() if d in dims_to_unstack} + | ( + {} + if set(dims_to_unstack) == set(unstacked_dims) + else { + "other": int( + np.prod( + [ + l + for d, l in unstacked_dims.items() + if d not in dims_to_unstack + ] + ) + ) + } + ) + ), + ) + for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2) + ] + fn = xr_function([x], outs) + # we test through the complementary operation in xarray to avoid needing coords + # which are required for unstack. We end up with a subset of {a, b, c, d} and + # other after unstacking, so we create the fully unstacked dataarray + # and stack to create this extra "other" dimension as needed + x_test = DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( + list(unstacked_dims.values()) + ), + dims=list(unstacked_dims.keys()), + ) + res = fn(x_test) + + expected_res = [ + x_test.stack( + {} + if set(dims_to_unstack) == set(unstacked_dims) + else {"other": [d for d in unstacked_dims if d not in dims_to_unstack]} + ) + for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2) + ] + for res_i, expected_res_i in zip(res, expected_res): + assert res_i.shape == expected_res_i.shape + # the shapes are right but the "other" one has the elements in different order + # I think it is an issue with the test not the function but not sure + # xr_assert_allclose(res_i, expected_res_i) + + +def test_unstack_simple(): + x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7)) + y = unstack(x, bc=dict(b=3, c=5)) + assert y.type.dims == ("a", "d", "b", "c") + assert y.type.shape == (2, 7, 3, 5) + + fn = xr_function([x], y) + + x_test = xr_arange_like(x) + x_np = x_test.values + res = fn(x_test) + expected = ( + DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d")) + .stack(bc=("b", "c")) + .unstack("bc") + ) + xr_assert_allclose(res, expected) + + +def test_stack_unstack(): + x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) + stack_x = stack(x, bd=("b", "d")) + unstack_x = unstack(stack_x, bd=dict(b=3, d=7)) + + x_test = xr_arange_like(x) + fn = xr_function([x], unstack_x) + res = fn(x_test) + expected_res = x_test.transpose("a", "c", "b", "d") + xr_assert_allclose(res, expected_res) diff --git a/tests/xtensor/util.py b/tests/xtensor/util.py index b429adb794..e12a3a0528 100644 --- a/tests/xtensor/util.py +++ b/tests/xtensor/util.py @@ -1,3 +1,4 @@ +import numpy as np from xarray import DataArray from xarray.testing import assert_allclose @@ -35,3 +36,10 @@ def xr_assert_allclose(x, y, *args, **kwargs): x = x.drop_vars(x.coords) y = y.drop_vars(y.coords) assert_allclose(x, y, *args, **kwargs) + + +def xr_arange_like(x): + return DataArray( + np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), + dims=x.type.dims, + )