Skip to content

Commit e2a2665

Browse files
committed
Allow building jacobian via vectorization instead of Scan
Also allow arbitrary expression dimensionality
1 parent 3c43234 commit e2a2665

File tree

4 files changed

+300
-210
lines changed

4 files changed

+300
-210
lines changed

doc/tutorial/gradients.rst

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,20 @@ PyTensor implements the :func:`pytensor.gradient.jacobian` macro that does all
101101
that is needed to compute the Jacobian. The following text explains how
102102
to do it manually.
103103

104+
Using Scan
105+
----------
106+
104107
In order to manually compute the Jacobian of some function ``y`` with
105-
respect to some parameter ``x`` we need to use `scan`. What we
106-
do is to loop over the entries in ``y`` and compute the gradient of
108+
respect to some parameter ``x`` we can use `scan`.
109+
We loop over the entries in ``y`` and compute the gradient of
107110
``y[i]`` with respect to ``x``.
108111

109112
.. note::
110113

111114
`scan` is a generic op in PyTensor that allows writing in a symbolic
112115
manner all kinds of recurrent equations. While creating
113116
symbolic loops (and optimizing them for performance) is a hard task,
114-
effort is being done for improving the performance of `scan`. We
115-
shall return to :ref:`scan<tutloop>` later in this tutorial.
117+
effort is being done for improving the performance of `scan`.
116118

117119
>>> import pytensor
118120
>>> import pytensor.tensor as pt
@@ -124,9 +126,9 @@ do is to loop over the entries in ``y`` and compute the gradient of
124126
array([[ 8., 0.],
125127
[ 0., 8.]])
126128

127-
What we do in this code is to generate a sequence of integers from ``0`` to
128-
``y.shape[0]`` using `pt.arange`. Then we loop through this sequence, and
129-
at each step, we compute the gradient of element ``y[i]`` with respect to
129+
This code generates a sequence of integers from ``0`` to
130+
``y.shape[0]`` using `pt.arange`. Then it loops through this sequence, and
131+
at each step, computes the gradient of element ``y[i]`` with respect to
130132
``x``. `scan` automatically concatenates all these rows, generating a
131133
matrix which corresponds to the Jacobian.
132134

@@ -139,6 +141,31 @@ matrix which corresponds to the Jacobian.
139141
``x`` anymore, while ``y[i]`` still is.
140142

141143

144+
Using automatic vectorization
145+
-----------------------------
146+
An alternative way to build the Jacobian is to vectorize the graph that computes a single row or colum of the jacobian
147+
We can use `Lop` or `Rop` (more about it below) to obtain the row or column of the jacobian and `vectorize_graph`
148+
to vectorize it to the full jacobian matrix.
149+
150+
>>> import pytensor
151+
>>> import pytensor.tensor as pt
152+
>>> from pytensor.gradient import Lop
153+
>>> from pytensor.graph import vectorize_graph
154+
>>> x = pt.dvector('x')
155+
>>> y = x ** 2
156+
>>> row_tangent = pt.dvector("row_tangent") # Helper variable, it will be replaced during vectorization
157+
>>> J_row = Lop(y, x, row_tangent)
158+
>>> J = vectorize_graph(J_row, replace={row_tangent: pt.eye(x.size)})
159+
>>> f = pytensor.function([x], J)
160+
>>> f([4, 4])
161+
array([[ 8., 0.],
162+
[ 0., 8.]])
163+
164+
This avoids the overhead of scan, at the cost of higher memory usage if the jacobian expression has large intermediate operations.
165+
Also, not all graphs are safely vectorizable (e.g., if different rows require intermediate operations of different sizes).
166+
For these reasons `jacobian` uses scan by default. The behavior can be changed by setting `vectorize=True`.
167+
168+
142169
Computing the Hessian
143170
=====================
144171

