Skip to content

Commit 56963e0

Browse files
committed
add scale parameter for HSGP
1 parent 32d9e86 commit 56963e0

File tree

2 files changed

+74
-45
lines changed

2 files changed

+74
-45
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import pymc as pm
2525

26-
from pymc.gp.cov import Covariance
26+
from pymc.gp.cov import Covariance, Periodic
2727
from pymc.gp.gp import Base
2828
from pymc.gp.mean import Mean, Zero
2929

@@ -69,19 +69,16 @@ def calc_eigenvectors(
6969
def calc_basis_periodic(
7070
Xs: TensorLike,
7171
period: TensorLike,
72-
m: Sequence[int],
72+
m: int,
7373
tl: ModuleType = np,
7474
):
7575
"""
7676
Calculate basis vectors for the cosine series expansion of the periodic covariance function.
7777
These are derived from the Taylor series representation of the covariance.
7878
"""
79-
if len(m) != 1:
80-
raise ValueError("`Periodic` basis vectors only implemented for 1-dimensional case.")
81-
m0 = m[0] # for compatibility with other kernels, m must be a sequence
8279
w0 = (2 * np.pi) / period # angular frequency defining the periodicity
83-
m1 = tl.tile(w0 * Xs, m0)
84-
m2 = tl.diag(tl.arange(0, m0, 1))
80+
m1 = tl.tile(w0 * Xs, m)
81+
m2 = tl.diag(tl.arange(0, m, 1))
8582
mw0x = m1 @ m2
8683
phi_cos = tl.cos(mw0x)
8784
phi_sin = tl.sin(mw0x)
@@ -128,8 +125,8 @@ class HSGP(Base):
128125
parameterization: str
129126
Whether to use `centred` or `noncentered` parameterization when multiplying the
130127
basis by the coefficients.
131-
cov_func: None, 2D array, or instance of `Covariance`
132-
The covariance function. Defaults to zero.
128+
cov_func: Covariance function, must be an instance of `Stationary` and implement a
129+
`power_spectral_density` method.
133130
mean_func: None, instance of Mean
134131
The mean function. Defaults to zero.
135132
@@ -431,10 +428,11 @@ class HSGPPeriodic(Base):
431428
432429
Parameters
433430
----------
434-
m: list
435-
The number of basis vectors to use. Must be a list with one element.
436-
cov_func: Instance of `Periodic` covariance
437-
The covariance function. Defaults to zero.
431+
m: int
432+
The number of basis vectors to use. Must be a positive integer.
433+
scale: TensorLike
434+
The standard deviation (square root of the variance) of the GP effect. Defaults to 1.0.
435+
cov_func: Must be an instance of instance of `Periodic` covariance
438436
mean_func: None, instance of Mean
439437
The mean function. Defaults to zero.
440438
@@ -447,10 +445,11 @@ class HSGPPeriodic(Base):
447445
448446
with pm.Model() as model:
449447
# Specify the covariance function, only for the 1-D case
448+
scale = pm.HalfNormal(10)
450449
cov_func = pm.gp.cov.Periodic(1, period=1, ls=0.1)
451450
452451
# Specify the approximation with 25 basis vectors
453-
gp = pm.gp.HSGPPeriodic(m=[25], cov_func=cov_func)
452+
gp = pm.gp.HSGPPeriodic(m=25, scale=scale, cov_func=cov_func)
454453
455454
# Place a GP prior over the function f.
456455
f = gp.prior("f", X=X)
@@ -472,31 +471,37 @@ class HSGPPeriodic(Base):
472471

473472
def __init__(
474473
self,
475-
m: Sequence[int],
474+
m: int,
475+
scale: Optional[Union[float, TensorLike]] = 1.0,
476476
*,
477477
mean_func: Mean = Zero(),
478-
cov_func: Covariance,
478+
cov_func: Periodic,
479479
):
480-
arg_err_msg = (
481-
"`m` and `L`, if provided, must be sequences with one element per active "
482-
"dimension of the kernel or covariance function."
483-
)
480+
arg_err_msg = "`m` must be a positive integer as the `Periodic` kernel approximation is only implemented for 1-dimensional case."
484481

485-
if not isinstance(m, Sequence):
482+
if not isinstance(m, int):
483+
raise ValueError(arg_err_msg)
484+
485+
if m <= 0:
486486
raise ValueError(arg_err_msg)
487487

488+
if not isinstance(cov_func, Periodic):
489+
raise ValueError(
490+
"`cov_func` must be an instance of a `Periodic` kernel only. Use the `scale` parameter to control the variance."
491+
)
492+
488493
if cov_func.n_dims > 1:
489494
raise ValueError(
490495
"HSGP approximation for `Periodic` kernel only implemented for 1-dimensional case."
491496
)
492-
m = tuple(m)
493497

494498
self._m = m
499+
self.scale = scale
495500

496501
super().__init__(mean_func=mean_func, cov_func=cov_func)
497502

498503
def prior_linearized(self, Xs: TensorLike):
499-
"""Linearized version of the approximation. Returns the cosine and sine bases and coeffients
504+
"""Linearized version of the approximation. Returns the cosine and sine bases and coefficients
500505
of the expansion needed to create the GP.
501506
502507
This function allows the user to bypass the GP interface and work directly with the basis
@@ -529,10 +534,11 @@ def prior_linearized(self, Xs: TensorLike):
529534
X = np.linspace(0, 10, 100)[:, None]
530535
531536
with pm.Model() as model:
537+
scale = pm.HalfNormal(10)
532538
cov_func = pm.gp.cov.Periodic(1, period=1, ls=ell)
533539
534-
# m = [200] means 200 basis vectors for the first dimension
535-
gp = pm.gp.HSGPPeriodic(m=[200], cov_func=cov_func)
540+
# m=200 means 200 basis vectors
541+
gp = pm.gp.HSGPPeriodic(m=200, scale=scale, cov_func=cov_func)
536542
537543
# Order is important. First calculate the mean, then make X a shared variable,
538544
# then subtract the mean. When X is mutated later, the correct mean will be
@@ -542,19 +548,19 @@ def prior_linearized(self, Xs: TensorLike):
542548
Xs = X - X_mean
543549
544550
# Pass the zero-subtracted Xs in to the GP
545-
(phi_cos, phi_sin), psd, sqrt_psd = gp.prior_linearized(Xs=Xs)
551+
(phi_cos, phi_sin), psd = gp.prior_linearized(Xs=Xs)
546552
547553
# Specify standard normal prior in the coefficients. The number of which
548554
# is twice the number of basis vectors minus one.
549555
# This is so that each cosine term has a `beta` and all but one of the
550556
# sine terms, as first eigenfunction for the sine component is zero
551-
m0 = gp._m[0]
552-
beta = pm.Normal("beta", size=(m0 * 2 - 1))
557+
m = gp._m
558+
beta = pm.Normal("beta", size=(m * 2 - 1))
553559
554560
# The (non-centered) GP approximation is given by
555561
f = pm.Deterministic(
556562
"f",
557-
phi_cos @ (psd * self._beta[:m0]) + phi_sin[..., 1:] @ (psd[1:] * self._beta[m0:])
563+
phi_cos @ (psd * beta[:m]) + phi_sin[..., 1:] @ (psd[1:] * beta[m:])
558564
)
559565
...
560566
@@ -572,8 +578,9 @@ def prior_linearized(self, Xs: TensorLike):
572578
Xs, _ = self.cov_func._slice(Xs)
573579

574580
phi_cos, phi_sin = calc_basis_periodic(Xs, self.cov_func.period, self._m, tl=pt)
575-
J = pt.arange(0, self._m[0], 1)
576-
psd = self.cov_func.power_spectral_density_approx(J)
581+
J = pt.arange(0, self._m, 1)
582+
# rescale basis coefficients by the sqrt variance term
583+
psd = self.scale * self.cov_func.power_spectral_density_approx(J)
577584
return (phi_cos, phi_sin), psd
578585

579586
def prior(self, name: str, X: TensorLike, dims: Optional[str] = None): # type: ignore
@@ -594,14 +601,14 @@ def prior(self, name: str, X: TensorLike, dims: Optional[str] = None): # type:
594601

595602
(phi_cos, phi_sin), psd = self.prior_linearized(X - self._X_mean)
596603

597-
m0 = self._m[0]
598-
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m0 * 2 - 1))
604+
m = self._m
605+
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2 - 1))
599606
# The first eigenfunction for the sine component is zero
600607
# and so does not contribute to the approximation.
601608
f = (
602609
self.mean_func(X)
603-
+ phi_cos @ (psd * self._beta[:m0]) # type: ignore
604-
+ phi_sin[..., 1:] @ (psd[1:] * self._beta[m0:]) # type: ignore
610+
+ phi_cos @ (psd * self._beta[:m]) # type: ignore
611+
+ phi_sin[..., 1:] @ (psd[1:] * self._beta[m:]) # type: ignore
605612
)
606613

