diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 66d081bde..ed30d22c5 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -15,7 +15,6 @@ import collections import logging import time -import warnings as _warnings from collections import Counter from collections.abc import Callable, Iterator @@ -40,7 +39,7 @@ from pymc.model import modelcontext from pymc.model.core import Point from pymc.pytensorf import ( - compile_pymc, + compile, find_rng_nodes, reseed_rngs, ) @@ -76,9 +75,6 @@ ) logger = logging.getLogger(__name__) -_warnings.filterwarnings( - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" -) REGULARISATION_TERM = 1e-8 DEFAULT_LINKER = "cvm_nogc" @@ -142,7 +138,7 @@ def get_logp_dlogp_of_ravel_inputs( [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)], model.value_vars, ) - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) logp_dlogp_fn.trust_input = True return logp_dlogp_fn @@ -502,9 +498,10 @@ def bfgs_sample_dense( logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) - with _warnings.catch_warnings(): - _warnings.simplefilter("ignore", category=FutureWarning) - mu = x - pt.batched_dot(H_inv, g) + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g + + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) phi = pt.matrix_transpose( # (L, N, 1) @@ -573,17 +570,16 @@ def bfgs_sample_sparse( logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) logdet += pt.sum(pt.log(alpha), axis=-1) + # inverse Hessian + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) + # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version. - with _warnings.catch_warnings(): - _warnings.simplefilter("ignore", category=FutureWarning) - mu = x - ( - # (L, N), (L, N) -> (L, N) - pt.batched_dot(alpha_diag, g) - # beta @ gamma @ beta.T - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) - # (L, N, N), (L, N) -> (L, N) - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) - ) + + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g + + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) phi = pt.matrix_transpose( # (L, N, 1) @@ -857,7 +853,7 @@ def make_pathfinder_body( # return psi, logP_psi, logQ_psi, elbo_argmax - pathfinder_body_fn = compile_pymc( + pathfinder_body_fn = compile( [x_full, g_full], [psi, logP_psi, logQ_psi, elbo_argmax], **compile_kwargs, diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index ea9d77480..b2f4b8158 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -18,10 +18,6 @@ import pymc as pm import pytest -pytestmark = pytest.mark.filterwarnings( - "ignore:compile_pymc was renamed to compile:FutureWarning", -) - import pymc_extras as pmx @@ -55,6 +51,7 @@ def reference_idata(): @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"]) +@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning") def test_pathfinder(inference_backend, reference_idata): if inference_backend == "blackjax" and sys.platform == "win32": pytest.skip("JAX not supported on windows")