diff --git a/pymc_extras/inference/pathfinder/lbfgs.py b/pymc_extras/inference/pathfinder/lbfgs.py index 8f9df2659..b567c05cc 100644 --- a/pymc_extras/inference/pathfinder/lbfgs.py +++ b/pymc_extras/inference/pathfinder/lbfgs.py @@ -37,11 +37,14 @@ class LBFGSHistoryManager: initial position maxiter : int maximum number of iterations to store + epsilon : float + tolerance for lbfgs update """ value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]] x0: NDArray[np.float64] maxiter: int + epsilon: float x_history: NDArray[np.float64] = field(init=False) g_history: NDArray[np.float64] = field(init=False) count: int = field(init=False) @@ -85,10 +88,9 @@ def entry_condition_met(self, x, value, grad) -> bool: s = x - self.x_history[self.count - 1] z = grad - self.g_history[self.count - 1] sz = (s * z).sum(axis=-1) - epsilon = 1e-8 - update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1)) + update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1)) - if update_mask: + if update: return True else: return False @@ -105,10 +107,10 @@ class LBFGSStatus(Enum): CONVERGED = auto() MAX_ITER_REACHED = auto() NON_FINITE = auto() - LOW_UPDATE_MASK_RATIO = auto() + LOW_UPDATE_PCT = auto() # Statuses that lead to Exceptions: INIT_FAILED = auto() - INIT_FAILED_LOW_UPDATE_MASK = auto() + INIT_FAILED_LOW_UPDATE_PCT = auto() LBFGS_FAILED = auto() @@ -144,10 +146,12 @@ class LBFGS: gradient tolerance for convergence, defaults to 1e-8 maxls : int, optional maximum number of line search steps, defaults to 1000 + epsilon : float, optional + tolerance for lbfgs update, defaults to 1e-8 """ def __init__( - self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000 + self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8 ) -> None: self.value_grad_fn = value_grad_fn self.maxcor = maxcor @@ -155,6 +159,7 @@ def __init__( self.ftol = ftol self.gtol = gtol self.maxls = maxls + self.epsilon = epsilon def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: """minimizes objective function starting from initial position. @@ -179,7 +184,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: x0 = np.array(x0, dtype=np.float64) history_manager = LBFGSHistoryManager( - value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter + value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon ) result = minimize( @@ -199,24 +204,22 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: history = history_manager.get_history() # warnings and suggestions for LBFGSStatus are displayed at the end - # threshold determining if the number of update mask is low compared to iterations + # threshold determining if the number of lbfgs updates is low compared to iterations low_update_threshold = 3 - logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}") - if history.count <= 1: # triggers LBFGSInitFailed if result.nit < low_update_threshold: lbfgs_status = LBFGSStatus.INIT_FAILED else: - lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK + lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT elif result.status == 1: # (result.nit > maxiter) or (result.nit > maxls) lbfgs_status = LBFGSStatus.MAX_ITER_REACHED elif result.status == 2: # precision loss resulting to inf or nan lbfgs_status = LBFGSStatus.NON_FINITE - elif history.count < low_update_threshold * result.nit: - lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO + elif history.count * low_update_threshold < result.nit: + lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT else: lbfgs_status = LBFGSStatus.CONVERGED diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 19a4e2d8d..cddc175ba 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -237,8 +237,8 @@ def convert_flat_trace_to_idata( def alpha_recover( - x: TensorVariable, g: TensorVariable, epsilon: TensorVariable -) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]: + x: TensorVariable, g: TensorVariable +) -> tuple[TensorVariable, TensorVariable, TensorVariable]: """compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates. Parameters @@ -247,9 +247,6 @@ def alpha_recover( position array, shape (L+1, N) g : TensorVariable gradient array, shape (L+1, N) - epsilon : float - threshold for filtering updates based on inner product of position - and gradient differences Returns ------- @@ -259,15 +256,13 @@ def alpha_recover( position differences, shape (L, N) z : TensorVariable gradient differences, shape (L, N) - update_mask : TensorVariable - mask for filtering updates, shape (L,) Notes ----- shapes: L=batch_size, N=num_params """ - def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable: + def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable: # alpha_lm1: (N,) # s_l: (N,) # z_l: (N,) @@ -281,43 +276,28 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable: ) # fmt:off return 1.0 / inv_alpha_l - def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable: - return alpha_lm1[-1] - - def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable: - return pt.switch( - update_mask_l, - compute_alpha_l(alpha_lm1, s_l, z_l), - return_alpha_lm1(alpha_lm1, s_l, z_l), - ) - Lp1, N = x.shape s = pt.diff(x, axis=0) z = pt.diff(g, axis=0) alpha_l_init = pt.ones(N) - sz = (s * z).sum(axis=-1) - # update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1) - # pt.linalg.norm does not work with JAX!! - update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1)) alpha, _ = pytensor.scan( - fn=scan_body, + fn=compute_alpha_l, outputs_info=alpha_l_init, - sequences=[update_mask, s, z], + sequences=[s, z], n_steps=Lp1 - 1, allow_gc=False, ) # assert np.all(alpha.eval() > 0), "alpha cannot be negative" - # alpha: (L, N), update_mask: (L, N) - return alpha, s, z, update_mask + # alpha: (L, N) + return alpha, s, z def inverse_hessian_factors( alpha: TensorVariable, s: TensorVariable, z: TensorVariable, - update_mask: TensorVariable, J: TensorConstant, ) -> tuple[TensorVariable, TensorVariable]: """compute the inverse hessian factors for the BFGS approximation. @@ -330,8 +310,6 @@ def inverse_hessian_factors( position differences, shape (L, N) z : TensorVariable gradient differences, shape (L, N) - update_mask : TensorVariable - mask for filtering updates, shape (L,) J : TensorConstant history size for L-BFGS @@ -350,30 +328,19 @@ def inverse_hessian_factors( # NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022) # NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented - def get_chi_matrix_1( - diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant - ) -> TensorVariable: + def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable: L, N = diff.shape j_last = pt.as_tensor(J - 1) # since indexing starts at 0 - def chi_update(chi_lm1, diff_l) -> TensorVariable: + def chi_update(diff_l, chi_lm1) -> TensorVariable: chi_l = pt.roll(chi_lm1, -1, axis=0) return pt.set_subtensor(chi_l[j_last], diff_l) - def no_op(chi_lm1, diff_l) -> TensorVariable: - return chi_lm1 - - def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable: - return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l)) - chi_init = pt.zeros((J, N)) chi_mat, _ = pytensor.scan( - fn=scan_body, + fn=chi_update, outputs_info=chi_init, - sequences=[ - update_mask, - diff, - ], + sequences=[diff], allow_gc=False, ) @@ -382,19 +349,15 @@ def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable: # (L, N, J) return chi_mat - def get_chi_matrix_2( - diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant - ) -> TensorVariable: + def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable: L, N = diff.shape - diff_masked = update_mask[:, None] * diff - # diff_padded: (L+J, N) pad_width = pt.zeros(shape=(2, 2), dtype="int32") - pad_width = pt.set_subtensor(pad_width[0, 0], J) - diff_padded = pt.pad(diff_masked, pad_width, mode="constant") + pad_width = pt.set_subtensor(pad_width[0, 0], J - 1) + diff_padded = pt.pad(diff, pad_width, mode="constant") - index = pt.arange(L)[:, None] + pt.arange(J)[None, :] + index = pt.arange(L)[..., None] + pt.arange(J)[None, ...] index = index.reshape((L, J)) chi_mat = pt.matrix_transpose(diff_padded[index]) @@ -403,8 +366,10 @@ def get_chi_matrix_2( return chi_mat L, N = alpha.shape - S = get_chi_matrix_1(s, update_mask, J) - Z = get_chi_matrix_1(z, update_mask, J) + + # changed to get_chi_matrix_2 after removing update_mask + S = get_chi_matrix_2(s, J) + Z = get_chi_matrix_2(z, J) # E: (L, J, J) Ij = pt.eye(J)[None, ...] @@ -785,7 +750,6 @@ def make_pathfinder_body( num_draws: int, maxcor: int, num_elbo_draws: int, - epsilon: float, **compile_kwargs: dict, ) -> Function: """ @@ -801,8 +765,6 @@ def make_pathfinder_body( The maximum number of iterations for the L-BFGS algorithm. num_elbo_draws : int The number of draws for the Evidence Lower Bound (ELBO) estimation. - epsilon : float - The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. compile_kwargs : dict Additional keyword arguments for the PyTensor compiler. @@ -827,11 +789,10 @@ def make_pathfinder_body( num_draws = pt.constant(num_draws, "num_draws", dtype="int32") num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32") - epsilon = pt.constant(epsilon, "epsilon", dtype="float64") maxcor = pt.constant(maxcor, "maxcor", dtype="int32") - alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon) - beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor) + alpha, s, z = alpha_recover(x_full, g_full) + beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor) # ignore initial point - x, g: (L, N) x = x_full[1:] @@ -941,11 +902,11 @@ def neg_logp_dlogp_func(x): x_base = DictToArrayBijection.map(ip).data # lbfgs - lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls) + lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon) # pathfinder body pathfinder_body_fn = make_pathfinder_body( - logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs + logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs ) rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs) @@ -957,7 +918,7 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult: x0 = x_base + jitter_value x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0) - if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}: + if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}: raise LBFGSInitFailed(lbfgs_status) elif lbfgs_status == LBFGSStatus.LBFGS_FAILED: raise LBFGSException() @@ -1399,8 +1360,8 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]: LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.", LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.", LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", - LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", - LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", } path_status_message = { diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index 6b30226f0..5e8773310 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -106,8 +106,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter): ) out, err = capsys.readouterr() status_pattern = [ - r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+", - r"LOW_UPDATE_MASK_RATIO\s+\d+", + r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+", + r"LOW_UPDATE_PCT\s+\d+", r"LBFGS_FAILED\s+\d+", r"SUCCESS\s+\d+", ] @@ -126,8 +126,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter): out, err = capsys.readouterr() status_pattern = [ - r"INIT_FAILED_LOW_UPDATE_MASK\s+2", - r"LOW_UPDATE_MASK_RATIO\s+2", + r"INIT_FAILED_LOW_UPDATE_PCT\s+2", + r"LOW_UPDATE_PCT\s+2", r"LBFGS_FAILED\s+4", ] for pattern in status_pattern: @@ -232,12 +232,11 @@ def test_bfgs_sample(): # get factors x_full = pt.as_tensor(x_data, dtype="float64") g_full = pt.as_tensor(g_data, dtype="float64") - epsilon = 1e-11 x = x_full[1:] g = g_full[1:] - alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon) - beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J) + alpha, s, z = alpha_recover(x_full, g_full) + beta, gamma = inverse_hessian_factors(alpha, s, z, J) # sample phi, logq = bfgs_sample( @@ -252,8 +251,8 @@ def test_bfgs_sample(): # check shapes assert beta.eval().shape == (L, N, 2 * J) assert gamma.eval().shape == (L, 2 * J, 2 * J) - assert phi.eval().shape == (L, num_samples, N) - assert logq.eval().shape == (L, num_samples) + assert all(phi.shape.eval() == (L, num_samples, N)) + assert all(logq.shape.eval() == (L, num_samples)) @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])