Skip to content

first pass at unstack #1412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: labeled_tensors
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion pytensor/xtensor/rewriting/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.tensor import moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
from pytensor.xtensor.shape import Stack
from pytensor.xtensor.shape import Stack, UnStack


@register_xcanonicalize
Expand All @@ -27,3 +27,19 @@ def lower_stack(fgraph, node):

new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]


@register_xcanonicalize
@node_rewriter(tracks=[UnStack])
def lower_unstack(fgraph, node):
[x] = node.inputs
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)

x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
final_tensor = x_tensor_transposed.reshape(
(*x_tensor_transposed.shape[:-1], *node.op.unstacked_lengths)
)

new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]
72 changes: 72 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,75 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
)
y = Stack(new_dim_name, tuple(stacked_dims))(y)
return y


class UnStack(XOp):
__props__ = ("old_dim_name", "unstacked_dims", "unstacked_lengths")
Copy link
Member Author

@ricardoV94 ricardoV94 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like nothing requires "unstacked_lengths" to be constant/non-symbolic. So we could parametrize this Op just with ("old_dim_name", "unstacked_dims") and pass "unstacked_lengths" to make_node. We can convert those to scalar TensorVariables as_tensor(x, ndim=0) and check that the dtype is integer.

Everything in the rewrite with reshape would work the same, but we would extract them from node.inputs[1:]

This will allow stuff like:

x = xtensor(dims=("a", "b", "c"))
y = stack(x, bc=("b", "c"))
# do something with stacked y
z = unstack(y, bc=dict(b=x.sizes["b"], c=x.sizes["c"]))

Without the user having to pre-commit to static shapes for b, c

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers, I'll try to make the updates


def __init__(
self,
old_dim_name: str,
unstacked_dims: tuple[str, ...],
unstacked_lengths: tuple[int, ...],
):
super().__init__()
if old_dim_name in unstacked_dims:
raise ValueError(
f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}"
)
if len(unstacked_dims) != len(unstacked_lengths):
raise ValueError(
"Tuples with unstacked dim names and lengths must have the same length "
f"but have {len(unstacked_dims)} and {len(unstacked_lengths)}"
)
if not unstacked_dims:
raise ValueError("Dims to unstack into can't be empty.")
if len(unstacked_dims) == 1:
raise ValueError("Only one dimension to unstack into, use rename instead")
self.old_dim_name = old_dim_name
self.unstacked_dims = unstacked_dims
self.unstacked_lengths = unstacked_lengths

def make_node(self, x):
x = as_xtensor(x)
if self.old_dim_name not in x.type.dims:
raise ValueError(
f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}"
)
if not set(self.unstacked_dims).isdisjoint(x.type.dims):
raise ValueError(
f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}"
)
if x.type.ndim == 1:
batch_dims, batch_shape = (), ()
else:
batch_dims, batch_shape = zip(
*(
(dim, shape)
for dim, shape in zip(x.type.dims, x.type.shape)
if dim != self.old_dim_name
)
)

output = xtensor(
dtype=x.type.dtype,
shape=(*batch_shape, *self.unstacked_lengths),
dims=(*batch_dims, *self.unstacked_dims),
)
return Apply(self, [x], [output])


def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None:
if dims:
raise ValueError(
"Cannot use both positional dim and keyword dims in unstack"
)
dims = dim

y = x
for old_dim_name, unstacked_dict in dims.items():
y = UnStack(
old_dim_name, tuple(unstacked_dict.keys()), tuple(unstacked_dict.values())
)(y)
return y
106 changes: 92 additions & 14 deletions tests/xtensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import numpy as np
from xarray import DataArray

from pytensor.xtensor.shape import stack
from pytensor.xtensor.shape import stack, unstack
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function


def powerset(iterable, min_group_size=0):
Expand Down Expand Up @@ -41,10 +41,7 @@ def test_transpose():
outs = [transpose(x, *perm) for perm in permutations]

fn = xr_function([x], outs)
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
res = fn(x_test)
expected_res = [x_test.transpose(*perm) for perm in permutations]
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res):
Expand All @@ -60,10 +57,7 @@ def test_stack():
]

fn = xr_function([x], outs)
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
res = fn(x_test)

