Skip to content

Commit 2eaf76f

Browse files
author
Aarsh-Wankar
committed
Refactor infer_shape methods to utilize _gufunc_to_out_shape for output shape computation
1 parent 52bbf59 commit 2eaf76f

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytensor.tensor.basic import as_tensor_variable, diagonal
1818
from pytensor.tensor.blockwise import Blockwise
1919
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
20+
from pytensor.tensor.utils import _gufunc_to_out_shape
2021

2122

2223
class MatrixPinv(Op):
@@ -63,7 +64,7 @@ def L_op(self, inputs, outputs, g_outputs):
6364
return [grad]
6465

6566
def infer_shape(self, fgraph, node, shapes):
66-
return [list(reversed(shapes[0]))]
67+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
6768

6869

6970
def pinv(x, hermitian=False):
@@ -156,7 +157,7 @@ def R_op(self, inputs, eval_points):
156157
return [-matrix_dot(xi, ev, xi)]
157158

158159
def infer_shape(self, fgraph, node, shapes):
159-
return shapes
160+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
160161

161162

162163
inv = matrix_inverse = Blockwise(MatrixInverse())
@@ -225,7 +226,7 @@ def grad(self, inputs, g_outputs):
225226
return [gz * self(x) * matrix_inverse(x).T]
226227

227228
def infer_shape(self, fgraph, node, shapes):
228-
return [()]
229+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
229230

230231
def __str__(self):
231232
return "Det"
@@ -259,7 +260,7 @@ def perform(self, node, inputs, outputs):
259260
raise ValueError("Failed to compute determinant", x) from e
260261

261262
def infer_shape(self, fgraph, node, shapes):
262-
return [(), ()]
263+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
263264

264265
def __str__(self):
265266
return "SLogDet"
@@ -317,8 +318,7 @@ def perform(self, node, inputs, outputs):
317318
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
318319

319320
def infer_shape(self, fgraph, node, shapes):
320-
n = shapes[0][0]
321-
return [(n,), (n, n)]
321+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
322322

323323

324324
eig = Blockwise(Eig())
@@ -619,16 +619,7 @@ def perform(self, node, inputs, outputs):
619619
s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
620620

621621
def infer_shape(self, fgraph, node, shapes):
622-
(x_shape,) = shapes
623-
M, N = x_shape
624-
K = ptm.minimum(M, N)
625-
s_shape = (K,)
626-
if self.compute_uv:
627-
u_shape = (M, M) if self.full_matrices else (M, K)
628-
vt_shape = (N, N) if self.full_matrices else (K, N)
629-
return [u_shape, s_shape, vt_shape]
630-
else:
631-
return [s_shape]
622+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
632623

633624
def L_op(
634625
self,

pytensor/tensor/slinalg.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor.tensor.nlinalg import kron, matrix_dot
2020
from pytensor.tensor.shape import reshape
2121
from pytensor.tensor.type import matrix, tensor, vector
22+
from pytensor.tensor.utils import _gufunc_to_out_shape
2223
from pytensor.tensor.variable import TensorVariable
2324

2425

@@ -50,7 +51,7 @@ def __init__(
5051
self.destroy_map = {0: [0]}
5152

5253
def infer_shape(self, fgraph, node, shapes):
53-
return [shapes[0]]
54+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
5455

5556
def make_node(self, x):
5657
x = as_tensor_variable(x)
@@ -268,13 +269,7 @@ def make_node(self, A, b):
268269
return Apply(self, [A, b], [x])
269270

270271
def infer_shape(self, fgraph, node, shapes):
271-
Ashape, Bshape = shapes
272-
rows = Ashape[1]
273-
if len(Bshape) == 1:
274-
return [(rows,)]
275-
else:
276-
cols = Bshape[1]
277-
return [(rows, cols)]
272+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
278273

279274
def L_op(self, inputs, outputs, output_gradients):
280275
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
@@ -816,7 +811,7 @@ def perform(self, node, inputs, output_storage):
816811
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
817812

818813
def infer_shape(self, fgraph, node, shapes):
819-
return [shapes[0]]
814+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
820815

821816
def grad(self, inputs, output_grads):
822817
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
@@ -888,7 +883,7 @@ def perform(self, node, inputs, output_storage):
888883
)
889884

890885
def infer_shape(self, fgraph, node, shapes):
891-
return [shapes[0]]
886+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
892887

893888
def grad(self, inputs, output_grads):
894889
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
@@ -1008,7 +1003,7 @@ def perform(self, node, inputs, output_storage):
10081003
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
10091004

10101005
def infer_shape(self, fgraph, node, shapes):
1011-
return [shapes[0]]
1006+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
10121007

10131008
def grad(self, inputs, output_grads):
10141009
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
@@ -1106,8 +1101,7 @@ def grad(self, inputs, gout):
11061101
return [gout[0][slc] for slc in slices]
11071102

11081103
def infer_shape(self, fgraph, nodes, shapes):
1109-
first, second = zip(*shapes, strict=True)
1110-
return [(pt.add(*first), pt.add(*second))]
1104+
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
11111105

11121106
def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
11131107
if len(matrices) != self.n_inputs:

pytensor/tensor/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,39 @@ def _parse_gufunc_signature(
202202
)
203203

204204

205+
def _gufunc_to_out_shape(
206+
signature: str, shapes: list[tuple[int, ...]]
207+
) -> list[tuple[int, ...]]:
208+
"""
209+
Compute the shape of the output of an Op given its gufunc signature and the
210+
shapes of its inputs.
211+
212+
Parameters
213+
----------
214+
signature : str
215+
The gufunc signature of the Op.
216+
eg: "(m,n),(n,p)->(m,p)".
217+
218+
shapes : list of tuple of int
219+
The list of shapes of the inputs.
220+
221+
Returns
222+
-------
223+
out_shape : list of tuple of int
224+
The list of shapes of the outputs.
225+
"""
226+
parsed = _parse_gufunc_signature(signature)
227+
out_shape = []
228+
dic = dict()
229+
for i in range(len(parsed[0])):
230+
for j in range(len(parsed[0][i])):
231+
dic[parsed[0][i][j]] = shapes[i][j]
232+
for i in range(len(parsed[1])):
233+
temp_list = [dic[x] for x in parsed[1][i]]
234+
out_shape.append(tuple(temp_list))
235+
return out_shape
236+
237+
205238
def safe_signature(
206239
core_inputs_ndim: Sequence[int],
207240
core_outputs_ndim: Sequence[int],

0 commit comments

Comments
 (0)