Skip to content

Commit 17ec0cd

Browse files
committed
Allow building jacobian via vectorization instead of Scan
Also allow arbitrary expression dimensionality
1 parent e25e8a2 commit 17ec0cd

File tree

2 files changed

+62
-43
lines changed

2 files changed

+62
-43
lines changed

pytensor/gradient.py

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor
1212
from pytensor.compile.ops import ViewOp
1313
from pytensor.configdefaults import config
14-
from pytensor.graph import utils
14+
from pytensor.graph import utils, vectorize_graph
1515
from pytensor.graph.basic import Apply, NominalVariable, Variable
1616
from pytensor.graph.null_type import NullType, null_type
1717
from pytensor.graph.op import get_test_values
@@ -2021,7 +2021,13 @@ def __str__(self):
20212021
Exception args: {args_msg}"""
20222022

20232023

2024-
def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"):
2024+
def jacobian(
2025+
expression,
2026+
wrt,
2027+
consider_constant=None,
2028+
disconnected_inputs="raise",
2029+
vectorize: bool = False,
2030+
):
20252031
"""
20262032
Compute the full Jacobian, row by row.
20272033
@@ -2051,62 +2057,71 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20512057
output, then a zero variable is returned. The return value is
20522058
of same type as `wrt`: a list/tuple or TensorVariable in all cases.
20532059
"""
2060+
# from pytensor.tensor import arange, scalar
2061+
from pytensor.tensor import eye
20542062

20552063
if not isinstance(expression, Variable):
20562064
raise TypeError("jacobian expects a Variable as `expression`")
20572065

2058-
if expression.ndim > 1:
2059-
raise ValueError(
2060-
"jacobian expects a 1 dimensional variable as `expression`."
2061-
" If not use flatten to make it a vector"
2062-
)
2063-
20642066
using_list = isinstance(wrt, list)
20652067
using_tuple = isinstance(wrt, tuple)
2068+
grad_kwargs = {
2069+
"consider_constant": consider_constant,
2070+
"disconnected_inputs": disconnected_inputs,
2071+
}
20662072

20672073
if isinstance(wrt, list | tuple):
20682074
wrt = list(wrt)
20692075
else:
20702076
wrt = [wrt]
20712077

2072-
if expression.ndim == 0:
2073-
# expression is just a scalar, use grad
2074-
return as_list_or_tuple(
2075-
using_list,
2076-
using_tuple,
2077-
grad(
2078-
expression,
2079-
wrt,
2080-
consider_constant=consider_constant,
2081-
disconnected_inputs=disconnected_inputs,
2082-
),
2078+
if all(expression.type.broadcastable):
2079+
jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs)
2080+
2081+
elif vectorize:
2082+
expression_flat = expression.ravel()
2083+
row_tangent = _float_ones_like(expression_flat).type("row_tangent")
2084+
jacobian_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs)
2085+
jacobian_matrices = vectorize_graph(
2086+
jacobian_rows,
2087+
replace={row_tangent: eye(expression_flat.size, dtype=row_tangent.dtype)},
20832088
)
2089+
# row_index = scalar("idx", dtype="int64")
2090+
# jacobian_rows = grad(
2091+
# expression.ravel()[row_index],
2092+
# wrt=wrt,
2093+
# **grad_kwargs
2094+
# )
2095+
# rows_indices = arange(expression.size)
2096+
# jacobian_matrices = vectorize_graph(
2097+
# jacobian_rows, replace={row_index: rows_indices}
2098+
# )
2099+
else:
20842100

2085-
def inner_function(*args):
2086-
idx = args[0]
2087-
expr = args[1]
2088-
rvals = []
2089-
for inp in args[2:]:
2090-
rval = grad(
2091-
expr[idx],
2092-
inp,
2093-
consider_constant=consider_constant,
2094-
disconnected_inputs=disconnected_inputs,
2101+
def inner_function(*args):
2102+
idx, expr, *wrt = args
2103+
return grad(expr[idx], wrt, **grad_kwargs)
2104+
2105+
jacobian_matrices, updates = pytensor.scan(
2106+
inner_function,
2107+
sequences=pytensor.tensor.arange(expression.size),
2108+
non_sequences=[expression.ravel(), *wrt],
2109+
return_list=True,
2110+
)
2111+
if updates:
2112+
raise ValueError(
2113+
"The scan used to build the jacobian matrices returned a list of updates"
20952114
)
2096-
rvals.append(rval)
2097-
return rvals
2098-
2099-
# Computing the gradients does not affect the random seeds on any random
2100-
# generator used n expression (because during computing gradients we are
2101-
# just backtracking over old values. (rp Jan 2012 - if anyone has a
2102-
# counter example please show me)
2103-
jacobs, updates = pytensor.scan(
2104-
inner_function,
2105-
sequences=pytensor.tensor.arange(expression.shape[0]),
2106-
non_sequences=[expression, *wrt],
2107-
)
2108-
assert not updates, "Scan has returned a list of updates; this should not happen."
2109-
return as_list_or_tuple(using_list, using_tuple, jacobs)
2115+
2116+
if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim):
2117+
# There was some raveling or squeezing done prior to getting the jacobians
2118+
# Reshape into original shapes
2119+
jacobian_matrices = [
2120+
jac_matrix.reshape((*expression.shape, *w.shape))
2121+
for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True)
2122+
]
2123+
2124+
return as_list_or_tuple(using_list, using_tuple, jacobian_matrices)
21102125

21112126

21122127
def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3074,6 +3074,10 @@ def flatten(x, ndim=1):
30743074
else:
30753075
dims = (-1,)
30763076

3077+
if len(dims) == _x.ndim:
3078+
# Nothing to ravel
3079+
return _x
3080+
30773081
x_reshaped = _x.reshape(dims)
30783082
shape_kept_dims = _x.type.shape[: ndim - 1]
30793083
bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])

0 commit comments

Comments
 (0)