File tree Expand file tree Collapse file tree 2 files changed +6
-15
lines changed Expand file tree Collapse file tree 2 files changed +6
-15
lines changed Original file line number Diff line number Diff line change @@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
3081
3081
else :
3082
3082
dims = (- 1 ,)
3083
3083
3084
+ if len (dims ) == _x .ndim :
3085
+ # Nothing to ravel
3086
+ return _x
3087
+
3084
3088
x_reshaped = _x .reshape (dims )
3085
3089
shape_kept_dims = _x .type .shape [: ndim - 1 ]
3086
3090
bcast_new_dim = builtins .all (s == 1 for s in _x .type .shape [ndim - 1 :])
Original file line number Diff line number Diff line change @@ -3867,37 +3867,24 @@ class TestInferShape(utt.InferShapeTester):
3867
3867
def test_Flatten (self ):
3868
3868
atens3 = tensor3 ()
3869
3869
atens3_val = random (4 , 5 , 3 )
3870
- for ndim in (3 , 2 , 1 ):
3870
+ for ndim in (2 , 1 ):
3871
3871
self ._compile_and_check (
3872
3872
[atens3 ],
3873
3873
[flatten (atens3 , ndim )],
3874
3874
[atens3_val ],
3875
3875
Reshape ,
3876
- excluding = ["local_useless_reshape" ],
3877
3876
)
3878
3877
3879
3878
amat = matrix ()
3880
3879
amat_val = random (4 , 5 )
3881
- for ndim in (2 , 1 ):
3880
+ for ndim in (1 , ):
3882
3881
self ._compile_and_check (
3883
3882
[amat ],
3884
3883
[flatten (amat , ndim )],
3885
3884
[amat_val ],
3886
3885
Reshape ,
3887
- excluding = ["local_useless_reshape" ],
3888
3886
)
3889
3887
3890
- avec = vector ()
3891
- avec_val = random (4 )
3892
- ndim = 1
3893
- self ._compile_and_check (
3894
- [avec ],
3895
- [flatten (avec , ndim )],
3896
- [avec_val ],
3897
- Reshape ,
3898
- excluding = ["local_useless_reshape" ],
3899
- )
3900
-
3901
3888
def test_Eye (self ):
3902
3889
aiscal = iscalar ()
3903
3890
biscal = iscalar ()
You can’t perform that action at this time.
0 commit comments