diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5ca7173c37..f4df7d15ea 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1719,6 +1719,9 @@ def full(shape, fill_value, dtype=None): fill_value = as_tensor_variable(fill_value) if dtype: fill_value = fill_value.astype(dtype) + + if np.ndim(shape) == 0: + shape = (shape,) return alloc(fill_value, *shape) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0f161760bd..fdf183e13d 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -848,10 +848,15 @@ def test_zeros(self): inp = np.zeros(shp, dtype=config.floatX) assert np.allclose(zeros_tensor(inp), np.zeros(shp)) - def test_full(self): - full_pt = ptb.full((2, 3), 3, dtype="int64") + @pytest.mark.parametrize( + "shape", [(2, 3), 5, np.int32(5), np.array(5), constant(5)] + ) + def test_full(self, shape): + full_pt = ptb.full(shape, 3, dtype="int64") res = pytensor.function([], full_pt, mode=self.mode)() - assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) + if isinstance(shape, ptb.TensorVariable): + shape = shape.eval() + assert np.array_equal(res, np.full(shape, 3, dtype="int64")) @pytest.mark.parametrize("func", (ptb.zeros, ptb.empty)) def test_rebuild(self, func):