Skip to content

Commit 5740f2e

Browse files
committed
address comments, pass approx arg correctly, improve docstrings
1 parent 766d3b1 commit 5740f2e

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

pymc/gp/gp.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -685,17 +685,12 @@ def __init__(self, approx="VFE", *, mean_func=Zero(), cov_func=Constant(0.0)):
685685
super().__init__(mean_func=mean_func, cov_func=cov_func)
686686

687687
def __add__(self, other):
688-
# new_gp will default to FITC approx
689688
new_gp = super().__add__(other)
690-
# make sure new gp has correct approx
691689
if not self.approx == other.approx:
692690
raise TypeError("Cannot add GPs with different approximations")
693691
new_gp.approx = self.approx
694692
return new_gp
695693

696-
# Use y as first argument, so that we can use functools.partial
697-
# in marginal_likelihood instead of lambda. This makes pickling
698-
# possible.
699694
def _build_marginal_likelihood_logp(self, y, X, Xu, sigma, jitter):
700695
sigma2 = at.square(sigma)
701696
Kuu = self.cov_func(Xu)
@@ -765,7 +760,7 @@ def marginal_likelihood(
765760
self.sigma = noise
766761

767762
approx_logp = self._build_marginal_likelihood_logp(y, X, Xu, noise, JITTER_DEFAULT)
768-
pm.Potential("marginalapprox_logp_" + name, approx_logp)
763+
pm.Potential(f"marginalapprox_logp_{name}", approx_logp)
769764

770765
def _build_conditional(
771766
self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total, jitter

pymc/tests/test_gp.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -863,42 +863,48 @@ def setup_method(self):
863863
self.map_full = pm.find_MAP(method="bfgs") # bfgs seems to work much better than lbfgsb
864864

865865
self.x_new = np.linspace(-6, 6, 20)
866+
867+
# Include additive Gaussian noise, return diagonal of predicted covariance matrix
866868
with model:
867869
self.pred_mu, self.pred_var = self.gp.predict(
868870
self.x_new[:, None], point=self.map_full, pred_noise=True, diag=True
869871
)
870872

873+
# Dont include additive Gaussian noise, return full predicted covariance matrix
871874
with model:
872875
self.pred_mu, self.pred_covar = self.gp.predict(
873876
self.x_new[:, None], point=self.map_full, pred_noise=False, diag=False
874877
)
875878

876879
@pytest.mark.parametrize("approx", ["FITC", "VFE", "DTC"])
877880
def test_fits_and_preds(self, approx):
878-
# check logp & dlogp, optimization gets approximately correct result
881+
"""Get MAP estimate for GP approximation, compare results and predictions to what's returned
882+
by an unapproximated GP. The tolerances are fairly wide, but narrow relative to initial
883+
values of the unknown parameters.
884+
"""
885+
879886
with pm.Model() as model:
880887
cov_func = pm.gp.cov.Linear(1, c=0.0)
881888
c = pm.Normal("c", mu=20.0, sigma=100.0, initval=-500.0)
882889
mean_func = pm.gp.mean.Constant(c)
883-
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx="VFE")
890+
gp = pm.gp.MarginalApprox(mean_func=mean_func, cov_func=cov_func, approx=approx)
884891
sigma = pm.HalfNormal("sigma", sigma=100, initval=50.0)
885892
gp.marginal_likelihood("lik", self.x[:, None], self.x[:, None], self.y, sigma)
886893
map_approx = pm.find_MAP(method="bfgs")
887894

888-
# use wide tolerances (but narrow relative to initial values of unknown parameters) because
889-
# test is likely flakey
895+
# Check MAP gets approximately correct result
890896
npt.assert_allclose(self.map_full["c"], map_approx["c"], atol=0.01, rtol=0.1)
891897
npt.assert_allclose(self.map_full["sigma"], map_approx["sigma"], atol=0.01, rtol=0.1)
892898

893-
# check that predict (and conditional) work, include noise, with diagonal non-full pred var
899+
# Check that predict (and conditional) work, include noise, with diagonal non-full pred var.
894900
with model:
895901
pred_mu_approx, pred_var_approx = gp.predict(
896902
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True
897903
)
898904
npt.assert_allclose(self.pred_mu, pred_mu_approx, atol=0.0, rtol=0.1)
899905
npt.assert_allclose(self.pred_var, pred_var_approx, atol=0.0, rtol=0.1)
900906

901-
# check that predict (and conditional) work, no noise, full pred covariance
907+
# Check that predict (and conditional) work, no noise, full pred covariance.
902908
with model:
903909
pred_mu_approx, pred_var_approx = gp.predict(
904910
self.x_new[:, None], point=map_approx, pred_noise=True, diag=True

0 commit comments

Comments
 (0)