607614
self.f = pm.Deterministic(name, f, dims=dims)
@@ -619,11 +626,12 @@ def _build_conditional(self, Xnew):
619626
Xnew, _ = self.cov_func._slice(Xnew)
620627

621628
phi_cos, phi_sin = calc_basis_periodic(Xnew - X_mean, self.cov_func.period, self._m, tl=pt)
622-
m0 = self._m[0]
623-
J = pt.arange(0, m0, 1)
624-
psd = self.cov_func.power_spectral_density_approx(J)
629+
m = self._m
630+
J = pt.arange(0, m, 1)
631+
# rescale basis coefficients by the sqrt variance term
632+
psd = self.scale * self.cov_func.power_spectral_density_approx(J)
625633

626-
phi = phi_cos @ (psd * beta[:m0]) + phi_sin[..., 1:] @ (psd[1:] * beta[m0:])
634+
phi = phi_cos @ (psd * beta[:m]) + phi_sin[..., 1:] @ (psd[1:] * beta[m:])
627635
return self.mean_func(Xnew) + phi
628636

629637
def conditional(self, name: str, Xnew: TensorLike, dims: Optional[str] = None): # type: ignore

tests/gp/test_hsgp_approx.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,30 +230,51 @@ def test_conditional(self, model, cov_func, X1, parameterization):
230230

