diff --git a/pymc_extras/inference/pathfinder/lbfgs.py b/pymc_extras/inference/pathfinder/lbfgs.py index ce73f8f30..8f9df2659 100644 --- a/pymc_extras/inference/pathfinder/lbfgs.py +++ b/pymc_extras/inference/pathfinder/lbfgs.py @@ -52,7 +52,7 @@ def __post_init__(self) -> None: self.count = 0 value, grad = self.value_grad_fn(self.x0) - if np.all(np.isfinite(grad)) and np.isfinite(value): + if self.entry_condition_met(self.x0, value, grad): self.add_entry(self.x0, grad) def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None: @@ -75,18 +75,40 @@ def get_history(self) -> LBFGSHistory: x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count ) + def entry_condition_met(self, x, value, grad) -> bool: + """Checks if the LBFGS iteration should continue.""" + + if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1): + if self.count == 0: + return True + else: + s = x - self.x_history[self.count - 1] + z = grad - self.g_history[self.count - 1] + sz = (s * z).sum(axis=-1) + epsilon = 1e-8 + update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1)) + + if update_mask: + return True + else: + return False + else: + return False + def __call__(self, x: NDArray[np.float64]) -> None: value, grad = self.value_grad_fn(x) - if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1: + if self.entry_condition_met(x, value, grad): self.add_entry(x, grad) class LBFGSStatus(Enum): CONVERGED = auto() MAX_ITER_REACHED = auto() - DIVERGED = auto() + NON_FINITE = auto() + LOW_UPDATE_MASK_RATIO = auto() # Statuses that lead to Exceptions: INIT_FAILED = auto() + INIT_FAILED_LOW_UPDATE_MASK = auto() LBFGS_FAILED = auto() @@ -101,8 +123,8 @@ def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED) class LBFGSInitFailed(LBFGSException): DEFAULT_MESSAGE = "LBFGS failed to initialise." - def __init__(self, message=None): - super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED) + def __init__(self, status: LBFGSStatus, message=None): + super().__init__(message or self.DEFAULT_MESSAGE, status) class LBFGS: @@ -177,13 +199,24 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: history = history_manager.get_history() # warnings and suggestions for LBFGSStatus are displayed at the end - if result.status == 1: - lbfgs_status = LBFGSStatus.MAX_ITER_REACHED - elif (result.status == 2) or (history.count <= 1): - if result.nit <= 1: + # threshold determining if the number of update mask is low compared to iterations + low_update_threshold = 3 + + logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}") + + if history.count <= 1: # triggers LBFGSInitFailed + if result.nit < low_update_threshold: lbfgs_status = LBFGSStatus.INIT_FAILED - elif result.fun == np.inf: - lbfgs_status = LBFGSStatus.DIVERGED + else: + lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK + elif result.status == 1: + # (result.nit > maxiter) or (result.nit > maxls) + lbfgs_status = LBFGSStatus.MAX_ITER_REACHED + elif result.status == 2: + # precision loss resulting to inf or nan + lbfgs_status = LBFGSStatus.NON_FINITE + elif history.count < low_update_threshold * result.nit: + lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO else: lbfgs_status = LBFGSStatus.CONVERGED diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 846c00fa9..ef2b6679b 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -952,8 +952,8 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult: x0 = x_base + jitter_value x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0) - if lbfgs_status == LBFGSStatus.INIT_FAILED: - raise LBFGSInitFailed() + if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}: + raise LBFGSInitFailed(lbfgs_status) elif lbfgs_status == LBFGSStatus.LBFGS_FAILED: raise LBFGSException() @@ -1391,15 +1391,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]: warnings = [] lbfgs_status_message = { - LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.", - LBFGSStatus.INIT_FAILED: "LBFGS failed to initialise. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.", - LBFGSStatus.DIVERGED: "LBFGS diverged to infinity. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.", + LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.", + LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", } path_status_message = { - PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.", - PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.", - PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", + PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.", + PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", } for lbfgs_status in mpr.lbfgs_status: @@ -1621,7 +1622,7 @@ def fit_pathfinder( maxiter: int = 1000, # L^max ftol: float = 1e-5, gtol: float = 1e-8, - maxls=1000, + maxls: int = 1000, num_elbo_draws: int = 10, # K jitter: float = 2.0, epsilon: float = 1e-8, diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index b0776a1ca..6b30226f0 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import sys import numpy as np import pymc as pm +import pytensor.tensor as pt import pytest import pymc_extras as pmx @@ -50,6 +52,88 @@ def reference_idata(): return idata +def unstable_lbfgs_update_mask_model() -> pm.Model: + # data and model from: https://github.com/pymc-devs/pymc-extras/issues/445 + # this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1). + # this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion. + + # fmt: off + inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0]) + + res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]]) + # fmt: on + + n_ordered = res.shape[1] + coords = { + "obs": np.arange(len(inp)), + "inp": np.arange(max(inp) + 1), + "outp": np.arange(res.shape[1]), + } + with pm.Model(coords=coords) as mdl: + mu = pm.Normal("intercept", sigma=3.5)[None] + + offset = pm.Normal( + "offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0]) + ) + + scale = 3.5 * pm.HalfStudentT("scale", nu=5) + mu += (scale * offset)[inp] + + phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1)) + phi = pt.concatenate([[0], pt.cumsum(phi_delta)]) + s_mu = pm.Normal( + "stereotype_intercept", + size=n_ordered, + transform=pm.distributions.transforms.ZeroSumTransform([-1]), + ) + fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1) + + pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp")) + + return mdl + + +@pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0]) +def test_unstable_lbfgs_update_mask(capsys, jitter): + model = unstable_lbfgs_update_mask_model() + + if jitter < 1000: + with model: + idata = pmx.fit( + method="pathfinder", + jitter=jitter, + random_seed=4, + ) + out, err = capsys.readouterr() + status_pattern = [ + r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+", + r"LOW_UPDATE_MASK_RATIO\s+\d+", + r"LBFGS_FAILED\s+\d+", + r"SUCCESS\s+\d+", + ] + for pattern in status_pattern: + assert re.search(pattern, out) is not None + + else: + with pytest.raises(ValueError, match="All paths failed"): + with model: + idata = pmx.fit( + method="pathfinder", + jitter=1000, + random_seed=2, + num_paths=4, + ) + out, err = capsys.readouterr() + + status_pattern = [ + r"INIT_FAILED_LOW_UPDATE_MASK\s+2", + r"LOW_UPDATE_MASK_RATIO\s+2", + r"LBFGS_FAILED\s+4", + ] + for pattern in status_pattern: + assert re.search(pattern, out) is not None + + @pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"]) @pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning") def test_pathfinder(inference_backend, reference_idata):