Skip to content

Commit ff732d6

Browse files
committed
Allow building jacobian via vectorization instead of Scan
Also allow arbitrary expression dimensionality
1 parent 3c43234 commit ff732d6

File tree

3 files changed

+267
-204
lines changed

3 files changed

+267
-204
lines changed

pytensor/gradient.py

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor
1212
from pytensor.compile.ops import ViewOp
1313
from pytensor.configdefaults import config
14-
from pytensor.graph import utils
14+
from pytensor.graph import utils, vectorize_graph
1515
from pytensor.graph.basic import Apply, NominalVariable, Variable
1616
from pytensor.graph.null_type import NullType, null_type
1717
from pytensor.graph.op import get_test_values
@@ -703,15 +703,15 @@ def grad(
703703
grad_dict[var] = g_var
704704

705705
def handle_disconnected(var):
706-
message = (
707-
"grad method was asked to compute the gradient "
708-
"with respect to a variable that is not part of "
709-
"the computational graph of the cost, or is used "
710-
f"only by a non-differentiable operator: {var}"
711-
)
712706
if disconnected_inputs == "ignore":
713-
pass
707+
return
714708
elif disconnected_inputs == "warn":
709+
message = (
710+
"grad method was asked to compute the gradient "
711+
"with respect to a variable that is not part of "
712+
"the computational graph of the cost, or is used "
713+
f"only by a non-differentiable operator: {var}"
714+
)
715715
warnings.warn(message, stacklevel=2)
716716
elif disconnected_inputs == "raise":
717717
message = utils.get_variable_trace_string(var)
@@ -2021,13 +2021,19 @@ def __str__(self):
20212021
Exception args: {args_msg}"""
20222022

20232023

2024-
def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"):
2024+
def jacobian(
2025+
expression,
2026+
wrt,
2027+
consider_constant=None,
2028+
disconnected_inputs="raise",
2029+
vectorize: bool = False,
2030+
):
20252031
"""
20262032
Compute the full Jacobian, row by row.
20272033
20282034
Parameters
20292035
----------
2030-
expression : Vector (1-dimensional) :class:`~pytensor.graph.basic.Variable`
2036+
expression :class:`~pytensor.graph.basic.Variable`
20312037
Values that we are differentiating (that we want the Jacobian of)
20322038
wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables
20332039
Term[s] with respect to which we compute the Jacobian
@@ -2051,62 +2057,73 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20512057
output, then a zero variable is returned. The return value is
20522058
of same type as `wrt`: a list/tuple or TensorVariable in all cases.
20532059
"""
2060+
from pytensor.tensor import broadcast_to, eye
20542061

20552062
if not isinstance(expression, Variable):
20562063
raise TypeError("jacobian expects a Variable as `expression`")
20572064

2058-
if expression.ndim > 1:
2059-
raise ValueError(
2060-
"jacobian expects a 1 dimensional variable as `expression`."
2061-
" If not use flatten to make it a vector"
2062-
)
2063-
20642065
using_list = isinstance(wrt, list)
20652066
using_tuple = isinstance(wrt, tuple)
2067+
grad_kwargs = {
2068+
"consider_constant": consider_constant,
2069+
"disconnected_inputs": disconnected_inputs,
2070+
}
20662071

20672072
if isinstance(wrt, list | tuple):
20682073
wrt = list(wrt)
20692074
else:
20702075
wrt = [wrt]
20712076

20722077
if all(expression.type.broadcastable):
2073-
# expression is just a scalar, use grad
2074-
return as_list_or_tuple(
2075-
using_list,
2076-
using_tuple,
2077-
grad(
2078-
expression.squeeze(),
2079-
wrt,
2080-
consider_constant=consider_constant,
2081-
disconnected_inputs=disconnected_inputs,
2082-
),
2078+
jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs)
2079+
2080+
elif vectorize:
2081+
expression_flat = expression.ravel()
2082+
row_tangent = _float_ones_like(expression_flat).type("row_tangent")
2083+
jacobian_single_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs)
2084+
2085+
n_rows = expression_flat.size
2086+
jacobian_matrices = vectorize_graph(
2087+
jacobian_single_rows,
2088+
replace={row_tangent: eye(n_rows, dtype=row_tangent.dtype)},
20832089
)
2090+
if disconnected_inputs != "raise":
2091+
# If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian
2092+
# We have to broadcast the zeros explicitly here
2093+
for i, (jacobian_single_row, jacobian_matrix) in enumerate(
2094+
zip(jacobian_single_rows, jacobian_matrices, strict=True)
2095+
):
2096+
if jacobian_single_row.ndim == jacobian_matrix.ndim:
2097+
jacobian_matrices[i] = broadcast_to(
2098+
jacobian_matrix, shape=(n_rows, *jacobian_matrix.shape)
2099+
)
20842100

2085-
def inner_function(*args):
2086-
idx = args[0]
2087-
expr = args[1]
2088-
rvals = []
2089-
for inp in args[2:]:
2090-
rval = grad(
2091-
expr[idx],
2092-
inp,
2093-
consider_constant=consider_constant,
2094-
disconnected_inputs=disconnected_inputs,
2101+
else:
2102+
2103+
def inner_function(*args):
2104+
idx, expr, *wrt = args
2105+
return grad(expr[idx], wrt, **grad_kwargs)
2106+
2107+
jacobian_matrices, updates = pytensor.scan(
2108+
inner_function,
2109+
sequences=pytensor.tensor.arange(expression.size),
2110+
non_sequences=[expression.ravel(), *wrt],
2111+
return_list=True,
2112+
)
2113+
if updates:
2114+
raise ValueError(
2115+
"The scan used to build the jacobian matrices returned a list of updates"
20952116
)
2096-
rvals.append(rval)
2097-
return rvals
2098-
2099-
# Computing the gradients does not affect the random seeds on any random
2100-
# generator used n expression (because during computing gradients we are
2101-
# just backtracking over old values. (rp Jan 2012 - if anyone has a
2102-
# counter example please show me)
2103-
jacobs, updates = pytensor.scan(
2104-
inner_function,
2105-
sequences=pytensor.tensor.arange(expression.shape[0]),
2106-
non_sequences=[expression, *wrt],
2107-
)
2108-
assert not updates, "Scan has returned a list of updates; this should not happen."
2109-
return as_list_or_tuple(using_list, using_tuple, jacobs)
2117+
2118+
if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim):
2119+
# There was some raveling or squeezing done prior to getting the jacobians
2120+
# Reshape into original shapes
2121+
jacobian_matrices = [
2122+
jac_matrix.reshape((*expression.shape, *w.shape))
2123+
for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True)
2124+
]
2125+
2126+
return as_list_or_tuple(using_list, using_tuple, jacobian_matrices)
21102127

21112128

21122129
def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
@@ -2302,7 +2319,7 @@ def _is_zero(x):
23022319

23032320
class ZeroGrad(ViewOp):
23042321
def grad(self, args, g_outs):
2305-
return [g_out.zeros_like() for g_out in g_outs]
2322+
return [g_out.zeros_like(g_out) for g_out in g_outs]
23062323

23072324
def R_op(self, inputs, eval_points):
23082325
if eval_points[0] is None:

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
30813081
else:
30823082
dims = (-1,)
30833083

3084+
if len(dims) == _x.ndim:
3085+
# Nothing to ravel
3086+
return _x
3087+
30843088
x_reshaped = _x.reshape(dims)
30853089
shape_kept_dims = _x.type.shape[: ndim - 1]
30863090
bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])

0 commit comments

Comments
 (0)