Skip to content

Commit b6fa76b

Browse files
committed
Fix broadcasting
1 parent 95c784c commit b6fa76b

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

pymc/distributions/timeseries.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
980980
batch_size = at.broadcast_shape(*sde_pars, init_dist)
981981
init_dist = change_dist_size(init_dist, batch_size)
982982

983-
# Create OpFromGraph representing random draws form AR process
983+
# Create OpFromGraph representing random draws from SDE process
984984
# Variables with underscore suffix are dummy inputs into the OpFromGraph
985985
init_ = init_dist.type()
986986
sde_pars_ = [x.type() for x in sde_pars]
@@ -1044,14 +1044,16 @@ def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwar
10441044
(x,) = values
10451045
# noise arg is unused, but is needed to make the logp signature match the rv_op signature
10461046
*sde_pars, _ = sde_pars_noise_arg
1047+
# sde_fn is user provided and likely not broadcastable to additional time dimension,
1048+
# since the input x is now [..., t], we need to broadcast each input to [..., None]
1049+
# below as best effort attempt to make it work
1050+
sde_pars_broadcast = [x[..., None] for x in sde_pars]
10471051
xtm1 = x[..., :-1]
10481052
xt = x[..., 1:]
1049-
f, g = op.sde_fn(xtm1, *sde_pars)
1053+
f, g = op.sde_fn(xtm1, *sde_pars_broadcast)
10501054
mu = xtm1 + op.dt * f
10511055
sigma = at.sqrt(op.dt) * g
10521056
# Compute and collapse logp across time dimension
10531057
sde_logp = at.sum(logp(Normal.dist(mu, sigma), xt), axis=-1)
1054-
init_logp = logp(init_dist, x[..., :1])
1055-
if init_dist.owner.op.ndim_supp == 0:
1056-
init_logp = at.sum(init_logp, axis=-1)
1058+
init_logp = logp(init_dist, x[..., 0])
10571059
return init_logp + sde_logp

pymc/tests/distributions/test_timeseries.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -859,11 +859,13 @@ def sde_fn(x, k, d, s):
859859
for i in range(batch_size):
860860
sde_pars_slice = sde_pars.copy()
861861
sde_pars_slice[batched_param] = sde_pars[batched_param][i]
862-
EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars, **kwargs)
862+
EulerMaruyama(f"y_{i}", dt=0.02, sde_fn=sde_fn, sde_pars=sde_pars_slice, **kwargs)
863863

864+
t0_init = t0.initial_point()
865+
t1_init = {f"y_{i}": t0_init["y"][i] for i in range(batch_size)}
864866
np.testing.assert_allclose(
865-
t0.compile_logp()(t0.initial_point()),
866-
t1.compile_logp()(t1.initial_point()),
867+
t0.compile_logp()(t0_init),
868+
t1.compile_logp()(t1_init),
867869
)
868870

869871
def test_change_dist_size1(self):

0 commit comments

Comments
 (0)