-
Notifications
You must be signed in to change notification settings - Fork 132
Refactor stacking functions, add dstack and column_stack #624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
5e0abe5
2d7617e
1fc5e23
7c5b094
f98a6af
fe9e2da
d160f9f
528e79e
2c4355a
b860369
3ba05ae
2167178
4b5210a
6a962ab
e4d9087
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2758,15 +2758,34 @@ def concatenate(tensor_list, axis=0): | |
return join(axis, *tensor_list) | ||
|
||
|
||
def horizontal_stack(*args): | ||
def hstack(tup): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See below |
||
r"""Stack arrays in sequence horizontally (column wise).""" | ||
# Note: 'horizontal_stack' and 'vertical_stack' do not behave exactly like | ||
# Numpy's hstack and vstack functions. This is intended, because Numpy's | ||
# functions have potentially confusing/incoherent behavior (try them on 1D | ||
# arrays). If this is fixed in a future version of Numpy, it may be worth | ||
# trying to get closer to Numpy's way of doing things. In the meantime, | ||
# better keep different names to emphasize the implementation divergences. | ||
|
||
arrs = atleast_1d(*tup) | ||
if not isinstance(arrs, list): | ||
arrs = [arrs] | ||
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal" | ||
if arrs and arrs[0].ndim == 1: | ||
return concatenate(arrs, axis=0) | ||
else: | ||
return concatenate(arrs, axis=1) | ||
|
||
|
||
def vstack(tup): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name |
||
r"""Stack arrays in sequence vertically (row wise).""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to use this opportunity to start putting real docstrings on pytensor functions, by adding (at least) a Parameters and Returns section. A See Also section would also be nice. |
||
|
||
arrs = atleast_2d(*tup) | ||
if not isinstance(arrs, list): | ||
arrs = [arrs] | ||
|
||
return concatenate(arrs, axis=0) | ||
|
||
|
||
def horizontal_stack(*args): | ||
warnings.warn( | ||
"horizontal_stack was renamed to hstack and will be removed in a future release", | ||
FutureWarning, | ||
) | ||
if len(args) < 2: | ||
raise ValueError("Too few arguments") | ||
|
||
|
@@ -2781,8 +2800,10 @@ def horizontal_stack(*args): | |
|
||
|
||
def vertical_stack(*args): | ||
r"""Stack arrays in sequence vertically (row wise).""" | ||
|
||
warnings.warn( | ||
"vertical_stack was renamed to vstack and will be removed in a future release", | ||
FutureWarning, | ||
) | ||
if len(args) < 2: | ||
raise ValueError("Too few arguments") | ||
|
||
|
@@ -2796,6 +2817,33 @@ def vertical_stack(*args): | |
return concatenate(_args, axis=0) | ||
|
||
|
||
def dstack(tup): | ||
r"""Stack arrays in sequence along third axis (depth wise).""" | ||
|
||
# arrs = atleast_3d(*tup, left=False) | ||
# if not isinstance(arrs, list): | ||
# arrs = [arrs] | ||
arrs = [] | ||
for arr in tup: | ||
if arr.ndim == 1: | ||
arr = atleast_3d([arr], left=False) | ||
else: | ||
arr = atleast_3d(arr, left=False) | ||
arrs.append(arr) | ||
return concatenate(arrs, 2) | ||
|
||
|
||
def column_stack(tup): | ||
r"""Stack 1-D arrays as columns into a 2-D array.""" | ||
|
||
arrays = [] | ||
for arr in tup: | ||
if arr.ndim < 2: | ||
arr = atleast_2d(arr).transpose() | ||
arrays.append(arr) | ||
return concatenate(arrays, 1) | ||
|
||
|
||
def is_flat(var, ndim=1): | ||
""" | ||
Verifies the dimensionality of the var is equal to | ||
|
@@ -4298,8 +4346,12 @@ def ix_(*args): | |
"tile", | ||
"flatten", | ||
"is_flat", | ||
"vstack", | ||
"hstack", | ||
"vertical_stack", | ||
"horizontal_stack", | ||
"dstack", | ||
"column_stack", | ||
"get_vector_length", | ||
"concatenate", | ||
"stack", | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -41,9 +41,11 @@ | |||||
atleast_Nd, | ||||||
cast, | ||||||
choose, | ||||||
column_stack, | ||||||
constant, | ||||||
default, | ||||||
diag, | ||||||
dstack, | ||||||
expand_dims, | ||||||
extract_constant, | ||||||
eye, | ||||||
|
@@ -55,6 +57,7 @@ | |||||
get_underlying_scalar_constant_value, | ||||||
get_vector_length, | ||||||
horizontal_stack, | ||||||
hstack, | ||||||
identity_like, | ||||||
infer_static_shape, | ||||||
inverse_permutation, | ||||||
|
@@ -86,6 +89,7 @@ | |||||
triu_indices, | ||||||
triu_indices_from, | ||||||
vertical_stack, | ||||||
vstack, | ||||||
zeros_like, | ||||||
) | ||||||
from pytensor.tensor.blockwise import Blockwise | ||||||
|
@@ -1667,6 +1671,19 @@ def test_join_matrix_ints(self): | |||||
assert (np.asarray(grad(s.sum(), b).eval()) == 0).all() | ||||||
assert (np.asarray(grad(s.sum(), a).eval()) == 0).all() | ||||||
|
||||||
def test_join_matrix1_using_column_stack(self): | ||||||
av = np.array([0.1, 0.2, 0.3], dtype="float32") | ||||||
bv = np.array([0.7, 0.8, 0.9], dtype="float32") | ||||||
a = self.shared(av) | ||||||
b = as_tensor_variable(bv) | ||||||
s = column_stack((a, b)) | ||||||
want = np.array( | ||||||
[[0.1, 0.7], [0.2, 0.8], [0.3, 0.9]], | ||||||
dtype="float32", | ||||||
) | ||||||
out = self.eval_outputs_and_check_join([s]) | ||||||
assert (out == want).all() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will also check the shapes match, which is important for these changes
Suggested change
|
||||||
|
||||||
def test_join_matrix1_using_vertical_stack(self): | ||||||
a = self.shared(np.array([[1, 2, 3], [4, 5, 6]], dtype=self.floatX)) | ||||||
b = as_tensor_variable(np.array([[7, 8, 9]], dtype=self.floatX)) | ||||||
|
@@ -4489,15 +4506,35 @@ def test_full_like(inp, shape): | |||||
) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("func", [hstack, vstack, dstack]) | ||||||
@pytest.mark.parametrize("dimension", [1, 2, 3]) | ||||||
def test_stack_helpers(func, dimension): | ||||||
if dimension == 1: | ||||||
arrays = [np.arange(i * dimension, (i + 1) * dimension) for i in range(3)] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be more readable with |
||||||
else: | ||||||
arrays = [ | ||||||
np.arange( | ||||||
i * dimension * dimension, (i + 1) * dimension * dimension | ||||||
).reshape(dimension, dimension) | ||||||
for i in range(3) | ||||||
] | ||||||
|
||||||
result = func(arrays) | ||||||
np_result = getattr(np, func.__name__)(arrays) | ||||||
|
||||||
assert np.array_equal(result.eval(), np_result) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
|
||||||
@pytest.mark.parametrize("func", [horizontal_stack, vertical_stack]) | ||||||
def test_oriented_stack_functions(func): | ||||||
with pytest.raises(ValueError): | ||||||
func() | ||||||
with pytest.warns(FutureWarning): | ||||||
with pytest.raises(ValueError): | ||||||
func() | ||||||
|
||||||
a = ptb.tensor(dtype=np.float64, shape=(None, None, None)) | ||||||
a = ptb.tensor(dtype=np.float64, shape=(None, None, None)) | ||||||
|
||||||
with pytest.raises(ValueError): | ||||||
func(a, a) | ||||||
with pytest.raises(ValueError): | ||||||
func(a, a) | ||||||
|
||||||
|
||||||
def test_trace(): | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.