From 8f0ffb3f60f16e5c8249b3e525d9c565eea88453 Mon Sep 17 00:00:00 2001 From: junpenglao Date: Tue, 18 Oct 2022 16:53:07 +0200 Subject: [PATCH 1/4] Refactor EulerMaruyama to work in v4 --- pymc/distributions/timeseries.py | 171 ++++++++++++++++---- pymc/tests/distributions/test_timeseries.py | 68 ++++---- 2 files changed, 175 insertions(+), 64 deletions(-) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 2ef30ff36e..43627f7b38 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -15,7 +15,7 @@ import warnings from abc import ABCMeta -from typing import Optional +from typing import Callable, Optional import aesara import aesara.tensor as at @@ -881,7 +881,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps return at.zeros_like(rv) -class EulerMaruyama(distribution.Continuous): +class EulerMaruyamaRV(SymbolicRandomVariable): + """A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph.""" + + default_output = 1 + dt: float + sde_fn: Callable + _print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}") + + def __init__(self, *args, dt, sde_fn, **kwargs): + self.dt = dt + self.sde_fn = sde_fn + super().__init__(*args, **kwargs) + + def update(self, node: Node): + """Return the update mapping for the noise RV.""" + # Since noise is a shared variable it shows up as the last node input + return {node.inputs[-1]: node.outputs[0]} + + +class EulerMaruyama(Distribution): r""" Stochastic differential equation discretized with the Euler-Maruyama method. @@ -893,39 +912,131 @@ class EulerMaruyama(distribution.Continuous): function returning the drift and diffusion coefficients of SDE sde_pars: tuple parameters of the SDE, passed as ``*args`` to ``sde_fn`` + init_dist : unnamed distribution, optional + Scalar or vector distribution for initial values. Unnamed refers to distributions + created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order). + If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...). + + .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. """ - def __new__(cls, *args, **kwargs): - raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.") + rv_type = EulerMaruyamaRV + + def __new__(cls, name, dt, sde_fn, *args, steps=None, **kwargs): + dt = at.as_tensor_variable(floatX(dt)) + steps = get_support_shape_1d( + support_shape=steps, + shape=None, # Shape will be checked in `cls.dist` + dims=kwargs.get("dims", None), + observed=kwargs.get("observed", None), + support_shape_offset=1, + ) + return super().__new__(cls, name, dt, sde_fn, *args, steps=steps, **kwargs) @classmethod - def dist(cls, *args, **kwargs): - raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.") + def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1 + ) + if steps is None: + raise ValueError("Must specify steps or shape parameter") + steps = at.as_tensor_variable(intX(steps), ndim=0) - def __init__(self, dt, sde_fn, sde_pars, *args, **kwds): - super().__init__(*args, **kwds) - self.dt = dt = at.as_tensor_variable(dt) - self.sde_fn = sde_fn - self.sde_pars = sde_pars + dt = at.as_tensor_variable(floatX(dt)) + sde_pars = [at.as_tensor_variable(x) for x in sde_pars] - def logp(self, x): - """ - Calculate log-probability of EulerMaruyama distribution at specified value. + if init_dist is not None: + if not isinstance(init_dist, TensorVariable) or not isinstance( + init_dist.owner.op, (RandomVariable, SymbolicRandomVariable) + ): + raise ValueError( + f"Init dist must be a distribution created via the `.dist()` API, " + f"got {type(init_dist)}" + ) + check_dist_not_registered(init_dist) + if init_dist.owner.op.ndim_supp > 1: + raise ValueError( + "Init distribution must have a scalar or vector support dimension, ", + f"got ndim_supp={init_dist.owner.op.ndim_supp}.", + ) + else: + warnings.warn( + "Initial distribution not specified, defaulting to " + "`Normal.dist(0, 100, shape=...)`. You can specify an init_dist " + "manually to suppress this warning.", + UserWarning, + ) + init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape) + # Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term + init_dist = ignore_logprob(init_dist) - Parameters - ---------- - x: numeric - Value for which log-probability is calculated. + return super().dist([dt, sde_fn, sde_pars, init_dist, steps], **kwargs) - Returns - ------- - TensorVariable - """ - xt = x[:-1] - f, g = self.sde_fn(x[:-1], *self.sde_pars) - mu = xt + self.dt * f - sigma = at.sqrt(self.dt) * g - return at.sum(Normal.dist(mu=mu, sigma=sigma).logp(x[1:])) - - def _distr_parameters_for_repr(self): - return ["dt"] + @classmethod + def rv_op(cls, dt, sde_fn, sde_pars, init_dist, steps, size=None): + # Init dist should have shape (*size, ar_order) + if size is not None: + batch_size = size + else: + # In this case the size of the init_dist depends on the parameters shape + # The last dimension of rho and init_dist does not matter + batch_size = at.broadcast_shape(*sde_pars, at.atleast_1d(init_dist)[..., 0]) + init_dist = change_dist_size(init_dist, batch_size) + + # Create OpFromGraph representing random draws form AR process + # Variables with underscore suffix are dummy inputs into the OpFromGraph + init_ = init_dist.type() + sde_pars_ = [x.type() for x in sde_pars] + steps_ = steps.type() + + noise_rng = aesara.shared(np.random.default_rng()) + + def step(*prev_args): + prev_y, *prev_sde_pars, rng = prev_args + f, g = sde_fn(prev_y, *prev_sde_pars) + mu = prev_y + dt * f + sigma = at.sqrt(dt) * g + next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs + return next_y, {rng: next_rng} + + y_t, innov_updates_ = aesara.scan( + fn=step, + outputs_info=[init_], + non_sequences=sde_pars_ + [noise_rng], + n_steps=steps_, + strict=True, + ) + (noise_next_rng,) = tuple(innov_updates_.values()) + + sde_out_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle( + tuple(range(1, y_t.ndim)) + (0,) + ) + + eulermaruyama_op = EulerMaruyamaRV( + inputs=[init_, steps_] + sde_pars_, + outputs=[noise_next_rng, sde_out_], + dt=dt, + sde_fn=sde_fn, + ndim_supp=1, + ) + + eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars) + return eulermaruyama + + +@_logprob.register(EulerMaruyamaRV) +def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs): + (x,) = values + # noise arg is unused, but is needed to make the logp signature match the rv_op signature + *sde_pars, _ = sde_pars_noise_arg + xtm1 = x[..., :-1] + xt = x[..., 1:] + f, g = op.sde_fn(xtm1, *sde_pars) + mu = xtm1 + op.dt * f + sigma = at.sqrt(op.dt) * g + # Compute and collapse logp across time dimension + sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1) + init_logp = logp(init_dist, x[..., :1]) + if init_dist.owner.op.ndim_supp == 0: + init_logp = at.sum(init_logp, axis=-1) + return init_logp + sde_logp diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 2cc8e2e091..b13aedf205 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -830,37 +830,37 @@ def test_change_dist_size(self): assert new_dist.eval().shape == (4, 3, 10) -def _gen_sde_path(sde, pars, dt, n, x0): - xs = [x0] - wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) - for i in range(n): - f, g = sde(xs[-1], *pars) - xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) - return np.array(xs) - - -@pytest.mark.xfail(reason="Euleryama not refactored", raises=NotImplementedError) -def test_linear(): - lam = -0.78 - sig2 = 5e-3 - N = 300 - dt = 1e-1 - sde = lambda x, lam: (lam * x, sig2) - x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0)) - z = x + np.random.randn(x.size) * sig2 - # build model - with Model() as model: - lamh = Flat("lamh") - xh = EulerMaruyama("xh", dt, sde, (lamh,), shape=N + 1, initval=x) - Normal("zh", mu=xh, sigma=sig2, observed=z) - # invert - with model: - trace = sample(init="advi+adapt_diag", chains=1) - - ppc = sample_posterior_predictive(trace, model=model) - - p95 = [2.5, 97.5] - lo, hi = np.percentile(trace[lamh], p95, axis=0) - assert (lo < lam) and (lam < hi) - lo, hi = np.percentile(ppc["zh"], p95, axis=0) - assert ((lo < z) * (z < hi)).mean() > 0.95 +class TestEulerMaruyama: + + def _gen_sde_path(self, sde, pars, dt, n, x0): + xs = [x0] + wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) + for i in range(n): + f, g = sde(xs[-1], *pars) + xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) + return np.array(xs) + + def test_linear(self): + lam = -0.78 + sig2 = 5e-3 + N = 300 + dt = 1e-1 + sde = lambda x, lam: (lam * x, sig2) + x = floatX(self._gen_sde_path(sde, (lam,), dt, N, 5.0)) + z = x + np.random.randn(x.size) * sig2 + # build model + with Model() as model: + lamh = Flat("lamh") + xh = EulerMaruyama("xh", dt, sde, (lamh,), steps=N, initval=x) + Normal("zh", mu=xh, sigma=sig2, observed=z) + # invert + with model: + trace = sample(chains=1) + + ppc = sample_posterior_predictive(trace, model=model) + + p95 = [2.5, 97.5] + lo, hi = np.percentile(trace.posterior["lamh"], p95, axis=[0, 1]) + assert (lo < lam) and (lam < hi) + lo, hi = np.percentile(ppc.posterior_predictive["zh"], p95, axis=[0, 1]) + assert ((lo < z) * (z < hi)).mean() > 0.95 From 95c784c5033b98b92d96cb221d8fd4e5d9b12fa7 Mon Sep 17 00:00:00 2001 From: junpenglao Date: Wed, 19 Oct 2022 14:29:25 +0200 Subject: [PATCH 2/4] Add change_dist_size Add tests. --- pymc/distributions/timeseries.py | 35 ++++++--- pymc/tests/distributions/test_timeseries.py | 78 ++++++++++++++++++--- 2 files changed, 94 insertions(+), 19 deletions(-) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 43627f7b38..f97221599a 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -27,7 +27,6 @@ from aesara.tensor.random.op import RandomVariable from pymc.aesaraf import constant_fold, floatX, intX -from pymc.distributions import distribution from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.distribution import ( Distribution, @@ -461,7 +460,7 @@ class AR(Distribution): process. init_dist : unnamed distribution, optional Scalar or vector distribution for initial values. Unnamed refers to distributions - created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order). + created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order). If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...). .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. @@ -914,7 +913,7 @@ class EulerMaruyama(Distribution): parameters of the SDE, passed as ``*args`` to ``sde_fn`` init_dist : unnamed distribution, optional Scalar or vector distribution for initial values. Unnamed refers to distributions - created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order). + created with the ``.dist()`` API. Distributions should have shape (*shape[:-1]). If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...). .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. @@ -953,7 +952,7 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): f"Init dist must be a distribution created via the `.dist()` API, " f"got {type(init_dist)}" ) - check_dist_not_registered(init_dist) + check_dist_not_registered(init_dist) if init_dist.owner.op.ndim_supp > 1: raise ValueError( "Init distribution must have a scalar or vector support dimension, ", @@ -970,17 +969,15 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): # Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term init_dist = ignore_logprob(init_dist) - return super().dist([dt, sde_fn, sde_pars, init_dist, steps], **kwargs) + return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs) @classmethod - def rv_op(cls, dt, sde_fn, sde_pars, init_dist, steps, size=None): - # Init dist should have shape (*size, ar_order) + def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None): + # Init dist should have shape (*size,) if size is not None: batch_size = size else: - # In this case the size of the init_dist depends on the parameters shape - # The last dimension of rho and init_dist does not matter - batch_size = at.broadcast_shape(*sde_pars, at.atleast_1d(init_dist)[..., 0]) + batch_size = at.broadcast_shape(*sde_pars, init_dist) init_dist = change_dist_size(init_dist, batch_size) # Create OpFromGraph representing random draws form AR process @@ -1024,6 +1021,24 @@ def step(*prev_args): return eulermaruyama +@_change_dist_size.register(EulerMaruyamaRV) +def change_eulermaruyama_size(op, dist, new_size, expand=False): + + if expand: + old_size = dist.shape[:-1] + new_size = tuple(new_size) + tuple(old_size) + + init_dist, steps, *sde_pars, _ = dist.owner.inputs + return EulerMaruyama.rv_op( + init_dist, + steps, + sde_pars, + dt=op.dt, + sde_fn=op.sde_fn, + size=new_size, + ) + + @_logprob.register(EulerMaruyamaRV) def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs): (x,) = values diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index b13aedf205..bbebe710a7 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -831,22 +831,82 @@ def test_change_dist_size(self): class TestEulerMaruyama: + @pytest.mark.parametrize("batched_param", [1, 2]) + @pytest.mark.parametrize("explicit_shape", (True, False)) + def test_batched_size(self, explicit_shape, batched_param): + steps, batch_size = 100, 5 + param_val = np.square(np.random.randn(batch_size)) + if explicit_shape: + kwargs = {"shape": (batch_size, steps)} + else: + kwargs = {"steps": steps - 1} + + def sde_fn(x, k, d, s): + return (k - d * x, s) + + sde_pars = [1.0, 2.0, 0.1] + sde_pars[batched_param] = sde_pars[batched_param] * param_val + with Model() as t0: + y = EulerMaruyama("y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs) + + y_eval = draw(y, draws=2) + assert y_eval[0].shape == (batch_size, steps) + assert not np.any(np.isclose(y_eval[0], y_eval[1])) + + if explicit_shape: + kwargs["shape"] = steps + with Model() as t1: + for i in range(batch_size): + sde_pars_slice = sde_pars.copy() + sde_pars_slice[batched_param] = sde_pars[batched_param][i] + EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs) + + np.testing.assert_allclose( + t0.compile_logp()(t0.initial_point()), + t1.compile_logp()(t1.initial_point()), + ) + + def test_change_dist_size1(self): + def sde1(x, k, d, s): + return (k - d * x, s) + + base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde1, sde_pars=(1, 2, 0.1), shape=(5, 10)) + + new_dist = change_dist_size(base_dist, (4,)) + assert new_dist.eval().shape == (4, 10) + + new_dist = change_dist_size(base_dist, (4,), expand=True) + assert new_dist.eval().shape == (4, 5, 10) - def _gen_sde_path(self, sde, pars, dt, n, x0): - xs = [x0] - wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) - for i in range(n): - f, g = sde(xs[-1], *pars) - xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) - return np.array(xs) + def test_change_dist_size2(self): + def sde2(p, s): + N = 500.0 + return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N) - def test_linear(self): + base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde2, sde_pars=(0.1,), shape=(3, 10)) + + new_dist = change_dist_size(base_dist, (4,)) + assert new_dist.eval().shape == (4, 10) + + new_dist = change_dist_size(base_dist, (4,), expand=True) + assert new_dist.eval().shape == (4, 3, 10) + + def test_linear_model(self): lam = -0.78 sig2 = 5e-3 N = 300 dt = 1e-1 + + def _gen_sde_path(sde, pars, dt, n, x0): + xs = [x0] + wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size)) + for i in range(n): + f, g = sde(xs[-1], *pars) + xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i]) + return np.array(xs) + sde = lambda x, lam: (lam * x, sig2) - x = floatX(self._gen_sde_path(sde, (lam,), dt, N, 5.0)) + x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0)) z = x + np.random.randn(x.size) * sig2 # build model with Model() as model: From b6fa76b784ef69e2fee76e16de73d25e681018dd Mon Sep 17 00:00:00 2001 From: junpenglao Date: Wed, 19 Oct 2022 19:32:06 +0200 Subject: [PATCH 3/4] Fix broadcasting --- pymc/distributions/timeseries.py | 12 +++++++----- pymc/tests/distributions/test_timeseries.py | 8 +++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index f97221599a..5d1e5f4844 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -980,7 +980,7 @@ def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None): batch_size = at.broadcast_shape(*sde_pars, init_dist) init_dist = change_dist_size(init_dist, batch_size) - # Create OpFromGraph representing random draws form AR process + # Create OpFromGraph representing random draws from SDE process # Variables with underscore suffix are dummy inputs into the OpFromGraph init_ = init_dist.type() sde_pars_ = [x.type() for x in sde_pars] @@ -1044,14 +1044,16 @@ def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwar (x,) = values # noise arg is unused, but is needed to make the logp signature match the rv_op signature *sde_pars, _ = sde_pars_noise_arg + # sde_fn is user provided and likely not broadcastable to additional time dimension, + # since the input x is now [..., t], we need to broadcast each input to [..., None] + # below as best effort attempt to make it work + sde_pars_broadcast = [x[..., None] for x in sde_pars] xtm1 = x[..., :-1] xt = x[..., 1:] - f, g = op.sde_fn(xtm1, *sde_pars) + f, g = op.sde_fn(xtm1, *sde_pars_broadcast) mu = xtm1 + op.dt * f sigma = at.sqrt(op.dt) * g # Compute and collapse logp across time dimension sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1) - init_logp = logp(init_dist, x[..., :1]) - if init_dist.owner.op.ndim_supp == 0: - init_logp = at.sum(init_logp, axis=-1) + init_logp = logp(init_dist, x[..., 0]) return init_logp + sde_logp diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index bbebe710a7..838532b456 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -859,11 +859,13 @@ def sde_fn(x, k, d, s): for i in range(batch_size): sde_pars_slice = sde_pars.copy() sde_pars_slice[batched_param] = sde_pars[batched_param][i] - EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs) + EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars_slice, **kwargs) + t0_init = t0.initial_point() + t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)} np.testing.assert_allclose( - t0.compile_logp()(t0.initial_point()), - t1.compile_logp()(t1.initial_point()), + t0.compile_logp()(t0_init), + t1.compile_logp()(t1_init), ) def test_change_dist_size1(self): From fe9cb3dd850f1f0811d7e490fbb39505f006d04a Mon Sep 17 00:00:00 2001 From: junpenglao Date: Thu, 20 Oct 2022 22:17:25 +0200 Subject: [PATCH 4/4] Update to limit support to univariate time series --- pymc/distributions/timeseries.py | 8 +++--- pymc/tests/distributions/test_timeseries.py | 31 +++++++++++++++++---- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index 5d1e5f4844..57028d90d5 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -912,8 +912,8 @@ class EulerMaruyama(Distribution): sde_pars: tuple parameters of the SDE, passed as ``*args`` to ``sde_fn`` init_dist : unnamed distribution, optional - Scalar or vector distribution for initial values. Unnamed refers to distributions - created with the ``.dist()`` API. Distributions should have shape (*shape[:-1]). + Scalar distribution for initial values. Unnamed refers to distributions created with + the ``.dist()`` API. Distributions should have shape (*shape[:-1]). If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...). .. warning:: init_dist will be cloned, rendering it independent of the one passed as input. @@ -953,9 +953,9 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs): f"got {type(init_dist)}" ) check_dist_not_registered(init_dist) - if init_dist.owner.op.ndim_supp > 1: + if init_dist.owner.op.ndim_supp > 0: raise ValueError( - "Init distribution must have a scalar or vector support dimension, ", + "Init distribution must have a scalar support dimension, ", f"got ndim_supp={init_dist.owner.op.ndim_supp}.", ) else: diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 838532b456..e4dd45900c 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -847,7 +847,10 @@ def sde_fn(x, k, d, s): sde_pars = [1.0, 2.0, 0.1] sde_pars[batched_param] = sde_pars[batched_param] * param_val with Model() as t0: - y = EulerMaruyama("y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs) + init_dist = pm.Normal.dist(0, 10, shape=(batch_size,)) + y = EulerMaruyama( + "y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs + ) y_eval = draw(y, draws=2) assert y_eval[0].shape == (batch_size, steps) @@ -859,7 +862,15 @@ def sde_fn(x, k, d, s): for i in range(batch_size): sde_pars_slice = sde_pars.copy() sde_pars_slice[batched_param] = sde_pars[batched_param][i] - EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars_slice, **kwargs) + init_dist = pm.Normal.dist(0, 10) + EulerMaruyama( + f"y_{i}", + dt=0.02, + sde_fn=sde_fn, + sde_pars=sde_pars_slice, + init_dist=init_dist, + **kwargs, + ) t0_init = t0.initial_point() t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)} @@ -872,7 +883,13 @@ def test_change_dist_size1(self): def sde1(x, k, d, s): return (k - d * x, s) - base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde1, sde_pars=(1, 2, 0.1), shape=(5, 10)) + base_dist = EulerMaruyama.dist( + dt=0.01, + sde_fn=sde1, + sde_pars=(1, 2, 0.1), + init_dist=pm.Normal.dist(0, 10), + shape=(5, 10), + ) new_dist = change_dist_size(base_dist, (4,)) assert new_dist.eval().shape == (4, 10) @@ -885,7 +902,9 @@ def sde2(p, s): N = 500.0 return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N) - base_dist = EulerMaruyama.dist(dt=0.01, sde_fn=sde2, sde_pars=(0.1,), shape=(3, 10)) + base_dist = EulerMaruyama.dist( + dt=0.01, sde_fn=sde2, sde_pars=(0.1,), init_dist=pm.Normal.dist(0, 10), shape=(3, 10) + ) new_dist = change_dist_size(base_dist, (4,)) assert new_dist.eval().shape == (4, 10) @@ -913,7 +932,9 @@ def _gen_sde_path(sde, pars, dt, n, x0): # build model with Model() as model: lamh = Flat("lamh") - xh = EulerMaruyama("xh", dt, sde, (lamh,), steps=N, initval=x) + xh = EulerMaruyama( + "xh", dt, sde, (lamh,), steps=N, initval=x, init_dist=pm.Normal.dist(0, 10) + ) Normal("zh", mu=xh, sigma=sig2, observed=z) # invert with model: