Skip to content

Commit c9f5f65

Browse files
authored
Support int-like shapes in pt.full (#759)
1 parent 146a0a8 commit c9f5f65

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pytensor/tensor/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,6 +1719,9 @@ def full(shape, fill_value, dtype=None):
17191719
fill_value = as_tensor_variable(fill_value)
17201720
if dtype:
17211721
fill_value = fill_value.astype(dtype)
1722+
1723+
if np.ndim(shape) == 0:
1724+
shape = (shape,)
17221725
return alloc(fill_value, *shape)
17231726

17241727

tests/tensor/test_basic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,10 +848,15 @@ def test_zeros(self):
848848
inp = np.zeros(shp, dtype=config.floatX)
849849
assert np.allclose(zeros_tensor(inp), np.zeros(shp))
850850

851-
def test_full(self):
852-
full_pt = ptb.full((2, 3), 3, dtype="int64")
851+
@pytest.mark.parametrize(
852+
"shape", [(2, 3), 5, np.int32(5), np.array(5), constant(5)]
853+
)
854+
def test_full(self, shape):
855+
full_pt = ptb.full(shape, 3, dtype="int64")
853856
res = pytensor.function([], full_pt, mode=self.mode)()
854-
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
857+
if isinstance(shape, ptb.TensorVariable):
858+
shape = shape.eval()
859+
assert np.array_equal(res, np.full(shape, 3, dtype="int64"))
855860

856861
@pytest.mark.parametrize("func", (ptb.zeros, ptb.empty))
857862
def test_rebuild(self, func):

0 commit comments

Comments
 (0)