Skip to content

Fix LBFGS iteration conditions and status handling #461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions pymc_extras/inference/pathfinder/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __post_init__(self) -> None:
self.count = 0

value, grad = self.value_grad_fn(self.x0)
if np.all(np.isfinite(grad)) and np.isfinite(value):
if self.entry_condition_met(self.x0, value, grad):
self.add_entry(self.x0, grad)

def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
Expand All @@ -75,18 +75,40 @@ def get_history(self) -> LBFGSHistory:
x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
)

def entry_condition_met(self, x, value, grad) -> bool:
"""Checks if the LBFGS iteration should continue."""

if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1):
if self.count == 0:
return True
else:
s = x - self.x_history[self.count - 1]
z = grad - self.g_history[self.count - 1]
sz = (s * z).sum(axis=-1)
epsilon = 1e-8
update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1))

if update_mask:
return True
else:
return False
else:
return False

def __call__(self, x: NDArray[np.float64]) -> None:
value, grad = self.value_grad_fn(x)
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
if self.entry_condition_met(x, value, grad):
self.add_entry(x, grad)


class LBFGSStatus(Enum):
CONVERGED = auto()
MAX_ITER_REACHED = auto()
DIVERGED = auto()
NON_FINITE = auto()
LOW_UPDATE_MASK_RATIO = auto()
# Statuses that lead to Exceptions:
INIT_FAILED = auto()
INIT_FAILED_LOW_UPDATE_MASK = auto()
LBFGS_FAILED = auto()


Expand All @@ -101,8 +123,8 @@ def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED)
class LBFGSInitFailed(LBFGSException):
DEFAULT_MESSAGE = "LBFGS failed to initialise."

def __init__(self, message=None):
super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
def __init__(self, status: LBFGSStatus, message=None):
super().__init__(message or self.DEFAULT_MESSAGE, status)


class LBFGS:
Expand Down Expand Up @@ -177,13 +199,24 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
history = history_manager.get_history()

# warnings and suggestions for LBFGSStatus are displayed at the end
if result.status == 1:
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
elif (result.status == 2) or (history.count <= 1):
if result.nit <= 1:
# threshold determining if the number of update mask is low compared to iterations
low_update_threshold = 3

logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}")

if history.count <= 1: # triggers LBFGSInitFailed
if result.nit < low_update_threshold:
lbfgs_status = LBFGSStatus.INIT_FAILED
elif result.fun == np.inf:
lbfgs_status = LBFGSStatus.DIVERGED
else:
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK
elif result.status == 1:
# (result.nit > maxiter) or (result.nit > maxls)
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
elif result.status == 2:
# precision loss resulting to inf or nan
lbfgs_status = LBFGSStatus.NON_FINITE
elif history.count < low_update_threshold * result.nit:
lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO
else:
lbfgs_status = LBFGSStatus.CONVERGED

Expand Down
19 changes: 10 additions & 9 deletions pymc_extras/inference/pathfinder/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,8 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult:
x0 = x_base + jitter_value
x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0)

if lbfgs_status == LBFGSStatus.INIT_FAILED:
raise LBFGSInitFailed()
if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}:
raise LBFGSInitFailed(lbfgs_status)
elif lbfgs_status == LBFGSStatus.LBFGS_FAILED:
raise LBFGSException()

Expand Down Expand Up @@ -1391,15 +1391,16 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]:
warnings = []

lbfgs_status_message = {
LBFGSStatus.MAX_ITER_REACHED: "LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
LBFGSStatus.INIT_FAILED: "LBFGS failed to initialise. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
LBFGSStatus.DIVERGED: "LBFGS diverged to infinity. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.",
LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.",
LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.",
}

path_status_message = {
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter is may be too close to the mean posterior and a poor exploration of the parameter space. Consider increasing jitter if this occurence is high relative to the number of paths.",
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
PathStatus.INVALID_LOGQ: "Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
PathStatus.ELBO_ARGMAX_AT_ZERO: "ELBO_ARGMAX_AT_ZERO: ELBO argmax at zero refers to the first iteration during LBFGS. A high occurrence suggests the model's default initial point + jitter values are concentrated in high-density regions in the target distribution and may result in poor exploration of the parameter space. Consider increasing jitter if this occurrence is high relative to the number of paths.",
PathStatus.INVALID_LOGQ: "INVALID_LOGQ: Invalid logQ values occur when a path's logQ values are not finite. The failed path is not included in samples when importance sampling is used. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.",
}