expected_res = [
Expand All @@ -80,10 +74,7 @@ def test_stack_single_dim():
assert out.type.dims == ("b", "c", "d")

fn = xr_function([x], out)
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
x_test = xr_arange_like(x)
fn.fn.dprint(print_type=True)
res = fn(x_test)
expected_res = x_test.stack(d=["a"])
Expand All @@ -102,3 +93,90 @@ def test_multiple_stacks():
res = fn(x_test)
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
xr_assert_allclose(res[0], expected_res)


def test_unstack():
unstacked_dims = {"a": 2, "b": 3, "c": 5, "d": 7}
dims = ("abcd",)
x = xtensor("x", dims=dims, shape=(2 * 3 * 5 * 7,))
outs = [
unstack(
x,
abcd=(
{d: l for d, l in unstacked_dims.items() if d in dims_to_unstack}
| (
{}
if set(dims_to_unstack) == set(unstacked_dims)
else {
"other": int(
np.prod(
[
l
for d, l in unstacked_dims.items()
if d not in dims_to_unstack
]
)
)
}
)
),
)
Comment on lines +103 to +123
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hard for me to read

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was already hard to follow when I wrote it, now after the formatting it is a nightmare. I'll try to simplify things a bit tomorrow.

Copy link
Member Author

@ricardoV94 ricardoV94 May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I get what you were trying to do with the test (60% confidence), and I think it has no parallel to what our poor-mans unstack can do. We basically can only unstack "consecutive dimensions", whereas xarray will always know what a bunch of stacked dimensions correspond to, and can unstack "non-consecutive/arbitrarily ordered" dimensions.

I think for our purposes we want to always get an identity if we do transpose(unstack(stack(new_dim=stacked_dims), new_dim=original_stacked_dims), original_dims), whereoriginal_stacked_dims contains the same dims, in the same order and with the same sizes.

I added a test more like that, that maybe we can parametrize with the powerset approach?

for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2)
]
fn = xr_function([x], outs)
# we test through the complementary operation in xarray to avoid needing coords
# which are required for unstack. We end up with a subset of {a, b, c, d} and
# other after unstacking, so we create the fully unstacked dataarray
# and stack to create this extra "other" dimension as needed
x_test = DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
list(unstacked_dims.values())
),
dims=list(unstacked_dims.keys()),
)
res = fn(x_test)

expected_res = [
x_test.stack(
{}
if set(dims_to_unstack) == set(unstacked_dims)
else {"other": [d for d in unstacked_dims if d not in dims_to_unstack]}
)
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2)
]
for res_i, expected_res_i in zip(res, expected_res):
assert res_i.shape == expected_res_i.shape
# the shapes are right but the "other" one has the elements in different order
# I think it is an issue with the test not the function but not sure
# xr_assert_allclose(res_i, expected_res_i)


def test_unstack_simple():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OriolAbril I added a simple just test to convince me things look correct and they do. Doesn't mean to replace your more exhaustive test and we can remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It is potentially more exhaustive but seeing this makes me yet a bit more convinced the issue is in the test and not the function so the complex one might need some rethinking.

x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7))
y = unstack(x, bc=dict(b=3, c=5))
assert y.type.dims == ("a", "d", "b", "c")
assert y.type.shape == (2, 7, 3, 5)

fn = xr_function([x], y)

x_test = xr_arange_like(x)
x_np = x_test.values
res = fn(x_test)
expected = (
DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d"))
.stack(bc=("b", "c"))
.unstack("bc")
)
xr_assert_allclose(res, expected)


def test_stack_unstack():
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7))
stack_x = stack(x, bd=("b", "d"))
unstack_x = unstack(stack_x, bd=dict(b=3, d=7))

x_test = xr_arange_like(x)
fn = xr_function([x], unstack_x)
res = fn(x_test)
expected_res = x_test.transpose("a", "c", "b", "d")
xr_assert_allclose(res, expected_res)
8 changes: 8 additions & 0 deletions tests/xtensor/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from xarray import DataArray
from xarray.testing import assert_allclose

Expand Down Expand Up @@ -35,3 +36,10 @@ def xr_assert_allclose(x, y, *args, **kwargs):
x = x.drop_vars(x.coords)
y = y.drop_vars(y.coords)
assert_allclose(x, y, *args, **kwargs)


def xr_arange_like(x):
return DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
Loading