Skip to content

Commit 8f0ffb3

Browse files
committed
Refactor EulerMaruyama to work in v4
1 parent 98aadfd commit 8f0ffb3

File tree

2 files changed

+175
-64
lines changed

2 files changed

+175
-64
lines changed

pymc/distributions/timeseries.py

Lines changed: 141 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import warnings
1616

1717
from abc import ABCMeta
18-
from typing import Optional
18+
from typing import Callable, Optional
1919

2020
import aesara
2121
import aesara.tensor as at
@@ -881,7 +881,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps
881881
return at.zeros_like(rv)
882882

883883

884-
class EulerMaruyama(distribution.Continuous):
884+
class EulerMaruyamaRV(SymbolicRandomVariable):
885+
"""A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph."""
886+
887+
default_output = 1
888+
dt: float
889+
sde_fn: Callable
890+
_print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}")
891+
892+
def __init__(self, *args, dt, sde_fn, **kwargs):
893+
self.dt = dt
894+
self.sde_fn = sde_fn
895+
super().__init__(*args, **kwargs)
896+
897+
def update(self, node: Node):
898+
"""Return the update mapping for the noise RV."""
899+
# Since noise is a shared variable it shows up as the last node input
900+
return {node.inputs[-1]: node.outputs[0]}
901+
902+
903+
class EulerMaruyama(Distribution):
885904
r"""
886905
Stochastic differential equation discretized with the Euler-Maruyama method.
887906
@@ -893,39 +912,131 @@ class EulerMaruyama(distribution.Continuous):
893912
function returning the drift and diffusion coefficients of SDE
894913
sde_pars: tuple
895914
parameters of the SDE, passed as ``*args`` to ``sde_fn``
915+
init_dist : unnamed distribution, optional
916+
Scalar or vector distribution for initial values. Unnamed refers to distributions
917+
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
918+
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
919+
920+
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
896921
"""
897922

898-
def __new__(cls, *args, **kwargs):
899-
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
923+
rv_type = EulerMaruyamaRV
924+
925+
def __new__(cls, name, dt, sde_fn, *args, steps=None, **kwargs):
926+
dt = at.as_tensor_variable(floatX(dt))
927+
steps = get_support_shape_1d(
928+
support_shape=steps,
929+
shape=None, # Shape will be checked in `cls.dist`
930+
dims=kwargs.get("dims", None),
931+
observed=kwargs.get("observed", None),
932+
support_shape_offset=1,
933+
)
934+
return super().__new__(cls, name, dt, sde_fn, *args, steps=steps, **kwargs)
900935

901936
@classmethod
902-
def dist(cls, *args, **kwargs):
903-
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
937+
def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
938+
steps = get_support_shape_1d(
939+
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1
940+
)
941+
if steps is None:
942+
raise ValueError("Must specify steps or shape parameter")
943+
steps = at.as_tensor_variable(intX(steps), ndim=0)
904944

905-
def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
906-
super().__init__(*args, **kwds)
907-
self.dt = dt = at.as_tensor_variable(dt)
908-
self.sde_fn = sde_fn
909-
self.sde_pars = sde_pars
945+
dt = at.as_tensor_variable(floatX(dt))
946+
sde_pars = [at.as_tensor_variable(x) for x in sde_pars]
910947

