Skip to content

Commit 3521ace

Browse files
committed
Eager optimization for no-op flatten
1 parent 3c43234 commit 3521ace

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
30813081
else:
30823082
dims = (-1,)
30833083

3084+
if len(dims) == _x.ndim:
3085+
# Nothing to ravel
3086+
return _x
3087+
30843088
x_reshaped = _x.reshape(dims)
30853089
shape_kept_dims = _x.type.shape[: ndim - 1]
30863090
bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])

tests/tensor/test_basic.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3867,37 +3867,24 @@ class TestInferShape(utt.InferShapeTester):
38673867
def test_Flatten(self):
38683868
atens3 = tensor3()
38693869
atens3_val = random(4, 5, 3)
3870-
for ndim in (3, 2, 1):
3870+
for ndim in (2, 1):
38713871
self._compile_and_check(
38723872
[atens3],
38733873
[flatten(atens3, ndim)],
38743874
[atens3_val],
38753875
Reshape,
3876-
excluding=["local_useless_reshape"],
38773876
)
38783877

38793878
amat = matrix()
38803879
amat_val = random(4, 5)
3881-
for ndim in (2, 1):
3880+
for ndim in (1,):
38823881
self._compile_and_check(
38833882
[amat],
38843883
[flatten(amat, ndim)],
38853884
[amat_val],
38863885
Reshape,
3887-
excluding=["local_useless_reshape"],
38883886
)
38893887

3890-
avec = vector()
3891-
avec_val = random(4)
3892-
ndim = 1
3893-
self._compile_and_check(
3894-
[avec],
3895-
[flatten(avec, ndim)],
3896-
[avec_val],
3897-
Reshape,
3898-
excluding=["local_useless_reshape"],
3899-
)
3900-
39013888
def test_Eye(self):
39023889
aiscal = iscalar()
39033890
biscal = iscalar()

0 commit comments

Comments
 (0)