Skip to content

Commit 95c784c

Browse files
committed
Add change_dist_size
Add tests.
1 parent 8f0ffb3 commit 95c784c

File tree

2 files changed

+94
-19
lines changed

2 files changed

+94
-19
lines changed

pymc/distributions/timeseries.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from aesara.tensor.random.op import RandomVariable
2828

2929
from pymc.aesaraf import constant_fold, floatX, intX
30-
from pymc.distributions import distribution
3130
from pymc.distributions.continuous import Normal, get_tau_sigma
3231
from pymc.distributions.distribution import (
3332
Distribution,
@@ -461,7 +460,7 @@ class AR(Distribution):
461460
process.
462461
init_dist : unnamed distribution, optional
463462
Scalar or vector distribution for initial values. Unnamed refers to distributions
464-
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
463+
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
465464
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
466465
467466
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -914,7 +913,7 @@ class EulerMaruyama(Distribution):
914913
parameters of the SDE, passed as ``*args`` to ``sde_fn``
915914
init_dist : unnamed distribution, optional
916915
Scalar or vector distribution for initial values. Unnamed refers to distributions
917-
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
916+
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
918917
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
919918
920919
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -953,7 +952,7 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
953952
f"Init dist must be a distribution created via the `.dist()` API, "
954953
f"got {type(init_dist)}"
955954
)
956-
check_dist_not_registered(init_dist)
955+
check_dist_not_registered(init_dist)
957956
if init_dist.owner.op.ndim_supp > 1:
958957
raise ValueError(
959958
"Init distribution must have a scalar or vector support dimension, ",
@@ -970,17 +969,15 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
970969
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
971970
init_dist = ignore_logprob(init_dist)
972971

973-
return super().dist([dt, sde_fn, sde_pars, init_dist, steps], **kwargs)
972+
return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs)
974973

975974
@classmethod
976-
def rv_op(cls, dt, sde_fn, sde_pars, init_dist, steps, size=None):
977-
# Init dist should have shape (*size, ar_order)
975+
def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
976+
# Init dist should have shape (*size,)
978977
if size is not None:
979978
batch_size = size
980979
else:
981-
# In this case the size of the init_dist depends on the parameters shape
982-
# The last dimension of rho and init_dist does not matter
983-
batch_size = at.broadcast_shape(*sde_pars, at.atleast_1d(init_dist)[..., 0])
980+
batch_size = at.broadcast_shape(*sde_pars, init_dist)
984981
init_dist = change_dist_size(init_dist, batch_size)
985982

986983
# Create OpFromGraph representing random draws form AR process
@@ -1024,6 +1021,24 @@ def step(*prev_args):
10241021
return eulermaruyama
10251022

10261023

1024+
@_change_dist_size.register(EulerMaruyamaRV)
1025+
def change_eulermaruyama_size(op, dist, new_size, expand=False):
1026+
1027+
if expand:
1028+
old_size = dist.shape[:-1]
1029+
new_size = tuple(new_size) + tuple(old_size)
1030+
1031+
init_dist, steps, *sde_pars, _ = dist.owner.inputs
1032+
return EulerMaruyama.rv_op(
1033+
init_dist,
1034+
steps,
1035+
sde_pars,
1036+
dt=op.dt,
1037+
sde_fn=op.sde_fn,
1038+
size=new_size,
1039+
)
1040+
1041+
10271042
@_logprob.register(EulerMaruyamaRV)
10281043
def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs):
10291044
(x,) = values

pymc/tests/distributions/test_timeseries.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -831,22 +831,82 @@ def test_change_dist_size(self):
831831

832832

833833
class TestEulerMaruyama:
834+
@pytest.mark.parametrize("batched_param", [1, 2])
835+
@pytest.mark.parametrize("explicit_shape", (True, False))
836+
def test_batched_size(self, explicit_shape, batched_param):
837+
steps, batch_size = 100, 5
838+
param_val = np.square(np.random.randn(batch_size))
839+
if explicit_shape:
840+
kwargs = {"shape": (batch_size, steps)}
841+
else:
842+
kwargs = {"steps": steps - 1}
843+
844+
def sde_fn(x, k, d, s):
845+
return (k - d * x, s)
846+
847+
sde_pars = [1.0, 2.0, 0.1]
848+
sde_pars[batched_param] = sde_pars[batched_param] * param_val
849+
with Model() as t0:
850+
y = EulerMaruyama("y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs)
851+
852+
y_eval = draw(y, draws=2)
853+
assert y_eval[0].shape == (batch_size, steps)
854+
assert not np.any(np.isclose(y_eval[0], y_eval[1]))
855+
856+
if explicit_shape:
857+
kwargs["shape"] = steps
858+
with Model() as t1:
859+
for i in range(batch_size):
860+
sde_pars_slice = sde_pars.copy()
861+
sde_pars_slice[batched_param] = sde_pars[batched_param][i]
862+
EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs)
863+
864+
np.testing.assert_allclose(
865+
t0.compile_logp()(t0.initial_point()),
866+
t1.compile_logp()(t1.initial_point()),
867+
)
868+
869+
def test_change_dist_size1(self):
870+
def sde1(x, k, d, s):
871+
return (k - d * x, s)
872+
873+
base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde1, sde_pars=(1, 2, 0.1), shape=(5, 10))
874+
875+
new_dist = change_dist_size(base_dist, (4,))
876+
assert new_dist.eval().shape == (4, 10)
877+
878+
new_dist = change_dist_size(base_dist, (4,), expand=True)
879+
assert new_dist.eval().shape == (4, 5, 10)
834880

835-
def _gen_sde_path(self, sde, pars, dt, n, x0):
836-
xs = [x0]
837-
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
838-
for i in range(n):
839-
f, g = sde(xs[-1], *pars)
840-
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
841-
return np.array(xs)
881+
def test_change_dist_size2(self):
882+
def sde2(p, s):
883+
N = 500.0
884+
return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N)
842885

843-
def test_linear(self):
886+
base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde2, sde_pars=(0.1,), shape=(3, 10))
887+
888+
new_dist = change_dist_size(base_dist, (4,))
889+
assert new_dist.eval().shape == (4, 10)
890+
891+
new_dist = change_dist_size(base_dist, (4,), expand=True)
892+
assert new_dist.eval().shape == (4, 3, 10)
893+
894+
def test_linear_model(self):
844895
lam = -0.78
845896
sig2 = 5e-3
846897
N = 300
847898
dt = 1e-1
899+
900+
def _gen_sde_path(sde, pars, dt, n, x0):
901+
xs = [x0]
902+
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
903+
for i in range(n):
904+
f, g = sde(xs[-1], *pars)
905+
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
906+
return np.array(xs)
907+
848908
sde = lambda x, lam: (lam * x, sig2)
849-
x = floatX(self._gen_sde_path(sde, (lam,), dt, N, 5.0))
909+
x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0))
850910
z = x + np.random.randn(x.size) * sig2
851911
# build model
852912
with Model() as model:

0 commit comments

Comments
 (0)