From 595751a219117a0706a9b769d5e5188219b48020 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 9 May 2024 10:40:54 +0200 Subject: [PATCH 1/4] convert intlike to tuple --- pytensor/tensor/basic.py | 3 +++ tests/tensor/test_basic.py | 7 ++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5ca7173c37..b021f70ee1 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1719,6 +1719,9 @@ def full(shape, fill_value, dtype=None): fill_value = as_tensor_variable(fill_value) if dtype: fill_value = fill_value.astype(dtype) + + if isinstance(shape, int | np.integer): + shape = (shape,) return alloc(fill_value, *shape) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0f161760bd..da1ca242b7 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -848,10 +848,11 @@ def test_zeros(self): inp = np.zeros(shp, dtype=config.floatX) assert np.allclose(zeros_tensor(inp), np.zeros(shp)) - def test_full(self): - full_pt = ptb.full((2, 3), 3, dtype="int64") + @pytest.mark.parametrize("shape", [(2, 3), 5, np.int32(5)]) + def test_full(self, shape): + full_pt = ptb.full(shape, 3, dtype="int64") res = pytensor.function([], full_pt, mode=self.mode)() - assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) + assert np.array_equal(res, np.full(shape, 3, dtype="int64")) @pytest.mark.parametrize("func", (ptb.zeros, ptb.empty)) def test_rebuild(self, func): From bb38f0b9301c87f5846e32760ad9866dabdf054e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 9 May 2024 11:54:40 +0200 Subject: [PATCH 2/4] change to check ndim --- pytensor/tensor/basic.py | 2 +- tests/tensor/test_basic.py | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index b021f70ee1..f4df7d15ea 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1720,7 +1720,7 @@ def full(shape, fill_value, dtype=None): if dtype: fill_value = fill_value.astype(dtype) - if isinstance(shape, int | np.integer): + if np.ndim(shape) == 0: shape = (shape,) return alloc(fill_value, *shape) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index da1ca242b7..a0db1aff46 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -848,12 +848,30 @@ def test_zeros(self): inp = np.zeros(shp, dtype=config.floatX) assert np.allclose(zeros_tensor(inp), np.zeros(shp)) - @pytest.mark.parametrize("shape", [(2, 3), 5, np.int32(5)]) + @pytest.mark.parametrize("shape", [(2, 3), 5, np.int32(5), np.array(5)]) def test_full(self, shape): full_pt = ptb.full(shape, 3, dtype="int64") res = pytensor.function([], full_pt, mode=self.mode)() assert np.array_equal(res, np.full(shape, 3, dtype="int64")) + def test_full_with_scalar(self): + shape = scalar("shape", dtype="int64") + full_pt = ptb.full(shape, 3, dtype="int64") + res = pytensor.function([shape], full_pt, mode=self.mode)(5) + assert np.array_equal(res, np.full(5, 3, dtype="int64")) + + @pytest.mark.parametrize( + "shape", + [ + 5.5, + np.array(5.5), + scalar("shape", dtype="float64"), + ], + ) + def test_full_with_float(self, shape) -> None: + with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): + ptb.full(shape, 3, dtype="int64") + @pytest.mark.parametrize("func", (ptb.zeros, ptb.empty)) def test_rebuild(self, func): x = vector(shape=(50,)) From 57b45c5a8694fcdc5b9de7bd039c14c19cece52f Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 9 May 2024 12:04:21 +0200 Subject: [PATCH 3/4] remove alloc duplicated tests --- tests/tensor/test_basic.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index a0db1aff46..a942a81ec5 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -860,18 +860,6 @@ def test_full_with_scalar(self): res = pytensor.function([shape], full_pt, mode=self.mode)(5) assert np.array_equal(res, np.full(5, 3, dtype="int64")) - @pytest.mark.parametrize( - "shape", - [ - 5.5, - np.array(5.5), - scalar("shape", dtype="float64"), - ], - ) - def test_full_with_float(self, shape) -> None: - with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"): - ptb.full(shape, 3, dtype="int64") - @pytest.mark.parametrize("func", (ptb.zeros, ptb.empty)) def test_rebuild(self, func): x = vector(shape=(50,)) From 256630dd84ca068e090d22a8745affe2a1b07fa0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Thu, 9 May 2024 12:09:25 +0200 Subject: [PATCH 4/4] consolidate tests --- tests/tensor/test_basic.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index a942a81ec5..fdf183e13d 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -848,18 +848,16 @@ def test_zeros(self): inp = np.zeros(shp, dtype=config.floatX) assert np.allclose(zeros_tensor(inp), np.zeros(shp)) - @pytest.mark.parametrize("shape", [(2, 3), 5, np.int32(5), np.array(5)]) + @pytest.mark.parametrize( + "shape", [(2, 3), 5, np.int32(5), np.array(5), constant(5)] + ) def test_full(self, shape): full_pt = ptb.full(shape, 3, dtype="int64") res = pytensor.function([], full_pt, mode=self.mode)() + if isinstance(shape, ptb.TensorVariable): + shape = shape.eval() assert np.array_equal(res, np.full(shape, 3, dtype="int64")) - def test_full_with_scalar(self): - shape = scalar("shape", dtype="int64") - full_pt = ptb.full(shape, 3, dtype="int64") - res = pytensor.function([shape], full_pt, mode=self.mode)(5) - assert np.array_equal(res, np.full(5, 3, dtype="int64")) - @pytest.mark.parametrize("func", (ptb.zeros, ptb.empty)) def test_rebuild(self, func): x = vector(shape=(50,))