Skip to content

Commit ef2956f

Browse files
committed
Sync updates with draft PR #386. \n- Added pytensor.function for bfgs_sample
1 parent 1fd7a11 commit ef2956f

File tree

1 file changed

+89
-31
lines changed

1 file changed

+89
-31
lines changed

pymc_experimental/inference/pathfinder/pathfinder.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,8 @@ def bfgs_sample(
302302
alpha,
303303
beta,
304304
gamma,
305-
random_seed: RandomSeed | None = None,
305+
# random_seed: RandomSeed | None = None,
306+
rng,
306307
):
307308
# batch: L = 8
308309
# alpha_l: (N,) => (L, N)
@@ -315,7 +316,7 @@ def bfgs_sample(
315316
# logdensity: (M,) => (L, M)
316317
# theta: (J, N)
317318

318-
rng = pytensor.shared(np.random.default_rng(seed=random_seed))
319+
# rng = pytensor.shared(np.random.default_rng(seed=random_seed))
319320

320321
def batched(x, g, alpha, beta, gamma):
321322
var_list = [x, g, alpha, beta, gamma]
@@ -380,6 +381,64 @@ def compute_logp(logp_func, arr):
380381
return np.where(np.isnan(logP), -np.inf, logP)
381382

382383

384+
_x = pt.matrix("_x", dtype="float64")
385+
_g = pt.matrix("_g", dtype="float64")
386+
_alpha = pt.matrix("_alpha", dtype="float64")
387+
_beta = pt.tensor3("_beta", dtype="float64")
388+
_gamma = pt.tensor3("_gamma", dtype="float64")
389+
_epsilon = pt.scalar("_epsilon", dtype="float64")
390+
_maxcor = pt.iscalar("_maxcor")
391+
_alpha, _S, _Z, _update_mask = alpha_recover(_x, _g, epsilon=_epsilon)
392+
_beta, _gamma = inverse_hessian_factors(_alpha, _S, _Z, _update_mask, J=_maxcor)
393+
394+
_num_elbo_draws = pt.iscalar("_num_elbo_draws")
395+
_dummy_rng = pytensor.shared(np.random.default_rng(), name="_dummy_rng")
396+
_phi, _logQ_phi = bfgs_sample(
397+
num_samples=_num_elbo_draws,
398+
x=_x,
399+
g=_g,
400+
alpha=_alpha,
401+
beta=_beta,
402+
gamma=_gamma,
403+
rng=_dummy_rng,
404+
)
405+
406+
_num_draws = pt.iscalar("_num_draws")
407+
_x_lstar = pt.dvector("_x_lstar")
408+
_g_lstar = pt.dvector("_g_lstar")
409+
_alpha_lstar = pt.dvector("_alpha_lstar")
410+
_beta_lstar = pt.dmatrix("_beta_lstar")
411+
_gamma_lstar = pt.dmatrix("_gamma_lstar")
412+
413+
414+
_psi, _logQ_psi = bfgs_sample(
415+
num_samples=_num_draws,
416+
x=_x_lstar,
417+
g=_g_lstar,
418+
alpha=_alpha_lstar,
419+
beta=_beta_lstar,
420+
gamma=_gamma_lstar,
421+
rng=_dummy_rng,
422+
)
423+
424+
alpha_recover_compiled = pytensor.function(
425+
inputs=[_x, _g, _epsilon],
426+
outputs=[_alpha, _S, _Z, _update_mask],
427+
)
428+
inverse_hessian_factors_compiled = pytensor.function(
429+
inputs=[_alpha, _S, _Z, _update_mask, _maxcor],
430+
outputs=[_beta, _gamma],
431+
)
432+
bfgs_sample_compiled = pytensor.function(
433+
inputs=[_num_elbo_draws, _x, _g, _alpha, _beta, _gamma],
434+
outputs=[_phi, _logQ_phi],
435+
)
436+
bfgs_sample_lstar_compiled = pytensor.function(
437+
inputs=[_num_draws, _x_lstar, _g_lstar, _alpha_lstar, _beta_lstar, _gamma_lstar],
438+
outputs=[_psi, _logQ_psi],
439+
)
440+
441+
383442
def single_pathfinder(
384443
model,
385444
num_draws: int,
@@ -423,47 +482,46 @@ def neg_dlogp_func(x):
423482
maxls=maxls,
424483
)
425484

426-
# x_full, g_full: (L+1, N)
427-
x_full = pt.as_tensor(lbfgs_history.x, dtype="float64")
428-
g_full = pt.as_tensor(lbfgs_history.g, dtype="float64")
485+
# x, g: (L+1, N)
486+
x = lbfgs_history.x
487+
g = lbfgs_history.g
488+
alpha, S, Z, update_mask = alpha_recover_compiled(x, g, epsilon)
489+
beta, gamma = inverse_hessian_factors_compiled(alpha, S, Z, update_mask, maxcor)
429490

430491
# ignore initial point - x, g: (L, N)
431-
x = x_full[1:]
432-
g = g_full[1:]
433-
434-
alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon)
435-
beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J=maxcor)
436-
437-
phi, logQ_phi = bfgs_sample(
438-
num_samples=num_elbo_draws,
439-
x=x,
440-
g=g,
441-
alpha=alpha,
442-
beta=beta,
443-
gamma=gamma,
444-
random_seed=pathfinder_seed,
492+
x = x[1:]
493+
g = g[1:]
494+
495+
rng = pytensor.shared(np.random.default_rng(pathfinder_seed), borrow=True)
496+
phi, logQ_phi = bfgs_sample_compiled.copy(swap={_dummy_rng: rng})(
497+
num_elbo_draws,
498+
x,
499+
g,
500+
alpha,
501+
beta,
502+
gamma,
445503
)
446504

447505
# .vectorize is slower than apply_along_axis
448-
logP_phi = compute_logp(logp_func, phi.eval())
449-
logQ_phi = logQ_phi.eval()
506+
logP_phi = compute_logp(logp_func, phi)
507+
# logQ_phi = logQ_phi.eval()
450508
elbo = (logP_phi - logQ_phi).mean(axis=-1)
451509
lstar = np.argmax(elbo)
452510

453511
# BUG: elbo may all be -inf for all l in L. So np.argmax(elbo) will return 0 which is wrong. Still, this won't affect the posterior samples in the multipath Pathfinder scenario because of PSIS/PSIR step. However, the user is left unaware of a failed Pathfinder run.
454512
# TODO: handle this case, e.g. by warning of a failed Pathfinder run and skip the following bfgs_sample step to save time.
455513

456-
psi, logQ_psi = bfgs_sample(
457-
num_samples=num_draws,
458-
x=x[lstar],
459-
g=g[lstar],
460-
alpha=alpha[lstar],
461-
beta=beta[lstar],
462-
gamma=gamma[lstar],
463-
random_seed=sample_seed,
514+
rng.set_value(np.random.default_rng(sample_seed), borrow=True)
515+
psi, logQ_psi = bfgs_sample_lstar_compiled.copy(swap={_dummy_rng: rng})(
516+
num_draws,
517+
x[lstar],
518+
g[lstar],
519+
alpha[lstar],
520+
beta[lstar],
521+
gamma[lstar],
464522
)
465-
psi = psi.eval()
466-
logQ_psi = logQ_psi.eval()
523+
# psi = psi.eval()
524+
# logQ_psi = logQ_psi.eval()
467525
logP_psi = compute_logp(logp_func, psi)
468526
# psi: (1, M, N)
469527
# logP_psi: (1, M)

0 commit comments

Comments
 (0)