911-
def logp(self, x):
912-
"""
913-
Calculate log-probability of EulerMaruyama distribution at specified value.
948+
if init_dist is not None:
949+
if not isinstance(init_dist, TensorVariable) or not isinstance(
950+
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
951+
):
952+
raise ValueError(
953+
f"Init dist must be a distribution created via the `.dist()` API, "
954+
f"got {type(init_dist)}"
955+
)
956+
check_dist_not_registered(init_dist)
957+
if init_dist.owner.op.ndim_supp > 1:
958+
raise ValueError(
959+
"Init distribution must have a scalar or vector support dimension, ",
960+
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
961+
)
962+
else:
963+
warnings.warn(
964+
"Initial distribution not specified, defaulting to "
965+
"`Normal.dist(0, 100, shape=...)`. You can specify an init_dist "
966+
"manually to suppress this warning.",
967+
UserWarning,
968+
)
969+
init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape)
970+
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
971+
init_dist = ignore_logprob(init_dist)
914972

915-
Parameters
916-
----------
917-
x: numeric
918-
Value for which log-probability is calculated.
973+
return super().dist([dt, sde_fn, sde_pars, init_dist, steps], **kwargs)
919974

920-
Returns
921-
-------
922-
TensorVariable
923-
"""
924-
xt = x[:-1]
925-
f, g = self.sde_fn(x[:-1], *self.sde_pars)
926-
mu = xt + self.dt * f
927-
sigma = at.sqrt(self.dt) * g
928-
return at.sum(Normal.dist(mu=mu, sigma=sigma).logp(x[1:]))
929-
930-
def _distr_parameters_for_repr(self):
931-
return ["dt"]
975+
@classmethod
976+
def rv_op(cls, dt, sde_fn, sde_pars, init_dist, steps, size=None):
977+
# Init dist should have shape (*size, ar_order)
978+
if size is not None:
979+
batch_size = size
980+
else:
981+
# In this case the size of the init_dist depends on the parameters shape
982+
# The last dimension of rho and init_dist does not matter
983+
batch_size = at.broadcast_shape(*sde_pars, at.atleast_1d(init_dist)[..., 0])
984+
init_dist = change_dist_size(init_dist, batch_size)
985+
986+
# Create OpFromGraph representing random draws form AR process
987+
# Variables with underscore suffix are dummy inputs into the OpFromGraph
988+
init_ = init_dist.type()
989+
sde_pars_ = [x.type() for x in sde_pars]
990+
steps_ = steps.type()
991+
992+
noise_rng = aesara.shared(np.random.default_rng())
993+
994+
def step(*prev_args):
995+
prev_y, *prev_sde_pars, rng = prev_args
996+
f, g = sde_fn(prev_y, *prev_sde_pars)
997+
mu = prev_y + dt * f
998+
sigma = at.sqrt(dt) * g
999+
next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs
1000+
return next_y, {rng: next_rng}
1001+
1002+
y_t, innov_updates_ = aesara.scan(
1003+
fn=step,
1004+
outputs_info=[init_],
1005+
non_sequences=sde_pars_ + [noise_rng],
1006+
n_steps=steps_,
1007+
strict=True,
1008+
)
1009+
(noise_next_rng,) = tuple(innov_updates_.values())
1010+
1011+
sde_out_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle(
1012+
tuple(range(1, y_t.ndim)) + (0,)
1013+
)
1014+
1015+
eulermaruyama_op = EulerMaruyamaRV(
1016+
inputs=[init_, steps_] + sde_pars_,
1017+
outputs=[noise_next_rng, sde_out_],
1018+
dt=dt,
1019+
sde_fn=sde_fn,
1020+
ndim_supp=1,
1021+
)
1022+
1023+
eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars)
1024+
return eulermaruyama
1025+
1026+
1027+
@_logprob.register(EulerMaruyamaRV)
1028+
def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs):
1029+
(x,) = values
1030+
# noise arg is unused, but is needed to make the logp signature match the rv_op signature
1031+
*sde_pars, _ = sde_pars_noise_arg
1032+
xtm1 = x[..., :-1]
1033+
xt = x[..., 1:]
1034+
f, g = op.sde_fn(xtm1, *sde_pars)
1035+
mu = xtm1 + op.dt * f
1036+
sigma = at.sqrt(op.dt) * g
1037+
# Compute and collapse logp across time dimension
1038+
sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1)
1039+
init_logp = logp(init_dist, x[..., :1])
1040+
if init_dist.owner.op.ndim_supp == 0:
1041+
init_logp = at.sum(init_logp, axis=-1)
1042+
return init_logp + sde_logp

pymc/tests/distributions/test_timeseries.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -830,37 +830,37 @@ def test_change_dist_size(self):
830830
assert new_dist.eval().shape == (4, 3, 10)
831831

832832

833-
def _gen_sde_path(sde, pars, dt, n, x0):
834-
xs = [x0]
835-
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
836-
for i in range(n):
837-
f, g = sde(xs[-1], *pars)
838-
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
839-
return np.array(xs)
840-
841-
842-
@pytest.mark.xfail(reason="Euleryama not refactored", raises=NotImplementedError)
843-
def test_linear():
844-
lam = -0.78
845-
sig2 = 5e-3
846-
N = 300
847-
dt = 1e-1
848-
sde = lambda x, lam: (lam * x, sig2)
849-
x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0))
850-
z = x + np.random.randn(x.size) * sig2
851-
# build model
852-
with Model() as model:
853-
lamh = Flat("lamh")
854-
xh = EulerMaruyama("xh", dt, sde, (lamh,), shape=N + 1, initval=x)
855-
Normal("zh", mu=xh, sigma=sig2, observed=z)
856-
# invert
857-
with model:
858-
trace = sample(init="advi+adapt_diag", chains=1)
859-
860-
ppc = sample_posterior_predictive(trace, model=model)
861-
862-
p95 = [2.5, 97.5]
863-
lo, hi = np.percentile(trace[lamh], p95, axis=0)
864-
assert (lo < lam) and (lam < hi)
865-
lo, hi = np.percentile(ppc["zh"], p95, axis=0)
866-
assert ((lo < z) * (z < hi)).mean() > 0.95
833+
class TestEulerMaruyama:
834+
835+
def _gen_sde_path(self, sde, pars, dt, n, x0):
836+
xs = [x0]
837+
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
838+
for i in range(n):
839+
f, g = sde(xs[-1], *pars)
840+
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
841+
return np.array(xs)
842+
843+
def test_linear(self):
844+
lam = -0.78
845+
sig2 = 5e-3
846+
N = 300
847+
dt = 1e-1
848+
sde = lambda x, lam: (lam * x, sig2)
849+
x = floatX(self._gen_sde_path(sde, (lam,), dt, N, 5.0))
850+
z = x + np.random.randn(x.size) * sig2
851+
# build model
852+
with Model() as model:
853+
lamh = Flat("lamh")
854+
xh = EulerMaruyama("xh", dt, sde, (lamh,), steps=N, initval=x)
855+
Normal("zh", mu=xh, sigma=sig2, observed=z)
856+
# invert
857+
with model:
858+
trace = sample(chains=1)
859+
860+
ppc = sample_posterior_predictive(trace, model=model)
861+
862+
p95 = [2.5, 97.5]
863+
lo, hi = np.percentile(trace.posterior["lamh"], p95, axis=[0, 1])
864+
assert (lo < lam) and (lam < hi)
865+
lo, hi = np.percentile(ppc.posterior_predictive["zh"], p95, axis=[0, 1])
866+
assert ((lo < z) * (z < hi)).mean() > 0.95

0 commit comments

Comments
 (0)