for lbfgs_status in mpr.lbfgs_status:
Expand Down Expand Up @@ -1621,7 +1622,7 @@ def fit_pathfinder(
maxiter: int = 1000, # L^max
ftol: float = 1e-5,
gtol: float = 1e-8,
maxls=1000,
maxls: int = 1000,
num_elbo_draws: int = 10, # K
jitter: float = 2.0,
epsilon: float = 1e-8,
Expand Down
84 changes: 84 additions & 0 deletions tests/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import sys

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

import pymc_extras as pmx
Expand Down Expand Up @@ -50,6 +52,88 @@ def reference_idata():
return idata


def unstable_lbfgs_update_mask_model() -> pm.Model:
# data and model from: https://github.com/pymc-devs/pymc-extras/issues/445
# this scenario made LBFGS struggle leading to a lot of rejected iterations, (result.nit being moderate, but only history.count <= 1).
# this scenario is used to test that the LBFGS history manager is rejecting iterations as expected and PF can run to completion.

# fmt: off
inp = np.array([0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 2, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1, 1, 1, 2, 0, 1, 2, 1, 0, 1, 0, 1, 0, 1, 0])

res = np.array([[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[0,1,0,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,0,1,0,0],[0,1,0,0,0],[0,1,0,0,0],[0,0,0,1,0],[0,0,1,0,0],[0,1,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,1,0,0],[1,0,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,0,1,0],[0,0,0,1,0],[1,0,0,0,0],[1,0,0,0,0],[0,1,0,0,0],[1,0,0,0,0],[0,0,1,0,0],[0,0,1,0,0],[1,0,0,0,0],[0,0,0,1,0]])
# fmt: on

n_ordered = res.shape[1]
coords = {
"obs": np.arange(len(inp)),
"inp": np.arange(max(inp) + 1),
"outp": np.arange(res.shape[1]),
}
with pm.Model(coords=coords) as mdl:
mu = pm.Normal("intercept", sigma=3.5)[None]

offset = pm.Normal(
"offset", dims=("inp"), transform=pm.distributions.transforms.ZeroSumTransform([0])
)

scale = 3.5 * pm.HalfStudentT("scale", nu=5)
mu += (scale * offset)[inp]

phi_delta = pm.Dirichlet("phi_diffs", [1.0] * (n_ordered - 1))
phi = pt.concatenate([[0], pt.cumsum(phi_delta)])
s_mu = pm.Normal(
"stereotype_intercept",
size=n_ordered,
transform=pm.distributions.transforms.ZeroSumTransform([-1]),
)
fprobs = pm.math.softmax(s_mu[None, :] + phi[None, :] * mu[:, None], axis=-1)

pm.Multinomial("y_res", p=fprobs, n=np.ones(len(inp)), observed=res, dims=("obs", "outp"))

return mdl


@pytest.mark.parametrize("jitter", [12.0, 500.0, 1000.0])
def test_unstable_lbfgs_update_mask(capsys, jitter):
model = unstable_lbfgs_update_mask_model()

if jitter < 1000:
with model:
idata = pmx.fit(
method="pathfinder",
jitter=jitter,
random_seed=4,
)
out, err = capsys.readouterr()
status_pattern = [
r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+",
r"LOW_UPDATE_MASK_RATIO\s+\d+",
r"LBFGS_FAILED\s+\d+",
r"SUCCESS\s+\d+",
]
for pattern in status_pattern:
assert re.search(pattern, out) is not None

else:
with pytest.raises(ValueError, match="All paths failed"):
with model:
idata = pmx.fit(
method="pathfinder",
jitter=1000,
random_seed=2,
num_paths=4,
)
out, err = capsys.readouterr()

status_pattern = [
r"INIT_FAILED_LOW_UPDATE_MASK\s+2",
r"LOW_UPDATE_MASK_RATIO\s+2",
r"LBFGS_FAILED\s+4",
]
for pattern in status_pattern:
assert re.search(pattern, out) is not None


@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
def test_pathfinder(inference_backend, reference_idata):
Expand Down
Loading