Skip to content

Commit be5efae

Browse files
authored
Refactor EulerMaruyama to work in v4 (#6227)
1 parent 49ad534 commit be5efae

File tree

2 files changed

+277
-66
lines changed

2 files changed

+277
-66
lines changed

pymc/distributions/timeseries.py

Lines changed: 160 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import warnings
1616

1717
from abc import ABCMeta
18-
from typing import Optional
18+
from typing import Callable, Optional
1919

2020
import aesara
2121
import aesara.tensor as at
@@ -27,7 +27,6 @@
2727
from aesara.tensor.random.op import RandomVariable
2828

2929
from pymc.aesaraf import constant_fold, floatX, intX
30-
from pymc.distributions import distribution
3130
from pymc.distributions.continuous import Normal, get_tau_sigma
3231
from pymc.distributions.distribution import (
3332
Distribution,
@@ -461,7 +460,7 @@ class AR(Distribution):
461460
process.
462461
init_dist : unnamed distribution, optional
463462
Scalar or vector distribution for initial values. Unnamed refers to distributions
464-
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
463+
created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
465464
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
466465
467466
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -881,7 +880,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps
881880
return at.zeros_like(rv)
882881

883882

884-
class EulerMaruyama(distribution.Continuous):
883+
class EulerMaruyamaRV(SymbolicRandomVariable):
884+
"""A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph."""
885+
886+
default_output = 1
887+
dt: float
888+
sde_fn: Callable
889+
_print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}")
890+
891+
def __init__(self, *args, dt, sde_fn, **kwargs):
892+
self.dt = dt
893+
self.sde_fn = sde_fn
894+
super().__init__(*args, **kwargs)
895+
896+
def update(self, node: Node):
897+
"""Return the update mapping for the noise RV."""
898+
# Since noise is a shared variable it shows up as the last node input
899+
return {node.inputs[-1]: node.outputs[0]}
900+
901+
902+
class EulerMaruyama(Distribution):
885903
r"""
886904
Stochastic differential equation discretized with the Euler-Maruyama method.
887905
@@ -893,39 +911,149 @@ class EulerMaruyama(distribution.Continuous):
893911
function returning the drift and diffusion coefficients of SDE
894912
sde_pars: tuple
895913
parameters of the SDE, passed as ``*args`` to ``sde_fn``
914+
init_dist : unnamed distribution, optional
915+
Scalar distribution for initial values. Unnamed refers to distributions created with
916+
the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
917+
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
918+
919+
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
896920
"""
897921

898-
def __new__(cls, *args, **kwargs):
899-
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
922+
rv_type = EulerMaruyamaRV
923+
924+
def __new__(cls, name, dt, sde_fn, *args, steps=None, **kwargs):
925+
dt = at.as_tensor_variable(floatX(dt))
926+
steps = get_support_shape_1d(
927+
support_shape=steps,
928+
shape=None, # Shape will be checked in `cls.dist`
929+
dims=kwargs.get("dims", None),
930+
observed=kwargs.get("observed", None),
931+
support_shape_offset=1,
932+
)
933+
return super().__new__(cls, name, dt, sde_fn, *args, steps=steps, **kwargs)
900934

901935
@classmethod
902-
def dist(cls, *args, **kwargs):
903-
raise NotImplementedError(f"{cls.__name__} has not yet been ported to PyMC 4.0.")
936+
def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
937+
steps = get_support_shape_1d(
938+
support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1
939+
)
940+
if steps is None:
941+
raise ValueError("Must specify steps or shape parameter")
942+
steps = at.as_tensor_variable(intX(steps), ndim=0)
904943

905-
def __init__(self, dt, sde_fn, sde_pars, *args, **kwds):
906-
super().__init__(*args, **kwds)
907-
self.dt = dt = at.as_tensor_variable(dt)
908-
self.sde_fn = sde_fn
909-
self.sde_pars = sde_pars
944+
dt = at.as_tensor_variable(floatX(dt))
945+
sde_pars = [at.as_tensor_variable(x) for x in sde_pars]
910946

911-
def logp(self, x):
912-
"""
913-
Calculate log-probability of EulerMaruyama distribution at specified value.
947+
if init_dist is not None:
948+
if not isinstance(init_dist, TensorVariable) or not isinstance(
949+
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
950+
):
951+
raise ValueError(
952+
f"Init dist must be a distribution created via the `.dist()` API, "
953+
f"got {type(init_dist)}"
954+
)
955+
check_dist_not_registered(init_dist)
956+
if init_dist.owner.op.ndim_supp > 0:
957+
raise ValueError(
958+
"Init distribution must have a scalar support dimension, ",
959+
f"got ndim_supp={init_dist.owner.op.ndim_supp}.",
960+
)
961+
else:
962+
warnings.warn(
963+
"Initial distribution not specified, defaulting to "
964+
"`Normal.dist(0, 100, shape=...)`. You can specify an init_dist "
965+
"manually to suppress this warning.",
966+
UserWarning,
967+
)
968+
init_dist = Normal.dist(0, 100, shape=sde_pars[0].shape)
969+
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
970+
init_dist = ignore_logprob(init_dist)
914971

915-
Parameters
916-
----------
917-
x: numeric
918-
Value for which log-probability is calculated.
972+
return super().dist([init_dist, steps, sde_pars, dt, sde_fn], **kwargs)
919973

920-
Returns
921-
-------
922-
TensorVariable
923-
"""
924-
xt = x[:-1]
925-
f, g = self.sde_fn(x[:-1], *self.sde_pars)
926-
mu = xt + self.dt * f
927-
sigma = at.sqrt(self.dt) * g
928-
return at.sum(Normal.dist(mu=mu, sigma=sigma).logp(x[1:]))
929-
930-
def _distr_parameters_for_repr(self):
931-
return ["dt"]
974+
@classmethod
975+
def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
976+
# Init dist should have shape (*size,)
977+
if size is not None:
978+
batch_size = size
979+
else:
980+
batch_size = at.broadcast_shape(*sde_pars, init_dist)
981+
init_dist = change_dist_size(init_dist, batch_size)
982+
983+
# Create OpFromGraph representing random draws from SDE process
984+
# Variables with underscore suffix are dummy inputs into the OpFromGraph
985+
init_ = init_dist.type()
986+
sde_pars_ = [x.type() for x in sde_pars]
987+
steps_ = steps.type()
988+
989+
noise_rng = aesara.shared(np.random.default_rng())
990+
991+
def step(*prev_args):
992+
prev_y, *prev_sde_pars, rng = prev_args
993+
f, g = sde_fn(prev_y, *prev_sde_pars)
994+
mu = prev_y + dt * f
995+
sigma = at.sqrt(dt) * g
996+
next_rng, next_y = Normal.dist(mu=mu, sigma=sigma, rng=rng).owner.outputs
997+
return next_y, {rng: next_rng}
998+
999+
y_t, innov_updates_ = aesara.scan(
1000+
fn=step,
1001+
outputs_info=[init_],
1002+
non_sequences=sde_pars_ + [noise_rng],
1003+
n_steps=steps_,
1004+
strict=True,
1005+
)
1006+
(noise_next_rng,) = tuple(innov_updates_.values())
1007+
1008+
sde_out_ = at.concatenate([init_[None, ...], y_t], axis=0).dimshuffle(
1009+
tuple(range(1, y_t.ndim)) + (0,)
1010+
)
1011+
1012+
eulermaruyama_op = EulerMaruyamaRV(
1013+
inputs=[init_, steps_] + sde_pars_,
1014+
outputs=[noise_next_rng, sde_out_],
1015+
dt=dt,
1016+
sde_fn=sde_fn,
1017+
ndim_supp=1,
1018+
)
1019+
1020+
eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars)
1021+
return eulermaruyama
1022+
1023+
1024+
@_change_dist_size.register(EulerMaruyamaRV)
1025+
def change_eulermaruyama_size(op, dist, new_size, expand=False):
1026+
1027+
if expand:
1028+
old_size = dist.shape[:-1]
1029+
new_size = tuple(new_size) + tuple(old_size)
1030+
1031+
init_dist, steps, *sde_pars, _ = dist.owner.inputs
1032+
return EulerMaruyama.rv_op(
1033+
init_dist,
1034+
steps,
1035+
sde_pars,
1036+
dt=op.dt,
1037+
sde_fn=op.sde_fn,
1038+
size=new_size,
1039+
)
1040+
1041+
1042+
@_logprob.register(EulerMaruyamaRV)
1043+
def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwargs):
1044+
(x,) = values
1045+
# noise arg is unused, but is needed to make the logp signature match the rv_op signature
1046+
*sde_pars, _ = sde_pars_noise_arg
1047+
# sde_fn is user provided and likely not broadcastable to additional time dimension,
1048+
# since the input x is now [..., t], we need to broadcast each input to [..., None]
1049+
# below as best effort attempt to make it work
1050+
sde_pars_broadcast = [x[..., None] for x in sde_pars]
1051+
xtm1 = x[..., :-1]
1052+
xt = x[..., 1:]
1053+
f, g = op.sde_fn(xtm1, *sde_pars_broadcast)
1054+
mu = xtm1 + op.dt * f
1055+
sigma = at.sqrt(op.dt) * g
1056+
# Compute and collapse logp across time dimension
1057+
sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1)
1058+
init_logp = logp(init_dist, x[..., 0])
1059+
return init_logp + sde_logp

