Skip to content

Commit 1e687ad

Browse files
committed
Expand batched_vector_b_solve_to_matrix rewrite
It now supports an arbitrary number of batched dimensions of b, by raveling them together
1 parent 5431080 commit 1e687ad

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node):
138138
]
139139

140140

141-
@register_stabilize
142141
@register_specialize
143142
@node_rewriter([Blockwise])
144143
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
145144
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
146145
147146
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
148-
Only the last two dimensions of `b` and the output are swapped.
149147
"""
150148
core_op = node.op.core_op
151149

@@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
175173
new_core_op = type(core_op)(**props)
176174
matrix_b_solve = Blockwise(new_core_op)
177175

176+
# Ravel any batched dims
177+
original_b_shape = tuple(b.shape)
178+
if len(original_b_shape) > 2:
179+
b = b.reshape((-1, original_b_shape[-1]))
180+
178181
# Apply the rewrite
179-
new_solve = _T(matrix_b_solve(a, _T(b)))
182+
new_solve = matrix_b_solve(a, b.T).T
183+
184+
# Unravel any batched dims
185+
if len(original_b_shape) > 2:
186+
new_solve = new_solve.reshape(original_b_shape)
180187

181188
old_solve = node.outputs[0]
182189
copy_stack_trace(old_solve, new_solve)

0 commit comments

Comments
 (0)