Closed
Description
Describe the issue:
as written: pytensor.tensor.transpose fails if fed a numpy array. Works with list or tuple.
Reproducable code example:
import pytensor.tensor as pt, numpy as np
pt.zeros( (2,3) ).transpose(np.array([1,0]))
Error message:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/home/velochy/salk/census/estonia.ipynb Cell 27 line 2
1 import pytensor.tensor as pt
----> 2 pt.zeros( (2,3) ).transpose(np.array([1,0]))
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/tensor/variable.py:260, in _tensor_py_operators.transpose(self, *axes)
258 iterable = False
259 if len(axes) == 1 and iterable:
--> 260 return pt.basic.transpose(self, axes[0])
261 else:
262 return pt.basic.transpose(self, axes)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/tensor/basic.py:2046, in transpose(x, axes)
2042 if tuple(axes) == tuple(range(len(axes))):
2043 # No-op
2044 return _x
-> 2046 ret = _x.dimshuffle(axes)
2048 if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
2049 ret.name = _x.name + ".T"
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/tensor/variable.py:347, in _tensor_py_operators.dimshuffle(self, *pattern)
345 if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple)):
346 pattern = pattern[0]
--> 347 ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
348 return ds_op(self)
File ~/anaconda3/envs/salk/lib/python3.12/site-packages/pytensor/tensor/elemwise.py:146, in DimShuffle.__init__(self, input_ndim, new_order)
143 print("NO",input_ndim,new_order)
145 for i, j in enumerate(new_order):
--> 146 if j != "x":
147 if not isinstance(j, int | np.integer):
148 raise TypeError(
149 "DimShuffle indices must be Python ints; got "
150 f"{j} of type {type(j)}."
151 )
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
PyTensor version information:
pytensor v 2.26.4
Context for the issue:
Not a major issue (as workaround is easy - just cast to list), but if compatibility with numpy is desired, it would be nice if it worked :)