-
Notifications
You must be signed in to change notification settings - Fork 131
first pass at unstack #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
first pass at unstack #1412
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,9 +9,9 @@ | |
import numpy as np | ||
from xarray import DataArray | ||
|
||
from pytensor.xtensor.shape import stack | ||
from pytensor.xtensor.shape import stack, unstack | ||
from pytensor.xtensor.type import xtensor | ||
from tests.xtensor.util import xr_assert_allclose, xr_function | ||
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function | ||
|
||
|
||
def powerset(iterable, min_group_size=0): | ||
|
@@ -41,10 +41,7 @@ def test_transpose(): | |
outs = [transpose(x, *perm) for perm in permutations] | ||
|
||
fn = xr_function([x], outs) | ||
x_test = DataArray( | ||
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), | ||
dims=x.type.dims, | ||
) | ||
x_test = xr_arange_like(x) | ||
res = fn(x_test) | ||
expected_res = [x_test.transpose(*perm) for perm in permutations] | ||
for outs_i, res_i, expected_res_i in zip(outs, res, expected_res): | ||
|
@@ -60,10 +57,7 @@ def test_stack(): | |
] | ||
|
||
fn = xr_function([x], outs) | ||
x_test = DataArray( | ||
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), | ||
dims=x.type.dims, | ||
) | ||
x_test = xr_arange_like(x) | ||
res = fn(x_test) | ||
|
||
expected_res = [ | ||
|
@@ -80,10 +74,7 @@ def test_stack_single_dim(): | |
assert out.type.dims == ("b", "c", "d") | ||
|
||
fn = xr_function([x], out) | ||
x_test = DataArray( | ||
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape), | ||
dims=x.type.dims, | ||
) | ||
x_test = xr_arange_like(x) | ||
fn.fn.dprint(print_type=True) | ||
res = fn(x_test) | ||
expected_res = x_test.stack(d=["a"]) | ||
|
@@ -102,3 +93,90 @@ def test_multiple_stacks(): | |
res = fn(x_test) | ||
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) | ||
xr_assert_allclose(res[0], expected_res) | ||
|
||
|
||
def test_unstack(): | ||
unstacked_dims = {"a": 2, "b": 3, "c": 5, "d": 7} | ||
dims = ("abcd",) | ||
x = xtensor("x", dims=dims, shape=(2 * 3 * 5 * 7,)) | ||
outs = [ | ||
unstack( | ||
x, | ||
abcd=( | ||
{d: l for d, l in unstacked_dims.items() if d in dims_to_unstack} | ||
| ( | ||
{} | ||
if set(dims_to_unstack) == set(unstacked_dims) | ||
else { | ||
"other": int( | ||
np.prod( | ||
[ | ||
l | ||
for d, l in unstacked_dims.items() | ||
if d not in dims_to_unstack | ||
] | ||
) | ||
) | ||
} | ||
) | ||
), | ||
) | ||
Comment on lines
+103
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit hard for me to read There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was already hard to follow when I wrote it, now after the formatting it is a nightmare. I'll try to simplify things a bit tomorrow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I get what you were trying to do with the test (60% confidence), and I think it has no parallel to what our poor-mans unstack can do. We basically can only unstack "consecutive dimensions", whereas xarray will always know what a bunch of stacked dimensions correspond to, and can unstack "non-consecutive/arbitrarily ordered" dimensions. I think for our purposes we want to always get an identity if we do I added a test more like that, that maybe we can parametrize with the powerset approach? |
||
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2) | ||
] | ||
fn = xr_function([x], outs) | ||
# we test through the complementary operation in xarray to avoid needing coords | ||
# which are required for unstack. We end up with a subset of {a, b, c, d} and | ||
# other after unstacking, so we create the fully unstacked dataarray | ||
# and stack to create this extra "other" dimension as needed | ||
x_test = DataArray( | ||
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape( | ||
list(unstacked_dims.values()) | ||
), | ||
dims=list(unstacked_dims.keys()), | ||
) | ||
res = fn(x_test) | ||
|
||
expected_res = [ | ||
x_test.stack( | ||
{} | ||
if set(dims_to_unstack) == set(unstacked_dims) | ||
else {"other": [d for d in unstacked_dims if d not in dims_to_unstack]} | ||
) | ||
for dims_to_unstack in powerset(unstacked_dims.keys(), min_group_size=2) | ||
] | ||
for res_i, expected_res_i in zip(res, expected_res): | ||
assert res_i.shape == expected_res_i.shape | ||
# the shapes are right but the "other" one has the elements in different order | ||
# I think it is an issue with the test not the function but not sure | ||
# xr_assert_allclose(res_i, expected_res_i) | ||
|
||
|
||
def test_unstack_simple(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @OriolAbril I added a simple just test to convince me things look correct and they do. Doesn't mean to replace your more exhaustive test and we can remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. It is potentially more exhaustive but seeing this makes me yet a bit more convinced the issue is in the test and not the function so the complex one might need some rethinking. |
||
x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7)) | ||
y = unstack(x, bc=dict(b=3, c=5)) | ||
assert y.type.dims == ("a", "d", "b", "c") | ||
assert y.type.shape == (2, 7, 3, 5) | ||
|
||
fn = xr_function([x], y) | ||
|
||
x_test = xr_arange_like(x) | ||
x_np = x_test.values | ||
res = fn(x_test) | ||
expected = ( | ||
DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d")) | ||
.stack(bc=("b", "c")) | ||
.unstack("bc") | ||
) | ||
xr_assert_allclose(res, expected) | ||
|
||
|
||
def test_stack_unstack(): | ||
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7)) | ||
stack_x = stack(x, bd=("b", "d")) | ||
unstack_x = unstack(stack_x, bd=dict(b=3, d=7)) | ||
|
||
x_test = xr_arange_like(x) | ||
fn = xr_function([x], unstack_x) | ||
res = fn(x_test) | ||
expected_res = x_test.transpose("a", "c", "b", "d") | ||
xr_assert_allclose(res, expected_res) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like nothing requires "unstacked_lengths" to be constant/non-symbolic. So we could parametrize this Op just with
("old_dim_name", "unstacked_dims")
and pass "unstacked_lengths" tomake_node
. We can convert those to scalar TensorVariablesas_tensor(x, ndim=0)
and check that thedtype
is integer.Everything in the rewrite with reshape would work the same, but we would extract them from
node.inputs[1:]
This will allow stuff like:
Without the user having to pre-commit to static shapes for b, c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CC @OriolAbril
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointers, I'll try to make the updates