Skip to content

Commit 282e528

Browse files
author
Aarsh-Wankar
committed
Redacted changes for Ops with non-conclusive gufunc_signature
1 parent 2eaf76f commit 282e528

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,16 @@ def perform(self, node, inputs, outputs):
619619
s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
620620

621621
def infer_shape(self, fgraph, node, shapes):
622-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
622+
(x_shape,) = shapes
623+
M, N = x_shape
624+
K = ptm.minimum(M, N)
625+
s_shape = (K,)
626+
if self.compute_uv:
627+
u_shape = (M, M) if self.full_matrices else (M, K)
628+
vt_shape = (N, N) if self.full_matrices else (K, N)
629+
return [u_shape, s_shape, vt_shape]
630+
else:
631+
return [s_shape]
623632

624633
def L_op(
625634
self,

pytensor/tensor/slinalg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from pytensor.tensor.nlinalg import kron, matrix_dot
2020
from pytensor.tensor.shape import reshape
2121
from pytensor.tensor.type import matrix, tensor, vector
22-
from pytensor.tensor.utils import _gufunc_to_out_shape
2322
from pytensor.tensor.variable import TensorVariable
23+
from pytensor.tensor.utils import _gufunc_to_out_shape
2424

2525

2626
logger = logging.getLogger(__name__)
@@ -1101,7 +1101,8 @@ def grad(self, inputs, gout):
11011101
return [gout[0][slc] for slc in slices]
11021102

11031103
def infer_shape(self, fgraph, nodes, shapes):
1104-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
1104+
first, second = zip(*shapes, strict=True)
1105+
return [(pt.add(*first), pt.add(*second))]
11051106

11061107
def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
11071108
if len(matrices) != self.n_inputs:

0 commit comments

Comments
 (0)