Skip to content

Commit e48ff56

Browse files
Dhruvanshu-JoshiricardoV94
authored andcommitted
Support squeezing of unit dimension broadcastable axis
1 parent 044910b commit e48ff56

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytensor.tensor.math import ge, lt, maximum, minimum, prod, switch
3131
from pytensor.tensor.math import max as pt_max
3232
from pytensor.tensor.math import sum as pt_sum
33+
from pytensor.tensor.shape import specify_broadcastable
3334
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
3435
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
3536
from pytensor.tensor.variable import TensorVariable
@@ -592,6 +593,15 @@ def squeeze(x, axis=None):
592593
# Nothing to do
593594
return _x
594595

596+
if _x.ndim == 0:
597+
# Nothing could be squeezed
598+
return _x
599+
600+
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
601+
# We add a `specify_broadcastable` instead of raising.
602+
non_broadcastable_axis = [i for i in axis if not _x.broadcastable[i]]
603+
_x = specify_broadcastable(_x, *non_broadcastable_axis)
604+
595605
return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
596606

597607

tests/tensor/test_extra_ops.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,14 +463,6 @@ def test_axis(self):
463463

464464
assert res.broadcastable == (False, True, False)
465465

466-
def test_invalid_axis(self):
467-
# Test that trying to squeeze a non broadcastable dimension raises error
468-
variable = TensorType(config.floatX, shape=(1, None))()
469-
with pytest.raises(
470-
ValueError, match="Cannot drop a non-broadcastable dimension"
471-
):
472-
squeeze(variable, axis=1)
473-
474466
def test_scalar_input(self):
475467
x = pt.scalar("x")
476468

@@ -482,6 +474,25 @@ def test_scalar_input(self):
482474
):
483475
squeeze(x, axis=1)
484476

477+
def test_invalid_input(self):
478+
x = pt.vector("x")
479+
axis = 0
480+
481+
f = pytensor.function([x], pt.squeeze(x, axis))
482+
483+
# Test that we allow squeezing of valid non-broadcastable dimension
484+
assert f([0]) == 0
485+
486+
# Test that we cannot squeeze dimensions whose length is greater than 1
487+
error_txt_1 = re.escape("SpecifyShape: Got shape (3,), expected (1,).")
488+
error_txt_2 = re.escape("SpecifyShape: dim 0 of input has shape 3, expected 1")
489+
match = error_txt_1 if pytensor.config.mode == "FAST_COMPILE" else error_txt_2
490+
with pytest.raises(
491+
AssertionError,
492+
match=match,
493+
):
494+
f([0, 1, 2])
495+
485496

486497
class TestCompress(utt.InferShapeTester):
487498
def setup_method(self):

0 commit comments

Comments
 (0)