pytensor/gradient.py

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytensor
1212
from pytensor.compile.ops import ViewOp
1313
from pytensor.configdefaults import config
14-
from pytensor.graph import utils
14+
from pytensor.graph import utils, vectorize_graph
1515
from pytensor.graph.basic import Apply, NominalVariable, Variable
1616
from pytensor.graph.null_type import NullType, null_type
1717
from pytensor.graph.op import get_test_values
@@ -703,15 +703,15 @@ def grad(
703703
grad_dict[var] = g_var
704704

705705
def handle_disconnected(var):
706-
message = (
707-
"grad method was asked to compute the gradient "
708-
"with respect to a variable that is not part of "
709-
"the computational graph of the cost, or is used "
710-
f"only by a non-differentiable operator: {var}"
711-
)
712706
if disconnected_inputs == "ignore":
713-
pass
707+
return
714708
elif disconnected_inputs == "warn":
709+
message = (
710+
"grad method was asked to compute the gradient "
711+
"with respect to a variable that is not part of "
712+
"the computational graph of the cost, or is used "
713+
f"only by a non-differentiable operator: {var}"
714+
)
715715
warnings.warn(message, stacklevel=2)
716716
elif disconnected_inputs == "raise":
717717
message = utils.get_variable_trace_string(var)
@@ -2021,13 +2021,19 @@ def __str__(self):
20212021
Exception args: {args_msg}"""
20222022

20232023

2024-
def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise"):
2024+
def jacobian(
2025+
expression,
2026+
wrt,
2027+
consider_constant=None,
2028+
disconnected_inputs="raise",
2029+
vectorize: bool = False,
2030+
):
20252031
"""
20262032
Compute the full Jacobian, row by row.
20272033
20282034
Parameters
20292035
----------
2030-
expression : Vector (1-dimensional) :class:`~pytensor.graph.basic.Variable`
2036+
expression :class:`~pytensor.graph.basic.Variable`
20312037
Values that we are differentiating (that we want the Jacobian of)
20322038
wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables
20332039
Term[s] with respect to which we compute the Jacobian
@@ -2051,62 +2057,73 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20512057
output, then a zero variable is returned. The return value is
20522058
of same type as `wrt`: a list/tuple or TensorVariable in all cases.
20532059
"""
2060+
from pytensor.tensor import broadcast_to, eye
20542061

20552062
if not isinstance(expression, Variable):
20562063
raise TypeError("jacobian expects a Variable as `expression`")
20572064

2058-
if expression.ndim > 1:
2059-
raise ValueError(
2060-
"jacobian expects a 1 dimensional variable as `expression`."
2061-
" If not use flatten to make it a vector"
2062-
)
2063-
20642065
using_list = isinstance(wrt, list)
20652066
using_tuple = isinstance(wrt, tuple)
2067+
grad_kwargs = {
2068+
"consider_constant": consider_constant,
2069+
"disconnected_inputs": disconnected_inputs,
2070+
}
20662071

20672072
if isinstance(wrt, list | tuple):
20682073
wrt = list(wrt)
20692074
else:
20702075
wrt = [wrt]
20712076

20722077
if all(expression.type.broadcastable):
2073-
# expression is just a scalar, use grad
2074-
return as_list_or_tuple(
2075-
using_list,
2076-
using_tuple,
2077-
grad(
2078-
expression.squeeze(),
2079-
wrt,
2080-
consider_constant=consider_constant,
2081-
disconnected_inputs=disconnected_inputs,
2082-
),
2078+
jacobian_matrices = grad(expression.squeeze(), wrt, **grad_kwargs)
2079+
2080+
elif vectorize:
2081+
expression_flat = expression.ravel()
2082+
row_tangent = _float_ones_like(expression_flat).type("row_tangent")
2083+
jacobian_single_rows = Lop(expression.ravel(), wrt, row_tangent, **grad_kwargs)
2084+
2085+
n_rows = expression_flat.size
2086+
jacobian_matrices = vectorize_graph(
2087+
jacobian_single_rows,
2088+
replace={row_tangent: eye(n_rows, dtype=row_tangent.dtype)},
20832089
)
2090+
if disconnected_inputs != "raise":
2091+
# If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian
2092+
# We have to broadcast the zeros explicitly here
2093+
for i, (jacobian_single_row, jacobian_matrix) in enumerate(
2094+
zip(jacobian_single_rows, jacobian_matrices, strict=True)
2095+
):
2096+
if jacobian_single_row.ndim == jacobian_matrix.ndim:
2097+
jacobian_matrices[i] = broadcast_to(
2098+
jacobian_matrix, shape=(n_rows, *jacobian_matrix.shape)
2099+
)
20842100

2085-
def inner_function(*args):
2086-
idx = args[0]
2087-
expr = args[1]
2088-
rvals = []
2089-
for inp in args[2:]:
2090-
rval = grad(
2091-
expr[idx],
2092-
inp,
2093-
consider_constant=consider_constant,
2094-
disconnected_inputs=disconnected_inputs,
2101+
else:
2102+
2103+
def inner_function(*args):
2104+
idx, expr, *wrt = args
2105+
return grad(expr[idx], wrt, **grad_kwargs)
2106+
2107+
jacobian_matrices, updates = pytensor.scan(
2108+
inner_function,
2109+
sequences=pytensor.tensor.arange(expression.size),
2110+
non_sequences=[expression.ravel(), *wrt],
2111+
return_list=True,
2112+
)
2113+
if updates:
2114+
raise ValueError(
2115+
"The scan used to build the jacobian matrices returned a list of updates"
20952116
)
2096-
rvals.append(rval)
2097-
return rvals
2098-
2099-
# Computing the gradients does not affect the random seeds on any random
2100-
# generator used n expression (because during computing gradients we are
2101-
# just backtracking over old values. (rp Jan 2012 - if anyone has a
2102-
# counter example please show me)
2103-
jacobs, updates = pytensor.scan(
2104-
inner_function,
2105-
sequences=pytensor.tensor.arange(expression.shape[0]),
2106-
non_sequences=[expression, *wrt],
2107-
)
2108-
assert not updates, "Scan has returned a list of updates; this should not happen."
2109-
return as_list_or_tuple(using_list, using_tuple, jacobs)
2117+
2118+
if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim):
2119+
# There was some raveling or squeezing done prior to getting the jacobians
2120+
# Reshape into original shapes
2121+
jacobian_matrices = [
2122+
jac_matrix.reshape((*expression.shape, *w.shape))
2123+
for jac_matrix, w in zip(jacobian_matrices, wrt, strict=True)
2124+
]
2125+
2126+
return as_list_or_tuple(using_list, using_tuple, jacobian_matrices)
21102127

21112128

21122129
def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
30813081
else:
30823082
dims = (-1,)
30833083

3084+
if len(dims) == _x.ndim:
3085+
# Nothing to ravel
3086+
return _x
3087+
30843088
x_reshaped = _x.reshape(dims)
30853089
shape_kept_dims = _x.type.shape[: ndim - 1]
30863090
bcast_new_dim = builtins.all(s == 1 for s in _x.type.shape[ndim - 1 :])

0 commit comments

Comments
 (0)