231231
class TestHSGPPeriodic(_BaseFixtures):
232232
def test_parametrization(self):
233+
err_msg = "`m` must be a positive integer as the `Periodic` kernel approximation is only implemented for 1-dimensional case."
234+
235+
with pytest.raises(ValueError, match=err_msg):
236+
# `m` must be a positive integer, not a list
237+
cov_func = pm.gp.cov.Periodic(1, period=1, ls=0.1)
238+
pm.gp.HSGPPeriodic(m=[500], cov_func=cov_func)
239+
240+
with pytest.raises(ValueError, match=err_msg):
241+
# `m`` must be a positive integer
242+
cov_func = pm.gp.cov.Periodic(1, period=1, ls=0.1)
243+
pm.gp.HSGPPeriodic(m=-1, cov_func=cov_func)
244+
245+
with pytest.raises(
246+
ValueError,
247+
match="`cov_func` must be an instance of a `Periodic` kernel only. Use the `scale` parameter to control the variance.",
248+
):
249+
# `cov_func` must be `Periodic` only
250+
cov_func = 5.0 * pm.gp.cov.Periodic(1, period=1, ls=0.1)
251+
pm.gp.HSGPPeriodic(m=500, cov_func=cov_func)
252+
233253
with pytest.raises(
234254
ValueError,
235255
match="HSGP approximation for `Periodic` kernel only implemented for 1-dimensional case.",
236256
):
237257
cov_func = pm.gp.cov.Periodic(2, period=1, ls=[1, 2])
238-
pm.gp.HSGPPeriodic(m=[500, 500], cov_func=cov_func)
258+
pm.gp.HSGPPeriodic(m=500, scale=0.5, cov_func=cov_func)
239259

240260
@pytest.mark.parametrize("cov_func", [pm.gp.cov.Periodic(1, period=1, ls=1)])
261+
@pytest.mark.parametrize("eta", [100.0])
241262
@pytest.mark.xfail(
242263
reason="For `pm.gp.cov.Periodic`, this test does not pass.\
243264
The mmd is around `0.0468`.\
244265
The test passes more often when subtracting the mean from the mean from the samples.\
245266
It might be that the period is slightly off for the approximate power spectral density.\
246267
See https://github.com/pymc-devs/pymc/pull/6877/ for the full discussion."
247268
)
248-
def test_prior(self, model, cov_func, X1, rng):
269+
def test_prior(self, model, cov_func, eta, X1, rng):
249270
"""Compare HSGPPeriodic prior to unapproximated GP prior, pm.gp.Latent. Draw samples from the
250271
prior and compare them using MMD two sample test.
251272
"""
252273
with model:
253-
hsgp = pm.gp.HSGPPeriodic(m=[200], cov_func=cov_func)
274+
hsgp = pm.gp.HSGPPeriodic(m=200, scale=eta, cov_func=cov_func)
254275
f1 = hsgp.prior("f1", X=X1)
255276

256-
gp = pm.gp.Latent(cov_func=cov_func)
277+
gp = pm.gp.Latent(cov_func=eta**2 * cov_func)
257278
f2 = gp.prior("f2", X=X1)
258279

259280
idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
@@ -268,12 +289,12 @@ def test_prior(self, model, cov_func, X1, rng):
268289

269290
@pytest.mark.parametrize("cov_func", [pm.gp.cov.Periodic(1, period=1, ls=1)])
270291
def test_conditional_periodic(self, model, cov_func, X1):
271-
"""Compare HSGPPeriodic conditional to unapproximated GP prior, pm.gp.Latent. Draw samples
292+
"""Compare HSGPPeriodic conditional to HSGPPeriodic prior. Draw samples
272293
from the prior and compare them using MMD two sample test. The conditional should match the
273294
prior when no data is observed.
274295
"""
275296
with model:
276-
hsgp = pm.gp.HSGPPeriodic(m=[100], cov_func=cov_func)
297+
hsgp = pm.gp.HSGPPeriodic(m=100, cov_func=cov_func)
277298
f = hsgp.prior("f", X=X1)
278299
fc = hsgp.conditional("fc", Xnew=X1)
279300

0 commit comments

Comments
 (0)