Skip to content

Commit 5fd729d

Browse files
ricardoV94aseyboldt
andcommitted
Add helper to build hessian vector product
Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
1 parent db1c161 commit 5fd729d

File tree

3 files changed

+128
-0
lines changed

3 files changed

+128
-0
lines changed

doc/tutorial/gradients.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,16 @@ or, making use of the R-operator:
267267
>>> f([4, 4], [2, 2])
268268
array([ 4., 4.])
269269

270+
There is a builtin helper that uses the first method
271+
272+
>>> x = pt.dvector('x')
273+
>>> v = pt.dvector('v')
274+
>>> y = pt.sum(x ** 2)
275+
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
276+
>>> f = pytensor.function([x, v], Hv)
277+
>>> f([4, 4], [2, 2])
278+
array([ 4., 4.])
279+
270280

271281
Final Pointers
272282
==============

pytensor/gradient.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2050,6 +2050,85 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
20502050
return as_list_or_tuple(using_list, using_tuple, hessians)
20512051

20522052

2053+
def hessian_vector_product(cost, wrt, p, **grad_kwargs):
2054+
"""Return the expression of the Hessian times a vector p.
2055+
2056+
Notes
2057+
-----
2058+
This function uses backward autodiff twice to obtain the desired expression.
2059+
You may want to manually build the equivalent expression by combining backward
2060+
followed by forward (if all Ops support it) autodiff.
2061+
See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this.
2062+
2063+
Parameters
2064+
----------
2065+
cost: Scalar (0-dimensional) variable.
2066+
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2067+
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
2068+
Each vector will be used for the hessp wirt to exach input variable
2069+
**grad_kwargs:
2070+
Keyword arguments passed to `grad` function.
2071+
2072+
Returns
2073+
-------
2074+
:class:` Vector or list of Vectors
2075+
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
2076+
2077+
Examples
2078+
--------
2079+
2080+
.. testcode::
2081+
2082+
import numpy as np
2083+
from scipy.optimize import minimize
2084+
from pytensor import function
2085+
from pytensor.tensor import vector
2086+
from pytensor.gradient import grad, hessian_vector_product
2087+
2088+
x = vector('x')
2089+
p = vector('p')
2090+
2091+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
2092+
rosen_jac = grad(rosen, x)
2093+
rosen_hessp = hessian_vector_product(rosen, x, p)
2094+
2095+
rosen_fn = function([x], rosen)
2096+
rosen_jac_fn = function([x], rosen_jac)
2097+
rosen_hessp_fn = function([x, p], rosen_hessp)
2098+
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
2099+
res = minimize(
2100+
rosen_fn,
2101+
x0,
2102+
method="Newton-CG",
2103+
jac=rosen_jac_fn,
2104+
hessp=rosen_hessp_fn,
2105+
options={"xtol": 1e-8},
2106+
)
2107+
print(res.x)
2108+
2109+
.. testoutput::
2110+
2111+
[1. 1. 1. 0.99999999 0.99999999]
2112+
2113+
2114+
2115+
"""
2116+
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
2117+
p_list = p if isinstance(p, Sequence) else [p]
2118+
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
2119+
hessian_cost = pytensor.tensor.add(
2120+
*[
2121+
(grad_wrt * p).sum()
2122+
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
2123+
]
2124+
)
2125+
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)
2126+
2127+
if isinstance(wrt, Variable):
2128+
return Hp_list[0]
2129+
return Hp_list
2130+
2131+
20532132
def _is_zero(x):
20542133
"""
20552134
Returns 'yes', 'no', or 'maybe' indicating whether x

tests/test_gradient.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from scipy.optimize import rosen_hess_prod
34

45
import pytensor
56
import pytensor.tensor.basic as ptb
@@ -20,6 +21,7 @@
2021
grad_scale,
2122
grad_undefined,
2223
hessian,
24+
hessian_vector_product,
2325
jacobian,
2426
subgraph_grad,
2527
zero_grad,
@@ -1079,3 +1081,40 @@ def test_jacobian_disconnected_inputs():
10791081
func_s = pytensor.function([s2], jacobian_s)
10801082
val = np.array(1.0).astype(pytensor.config.floatX)
10811083
assert np.allclose(func_s(val), np.zeros(1))
1084+
1085+
1086+
class TestHessianVectorProdudoct:
1087+
def test_rosen(self):
1088+
x = vector("x", dtype="float64")
1089+
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
1090+
1091+
p = vector("p", dtype="float64")
1092+
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)
1093+
1094+
x_test = 0.1 * np.arange(9)
1095+
p_test = 0.5 * np.arange(9)
1096+
np.testing.assert_allclose(
1097+
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
1098+
rosen_hess_prod(x_test, p_test),
1099+
)
1100+
1101+
def test_multiple_wrt(self):
1102+
x = vector("x", dtype="float64")
1103+
y = vector("y", dtype="float64")
1104+
p_x = vector("p_x", dtype="float64")
1105+
p_y = vector("p_y", dtype="float64")
1106+
1107+
cost = (x**2 - y**2).sum()
1108+
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])
1109+
1110+
hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
1111+
test = {
1112+
# x, y don't matter
1113+
"x": np.full((3,), np.nan),
1114+
"y": np.full((3,), np.nan),
1115+
"p_x": [1, 2, 3],
1116+
"p_y": [3, 2, 1],
1117+
}
1118+
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
1119+
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
1120+
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])

0 commit comments

Comments
 (0)