Skip to content

Commit 93ba789

Browse files
committed
Apply rewrites to logp before apply grad
This adds safe rewrites to logp before the grad operator is applied. This is motivated by #6717, where expensive `cholesky(L.dot(L.T))` operations are removed. If these remain in the logp graph when the grad is taken, the resulting dlogp graph contains unnecessary operations. However this may improve the stability and performance of grad logp in other situation.
1 parent 7b08fc1 commit 93ba789

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

pymc/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
hessian,
7676
inputvars,
7777
replace_rvs_by_values,
78+
rewrite_pregrad,
7879
)
7980
from pymc.util import (
8081
UNSET,
@@ -381,6 +382,8 @@ def __init__(
381382
self._extra_vars_shared[var.name] = shared
382383
givens.append((var, shared))
383384

385+
cost = rewrite_pregrad(cost)
386+
384387
if compute_grads:
385388
grads = pytensor.grad(cost, grad_vars, disconnected_inputs="ignore")
386389
for grad_wrt, var in zip(grads, grad_vars):
@@ -824,6 +827,7 @@ def dlogp(
824827
)
825828

826829
cost = self.logp(jacobian=jacobian)
830+
cost = rewrite_pregrad(cost)
827831
return gradient(cost, value_vars)
828832

829833
def d2logp(
@@ -862,6 +866,7 @@ def d2logp(
862866
)
863867

864868
cost = self.logp(jacobian=jacobian)
869+
cost = rewrite_pregrad(cost)
865870
return hessian(cost, value_vars)
866871

867872
@property

pymc/pytensorf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,3 +1228,10 @@ def constant_fold(
12281228
return tuple(
12291229
folded_x.data if isinstance(folded_x, Constant) else folded_x for folded_x in folded_xs
12301230
)
1231+
1232+
1233+
def rewrite_pregrad(graph):
1234+
"""Apply simplifying or stabilizing rewrites to graph that are safe to use
1235+
pre-grad.
1236+
"""
1237+
return rewrite_graph(graph, include=("canonicalize", "stabilize"))

0 commit comments

Comments
 (0)