pymc/tests/distributions/test_timeseries.py

Lines changed: 117 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -830,37 +830,120 @@ def test_change_dist_size(self):
830830
assert new_dist.eval().shape == (4, 3, 10)
831831

832832

833-
def _gen_sde_path(sde, pars, dt, n, x0):
834-
xs = [x0]
835-
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
836-
for i in range(n):
837-
f, g = sde(xs[-1], *pars)
838-
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
839-
return np.array(xs)
840-
841-
842-
@pytest.mark.xfail(reason="Euleryama not refactored", raises=NotImplementedError)
843-
def test_linear():
844-
lam = -0.78
845-
sig2 = 5e-3
846-
N = 300
847-
dt = 1e-1
848-
sde = lambda x, lam: (lam * x, sig2)
849-
x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0))
850-
z = x + np.random.randn(x.size) * sig2
851-
# build model
852-
with Model() as model:
853-
lamh = Flat("lamh")
854-
xh = EulerMaruyama("xh", dt, sde, (lamh,), shape=N + 1, initval=x)
855-
Normal("zh", mu=xh, sigma=sig2, observed=z)
856-
# invert
857-
with model:
858-
trace = sample(init="advi+adapt_diag", chains=1)
859-
860-
ppc = sample_posterior_predictive(trace, model=model)
861-
862-
p95 = [2.5, 97.5]
863-
lo, hi = np.percentile(trace[lamh], p95, axis=0)
864-
assert (lo < lam) and (lam < hi)
865-
lo, hi = np.percentile(ppc["zh"], p95, axis=0)
866-
assert ((lo < z) * (z < hi)).mean() > 0.95
833+
class TestEulerMaruyama:
834+
@pytest.mark.parametrize("batched_param", [1, 2])
835+
@pytest.mark.parametrize("explicit_shape", (True, False))
836+
def test_batched_size(self, explicit_shape, batched_param):
837+
steps, batch_size = 100, 5
838+
param_val = np.square(np.random.randn(batch_size))
839+
if explicit_shape:
840+
kwargs = {"shape": (batch_size, steps)}
841+
else:
842+
kwargs = {"steps": steps - 1}
843+
844+
def sde_fn(x, k, d, s):
845+
return (k - d * x, s)
846+
847+
sde_pars = [1.0, 2.0, 0.1]
848+
sde_pars[batched_param] = sde_pars[batched_param] * param_val
849+
with Model() as t0:
850+
init_dist = pm.Normal.dist(0, 10, shape=(batch_size,))
851+
y = EulerMaruyama(
852+
"y", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, init_dist=init_dist, **kwargs
853+
)
854+
855+
y_eval = draw(y, draws=2)
856+
assert y_eval[0].shape == (batch_size, steps)
857+
assert not np.any(np.isclose(y_eval[0], y_eval[1]))
858+
859+
if explicit_shape:
860+
kwargs["shape"] = steps
861+
with Model() as t1:
862+
for i in range(batch_size):
863+
sde_pars_slice = sde_pars.copy()
864+
sde_pars_slice[batched_param] = sde_pars[batched_param][i]
865+
init_dist = pm.Normal.dist(0, 10)
866+
EulerMaruyama(
867+
f"y_{i}",
868+
dt=0.02,
869+
sde_fn=sde_fn,
870+
sde_pars=sde_pars_slice,
871+
init_dist=init_dist,
872+
**kwargs,
873+
)
874+
875+
t0_init = t0.initial_point()
876+
t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)}
877+
np.testing.assert_allclose(
878+
t0.compile_logp()(t0_init),
879+
t1.compile_logp()(t1_init),
880+
)
881+
882+
def test_change_dist_size1(self):
883+
def sde1(x, k, d, s):
884+
return (k - d * x, s)
885+
886+
base_dist = EulerMaruyama.dist(
887+
dt=0.01,
888+
sde_fn=sde1,
889+
sde_pars=(1, 2, 0.1),
890+
init_dist=pm.Normal.dist(0, 10),
891+
shape=(5, 10),
892+
)
893+
894+
new_dist = change_dist_size(base_dist, (4,))
895+
assert new_dist.eval().shape == (4, 10)
896+
897+
new_dist = change_dist_size(base_dist, (4,), expand=True)
898+
assert new_dist.eval().shape == (4, 5, 10)
899+
900+
def test_change_dist_size2(self):
901+
def sde2(p, s):
902+
N = 500.0
903+
return s * p * (1 - p) / (1 + s * p), pm.math.sqrt(p * (1 - p) / N)
904+
905+
base_dist = EulerMaruyama.dist(
906+
dt=0.01, sde_fn=sde2, sde_pars=(0.1,), init_dist=pm.Normal.dist(0, 10), shape=(3, 10)
907+
)
908+
909+
new_dist = change_dist_size(base_dist, (4,))
910+
assert new_dist.eval().shape == (4, 10)
911+
912+
new_dist = change_dist_size(base_dist, (4,), expand=True)
913+
assert new_dist.eval().shape == (4, 3, 10)
914+
915+
def test_linear_model(self):
916+
lam = -0.78
917+
sig2 = 5e-3
918+
N = 300
919+
dt = 1e-1
920+
921+
def _gen_sde_path(sde, pars, dt, n, x0):
922+
xs = [x0]
923+
wt = np.random.normal(size=(n,) if isinstance(x0, float) else (n, x0.size))
924+
for i in range(n):
925+
f, g = sde(xs[-1], *pars)
926+
xs.append(xs[-1] + f * dt + np.sqrt(dt) * g * wt[i])
927+
return np.array(xs)
928+
929+
sde = lambda x, lam: (lam * x, sig2)
930+
x = floatX(_gen_sde_path(sde, (lam,), dt, N, 5.0))
931+
z = x + np.random.randn(x.size) * sig2
932+
# build model
933+
with Model() as model:
934+
lamh = Flat("lamh")
935+
xh = EulerMaruyama(
936+
"xh", dt, sde, (lamh,), steps=N, initval=x, init_dist=pm.Normal.dist(0, 10)
937+
)
938+
Normal("zh", mu=xh, sigma=sig2, observed=z)
939+
# invert
940+
with model:
941+
trace = sample(chains=1)
942+
943+
ppc = sample_posterior_predictive(trace, model=model)
944+
945+
p95 = [2.5, 97.5]
946+
lo, hi = np.percentile(trace.posterior["lamh"], p95, axis=[0, 1])
947+
assert (lo < lam) and (lam < hi)
948+
lo, hi = np.percentile(ppc.posterior_predictive["zh"], p95, axis=[0, 1])
949+
assert ((lo < z) * (z < hi)).mean() > 0.95

0 commit comments

Comments
 (0)