From 681b93bf14df983898f4d516cb3da1b8723b2718 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sat, 19 Apr 2025 04:26:37 +1000 Subject: [PATCH 1/4] Refactor alpha_recover and inverse_hessian_factors to remove update_mask parameter - Removed the update_mask variable from alpha_recover and inverse_hessian_factors functions. - Simplified the logic in alpha_recover by directly computing alpha without filtering updates. - Changes should offer speed-ups by reducing reliance on scan functions, and perform vectorised operations. --- pymc_extras/inference/pathfinder/lbfgs.py | 8 ++- .../inference/pathfinder/pathfinder.py | 51 ++++--------------- 2 files changed, 17 insertions(+), 42 deletions(-) diff --git a/pymc_extras/inference/pathfinder/lbfgs.py b/pymc_extras/inference/pathfinder/lbfgs.py index 8f9df2659..61e5c21b2 100644 --- a/pymc_extras/inference/pathfinder/lbfgs.py +++ b/pymc_extras/inference/pathfinder/lbfgs.py @@ -37,11 +37,14 @@ class LBFGSHistoryManager: initial position maxiter : int maximum number of iterations to store + epsilon : float + tolerance for lbfgs update """ value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]] x0: NDArray[np.float64] maxiter: int + epsilon: float x_history: NDArray[np.float64] = field(init=False) g_history: NDArray[np.float64] = field(init=False) count: int = field(init=False) @@ -147,7 +150,7 @@ class LBFGS: """ def __init__( - self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000 + self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8 ) -> None: self.value_grad_fn = value_grad_fn self.maxcor = maxcor @@ -155,6 +158,7 @@ def __init__( self.ftol = ftol self.gtol = gtol self.maxls = maxls + self.epsilon = epsilon def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: """minimizes objective function starting from initial position. @@ -179,7 +183,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: x0 = np.array(x0, dtype=np.float64) history_manager = LBFGSHistoryManager( - value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter + value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon ) result = minimize( diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 19a4e2d8d..4d4cda664 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -259,8 +259,6 @@ def alpha_recover( position differences, shape (L, N) z : TensorVariable gradient differences, shape (L, N) - update_mask : TensorVariable - mask for filtering updates, shape (L,) Notes ----- @@ -281,43 +279,28 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable: ) # fmt:off return 1.0 / inv_alpha_l - def return_alpha_lm1(alpha_lm1, s_l, z_l) -> TensorVariable: - return alpha_lm1[-1] - - def scan_body(update_mask_l, s_l, z_l, alpha_lm1) -> TensorVariable: - return pt.switch( - update_mask_l, - compute_alpha_l(alpha_lm1, s_l, z_l), - return_alpha_lm1(alpha_lm1, s_l, z_l), - ) - Lp1, N = x.shape s = pt.diff(x, axis=0) z = pt.diff(g, axis=0) alpha_l_init = pt.ones(N) - sz = (s * z).sum(axis=-1) - # update_mask = sz > epsilon * pt.linalg.norm(z, axis=-1) - # pt.linalg.norm does not work with JAX!! - update_mask = sz > epsilon * pt.sqrt(pt.sum(z**2, axis=-1)) alpha, _ = pytensor.scan( - fn=scan_body, + fn=compute_alpha_l, outputs_info=alpha_l_init, - sequences=[update_mask, s, z], + sequences=[s, z], n_steps=Lp1 - 1, allow_gc=False, ) # assert np.all(alpha.eval() > 0), "alpha cannot be negative" # alpha: (L, N), update_mask: (L, N) - return alpha, s, z, update_mask + return alpha, s, z def inverse_hessian_factors( alpha: TensorVariable, s: TensorVariable, z: TensorVariable, - update_mask: TensorVariable, J: TensorConstant, ) -> tuple[TensorVariable, TensorVariable]: """compute the inverse hessian factors for the BFGS approximation. @@ -330,8 +313,6 @@ def inverse_hessian_factors( position differences, shape (L, N) z : TensorVariable gradient differences, shape (L, N) - update_mask : TensorVariable - mask for filtering updates, shape (L,) J : TensorConstant history size for L-BFGS @@ -350,9 +331,8 @@ def inverse_hessian_factors( # NOTE: get_chi_matrix_1 is a modified version of get_chi_matrix_2 to closely follow Zhang et al., (2022) # NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented - def get_chi_matrix_1( - diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant - ) -> TensorVariable: + def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable: + # TODO: vectorize this! L, N = diff.shape j_last = pt.as_tensor(J - 1) # since indexing starts at 0 @@ -360,20 +340,11 @@ def chi_update(chi_lm1, diff_l) -> TensorVariable: chi_l = pt.roll(chi_lm1, -1, axis=0) return pt.set_subtensor(chi_l[j_last], diff_l) - def no_op(chi_lm1, diff_l) -> TensorVariable: - return chi_lm1 - - def scan_body(update_mask_l, diff_l, chi_lm1) -> TensorVariable: - return pt.switch(update_mask_l, chi_update(chi_lm1, diff_l), no_op(chi_lm1, diff_l)) - chi_init = pt.zeros((J, N)) chi_mat, _ = pytensor.scan( - fn=scan_body, + fn=chi_update, outputs_info=chi_init, - sequences=[ - update_mask, - diff, - ], + sequences=[diff], allow_gc=False, ) @@ -403,8 +374,8 @@ def get_chi_matrix_2( return chi_mat L, N = alpha.shape - S = get_chi_matrix_1(s, update_mask, J) - Z = get_chi_matrix_1(z, update_mask, J) + S = get_chi_matrix_1(s, J) + Z = get_chi_matrix_1(z, J) # E: (L, J, J) Ij = pt.eye(J)[None, ...] @@ -830,8 +801,8 @@ def make_pathfinder_body( epsilon = pt.constant(epsilon, "epsilon", dtype="float64") maxcor = pt.constant(maxcor, "maxcor", dtype="int32") - alpha, s, z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon) - beta, gamma = inverse_hessian_factors(alpha, s, z, update_mask, J=maxcor) + alpha, s, z = alpha_recover(x_full, g_full, epsilon=epsilon) + beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor) # ignore initial point - x, g: (L, N) x = x_full[1:] From c850d8240a259379fada2fc5f207573f05b41bdb Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sun, 20 Apr 2025 00:48:24 +1000 Subject: [PATCH 2/4] WIP: tidying up and shorten var names --- pymc_extras/inference/pathfinder/lbfgs.py | 19 +++++++++---------- .../inference/pathfinder/pathfinder.py | 6 +++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pymc_extras/inference/pathfinder/lbfgs.py b/pymc_extras/inference/pathfinder/lbfgs.py index 61e5c21b2..e910f39db 100644 --- a/pymc_extras/inference/pathfinder/lbfgs.py +++ b/pymc_extras/inference/pathfinder/lbfgs.py @@ -88,10 +88,9 @@ def entry_condition_met(self, x, value, grad) -> bool: s = x - self.x_history[self.count - 1] z = grad - self.g_history[self.count - 1] sz = (s * z).sum(axis=-1) - epsilon = 1e-8 - update_mask = sz > epsilon * np.sqrt(np.sum(z**2, axis=-1)) + update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1)) - if update_mask: + if update: return True else: return False @@ -108,10 +107,10 @@ class LBFGSStatus(Enum): CONVERGED = auto() MAX_ITER_REACHED = auto() NON_FINITE = auto() - LOW_UPDATE_MASK_RATIO = auto() + LOW_UPDATE_PCT = auto() # Statuses that lead to Exceptions: INIT_FAILED = auto() - INIT_FAILED_LOW_UPDATE_MASK = auto() + INIT_FAILED_LOW_UPDATE_PCT = auto() LBFGS_FAILED = auto() @@ -147,6 +146,8 @@ class LBFGS: gradient tolerance for convergence, defaults to 1e-8 maxls : int, optional maximum number of line search steps, defaults to 1000 + epsilon : float, optional + tolerance for lbfgs update, defaults to 1e-8 """ def __init__( @@ -203,16 +204,14 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: history = history_manager.get_history() # warnings and suggestions for LBFGSStatus are displayed at the end - # threshold determining if the number of update mask is low compared to iterations + # threshold determining if the number of lbfgs updates is low compared to iterations low_update_threshold = 3 - logging.warning(f"LBFGS status: {result} \n nit: {result.nit} \n count: {history.count}") - if history.count <= 1: # triggers LBFGSInitFailed if result.nit < low_update_threshold: lbfgs_status = LBFGSStatus.INIT_FAILED else: - lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK + lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT elif result.status == 1: # (result.nit > maxiter) or (result.nit > maxls) lbfgs_status = LBFGSStatus.MAX_ITER_REACHED @@ -220,7 +219,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: # precision loss resulting to inf or nan lbfgs_status = LBFGSStatus.NON_FINITE elif history.count < low_update_threshold * result.nit: - lbfgs_status = LBFGSStatus.LOW_UPDATE_MASK_RATIO + lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT else: lbfgs_status = LBFGSStatus.CONVERGED diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 4d4cda664..f1e0991ad 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -928,7 +928,7 @@ def single_pathfinder_fn(random_seed: int) -> PathfinderResult: x0 = x_base + jitter_value x, g, lbfgs_niter, lbfgs_status = lbfgs.minimize(x0) - if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK}: + if lbfgs_status in {LBFGSStatus.INIT_FAILED, LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT}: raise LBFGSInitFailed(lbfgs_status) elif lbfgs_status == LBFGSStatus.LBFGS_FAILED: raise LBFGSException() @@ -1370,8 +1370,8 @@ def _get_status_warning(mpr: MultiPathfinderResult) -> list[str]: LBFGSStatus.MAX_ITER_REACHED: "MAX_ITER_REACHED: LBFGS maximum number of iterations reached. Consider increasing maxiter if this occurence is high relative to the number of paths.", LBFGSStatus.INIT_FAILED: "INIT_FAILED: LBFGS failed to initialize. Consider reparameterizing the model or reducing jitter if this occurence is high relative to the number of paths.", LBFGSStatus.NON_FINITE: "NON_FINITE: LBFGS objective function produced inf or nan at the last iteration. Consider reparameterizing the model or adjusting the pathfinder arguments if this occurence is high relative to the number of paths.", - LBFGSStatus.LOW_UPDATE_MASK_RATIO: "LOW_UPDATE_MASK_RATIO: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", - LBFGSStatus.INIT_FAILED_LOW_UPDATE_MASK: "INIT_FAILED_LOW_UPDATE_MASK: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.LOW_UPDATE_PCT: "LOW_UPDATE_PCT: Majority of LBFGS iterations were not accepted due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", + LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT: "INIT_FAILED_LOW_UPDATE_PCT: LBFGS failed to initialize due to the either: (1) LBFGS function or gradient values containing too many inf or nan values or (2) gradient changes being significantly large, set by epsilon. Consider reparameterizing the model, adjusting initvals or jitter or other pathfinder arguments if this occurence is high relative to the number of paths.", } path_status_message = { From 90bba5d1b6176b7791e86fcd2a8ff5961a4908b9 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sun, 20 Apr 2025 02:02:40 +1000 Subject: [PATCH 3/4] WIP: modified get_chi_matrix --- .../inference/pathfinder/pathfinder.py | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index f1e0991ad..9829e9cf5 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -237,8 +237,8 @@ def convert_flat_trace_to_idata( def alpha_recover( - x: TensorVariable, g: TensorVariable, epsilon: TensorVariable -) -> tuple[TensorVariable, TensorVariable, TensorVariable, TensorVariable]: + x: TensorVariable, g: TensorVariable +) -> tuple[TensorVariable, TensorVariable, TensorVariable]: """compute the diagonal elements of the inverse Hessian at each iterations of L-BFGS and filter updates. Parameters @@ -247,9 +247,6 @@ def alpha_recover( position array, shape (L+1, N) g : TensorVariable gradient array, shape (L+1, N) - epsilon : float - threshold for filtering updates based on inner product of position - and gradient differences Returns ------- @@ -332,11 +329,10 @@ def inverse_hessian_factors( # NOTE: get_chi_matrix_2 is from blackjax which MAYBE incorrectly implemented def get_chi_matrix_1(diff: TensorVariable, J: TensorConstant) -> TensorVariable: - # TODO: vectorize this! L, N = diff.shape j_last = pt.as_tensor(J - 1) # since indexing starts at 0 - def chi_update(chi_lm1, diff_l) -> TensorVariable: + def chi_update(diff_l, chi_lm1) -> TensorVariable: chi_l = pt.roll(chi_lm1, -1, axis=0) return pt.set_subtensor(chi_l[j_last], diff_l) @@ -353,19 +349,15 @@ def chi_update(chi_lm1, diff_l) -> TensorVariable: # (L, N, J) return chi_mat - def get_chi_matrix_2( - diff: TensorVariable, update_mask: TensorVariable, J: TensorConstant - ) -> TensorVariable: + def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable: L, N = diff.shape - diff_masked = update_mask[:, None] * diff - # diff_padded: (L+J, N) pad_width = pt.zeros(shape=(2, 2), dtype="int32") - pad_width = pt.set_subtensor(pad_width[0, 0], J) - diff_padded = pt.pad(diff_masked, pad_width, mode="constant") + pad_width = pt.set_subtensor(pad_width[0, 0], J - 1) + diff_padded = pt.pad(diff, pad_width, mode="constant") - index = pt.arange(L)[:, None] + pt.arange(J)[None, :] + index = pt.arange(L)[..., None] + pt.arange(J)[None, ...] index = index.reshape((L, J)) chi_mat = pt.matrix_transpose(diff_padded[index]) @@ -374,6 +366,8 @@ def get_chi_matrix_2( return chi_mat L, N = alpha.shape + + # changed to get_chi_matrix_2 after removing update_mask S = get_chi_matrix_1(s, J) Z = get_chi_matrix_1(z, J) @@ -756,7 +750,6 @@ def make_pathfinder_body( num_draws: int, maxcor: int, num_elbo_draws: int, - epsilon: float, **compile_kwargs: dict, ) -> Function: """ @@ -772,8 +765,6 @@ def make_pathfinder_body( The maximum number of iterations for the L-BFGS algorithm. num_elbo_draws : int The number of draws for the Evidence Lower Bound (ELBO) estimation. - epsilon : float - The value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. compile_kwargs : dict Additional keyword arguments for the PyTensor compiler. @@ -798,10 +789,9 @@ def make_pathfinder_body( num_draws = pt.constant(num_draws, "num_draws", dtype="int32") num_elbo_draws = pt.constant(num_elbo_draws, "num_elbo_draws", dtype="int32") - epsilon = pt.constant(epsilon, "epsilon", dtype="float64") maxcor = pt.constant(maxcor, "maxcor", dtype="int32") - alpha, s, z = alpha_recover(x_full, g_full, epsilon=epsilon) + alpha, s, z = alpha_recover(x_full, g_full) beta, gamma = inverse_hessian_factors(alpha, s, z, J=maxcor) # ignore initial point - x, g: (L, N) @@ -912,11 +902,11 @@ def neg_logp_dlogp_func(x): x_base = DictToArrayBijection.map(ip).data # lbfgs - lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls) + lbfgs = LBFGS(neg_logp_dlogp_func, maxcor, maxiter, ftol, gtol, maxls, epsilon) # pathfinder body pathfinder_body_fn = make_pathfinder_body( - logp_func, num_draws, maxcor, num_elbo_draws, epsilon, **compile_kwargs + logp_func, num_draws, maxcor, num_elbo_draws, **compile_kwargs ) rngs = find_rng_nodes(pathfinder_body_fn.maker.fgraph.outputs) From 32fa4139dc771b762292057af48903824c637df7 Mon Sep 17 00:00:00 2001 From: aphc14 <177544929+aphc14@users.noreply.github.com> Date: Sun, 20 Apr 2025 03:12:53 +1000 Subject: [PATCH 4/4] Updated LBFGS status handling and alpha_recover function - Corrected the condition for LOW_UPDATE_PCT in LBFGS status handling. - Removed update_mask references in alpha_recover and inverse_hessian_factors - Adjusted test cases to reflect changes in status messages and function signatures. --- pymc_extras/inference/pathfinder/lbfgs.py | 2 +- pymc_extras/inference/pathfinder/pathfinder.py | 8 ++++---- tests/test_pathfinder.py | 17 ++++++++--------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pymc_extras/inference/pathfinder/lbfgs.py b/pymc_extras/inference/pathfinder/lbfgs.py index e910f39db..b567c05cc 100644 --- a/pymc_extras/inference/pathfinder/lbfgs.py +++ b/pymc_extras/inference/pathfinder/lbfgs.py @@ -218,7 +218,7 @@ def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]: elif result.status == 2: # precision loss resulting to inf or nan lbfgs_status = LBFGSStatus.NON_FINITE - elif history.count < low_update_threshold * result.nit: + elif history.count * low_update_threshold < result.nit: lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT else: lbfgs_status = LBFGSStatus.CONVERGED diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 9829e9cf5..cddc175ba 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -262,7 +262,7 @@ def alpha_recover( shapes: L=batch_size, N=num_params """ - def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable: + def compute_alpha_l(s_l, z_l, alpha_lm1) -> TensorVariable: # alpha_lm1: (N,) # s_l: (N,) # z_l: (N,) @@ -290,7 +290,7 @@ def compute_alpha_l(alpha_lm1, s_l, z_l) -> TensorVariable: ) # assert np.all(alpha.eval() > 0), "alpha cannot be negative" - # alpha: (L, N), update_mask: (L, N) + # alpha: (L, N) return alpha, s, z @@ -368,8 +368,8 @@ def get_chi_matrix_2(diff: TensorVariable, J: TensorConstant) -> TensorVariable: L, N = alpha.shape # changed to get_chi_matrix_2 after removing update_mask - S = get_chi_matrix_1(s, J) - Z = get_chi_matrix_1(z, J) + S = get_chi_matrix_2(s, J) + Z = get_chi_matrix_2(z, J) # E: (L, J, J) Ij = pt.eye(J)[None, ...] diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index 6b30226f0..5e8773310 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -106,8 +106,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter): ) out, err = capsys.readouterr() status_pattern = [ - r"INIT_FAILED_LOW_UPDATE_MASK\s+\d+", - r"LOW_UPDATE_MASK_RATIO\s+\d+", + r"INIT_FAILED_LOW_UPDATE_PCT\s+\d+", + r"LOW_UPDATE_PCT\s+\d+", r"LBFGS_FAILED\s+\d+", r"SUCCESS\s+\d+", ] @@ -126,8 +126,8 @@ def test_unstable_lbfgs_update_mask(capsys, jitter): out, err = capsys.readouterr() status_pattern = [ - r"INIT_FAILED_LOW_UPDATE_MASK\s+2", - r"LOW_UPDATE_MASK_RATIO\s+2", + r"INIT_FAILED_LOW_UPDATE_PCT\s+2", + r"LOW_UPDATE_PCT\s+2", r"LBFGS_FAILED\s+4", ] for pattern in status_pattern: @@ -232,12 +232,11 @@ def test_bfgs_sample(): # get factors x_full = pt.as_tensor(x_data, dtype="float64") g_full = pt.as_tensor(g_data, dtype="float64") - epsilon = 1e-11 x = x_full[1:] g = g_full[1:] - alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon) - beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J) + alpha, s, z = alpha_recover(x_full, g_full) + beta, gamma = inverse_hessian_factors(alpha, s, z, J) # sample phi, logq = bfgs_sample( @@ -252,8 +251,8 @@ def test_bfgs_sample(): # check shapes assert beta.eval().shape == (L, N, 2 * J) assert gamma.eval().shape == (L, 2 * J, 2 * J) - assert phi.eval().shape == (L, num_samples, N) - assert logq.eval().shape == (L, num_samples) + assert all(phi.shape.eval() == (L, num_samples, N)) + assert all(logq.shape.eval() == (L, num_samples)) @pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])