@@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node):
138
138
]
139
139
140
140
141
- @register_stabilize
142
141
@register_specialize
143
142
@node_rewriter ([Blockwise ])
144
143
def batched_vector_b_solve_to_matrix_b_solve (fgraph , node ):
145
144
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
146
145
147
146
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
148
- Only the last two dimensions of `b` and the output are swapped.
149
147
"""
150
148
core_op = node .op .core_op
151
149
@@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
175
173
new_core_op = type (core_op )(** props )
176
174
matrix_b_solve = Blockwise (new_core_op )
177
175
176
+ # Ravel any batched dims
177
+ original_b_shape = tuple (b .shape )
178
+ if len (original_b_shape ) > 2 :
179
+ b = b .reshape ((- 1 , original_b_shape [- 1 ]))
180
+
178
181
# Apply the rewrite
179
- new_solve = _T (matrix_b_solve (a , _T (b )))
182
+ new_solve = matrix_b_solve (a , b .T ).T
183
+
184
+ # Unravel any batched dims
185
+ if len (original_b_shape ) > 2 :
186
+ new_solve = new_solve .reshape (original_b_shape )
180
187
181
188
old_solve = node .outputs [0 ]
182
189
copy_stack_trace (old_solve , new_solve )
0 commit comments