diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e30887cfe3..36eac9a461 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -4355,28 +4355,22 @@ def empty_like( def atleast_Nd( - *arys: np.ndarray | TensorVariable, n: int = 1, left: bool = True + arry: np.ndarray | TensorVariable, *, n: int = 1, left: bool = True ) -> TensorVariable: - """Convert inputs to arrays with at least `n` dimensions.""" - res = [] - for ary in arys: - ary = as_tensor(ary) + """Convert input to an array with at least `n` dimensions.""" - if ary.ndim >= n: - result = ary - else: - result = ( - shape_padleft(ary, n - ary.ndim) - if left - else shape_padright(ary, n - ary.ndim) - ) + arry = as_tensor(arry) - res.append(result) - - if len(res) == 1: - return res[0] + if arry.ndim >= n: + result = arry else: - return res + result = ( + shape_padleft(arry, n - arry.ndim) + if left + else shape_padright(arry, n - arry.ndim) + ) + + return result atleast_1d = partial(atleast_Nd, n=1) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 60643e2984..1186aeb35c 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -4364,7 +4364,8 @@ def test_atleast_Nd(): for n in range(1, 3): ary1, ary2 = dscalar(), dvector() - res_ary1, res_ary2 = atleast_Nd(ary1, ary2, n=n) + res_ary1 = atleast_Nd(ary1, n=n) + res_ary2 = atleast_Nd(ary2, n=n) assert res_ary1.ndim == n if n == ary2.ndim: