@@ -980,7 +980,7 @@ def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
980
980
batch_size = at .broadcast_shape (* sde_pars , init_dist )
981
981
init_dist = change_dist_size (init_dist , batch_size )
982
982
983
- # Create OpFromGraph representing random draws form AR process
983
+ # Create OpFromGraph representing random draws from SDE process
984
984
# Variables with underscore suffix are dummy inputs into the OpFromGraph
985
985
init_ = init_dist .type ()
986
986
sde_pars_ = [x .type () for x in sde_pars ]
@@ -1044,14 +1044,16 @@ def eulermaruyama_logp(op, values, init_dist, steps, *sde_pars_noise_arg, **kwar
1044
1044
(x ,) = values
1045
1045
# noise arg is unused, but is needed to make the logp signature match the rv_op signature
1046
1046
* sde_pars , _ = sde_pars_noise_arg
1047
+ # sde_fn is user provided and likely not broadcastable to additional time dimension,
1048
+ # since the input x is now [..., t], we need to broadcast each input to [..., None]
1049
+ # below as best effort attempt to make it work
1050
+ sde_pars_broadcast = [x [..., None ] for x in sde_pars ]
1047
1051
xtm1 = x [..., :- 1 ]
1048
1052
xt = x [..., 1 :]
1049
- f , g = op .sde_fn (xtm1 , * sde_pars )
1053
+ f , g = op .sde_fn (xtm1 , * sde_pars_broadcast )
1050
1054
mu = xtm1 + op .dt * f
1051
1055
sigma = at .sqrt (op .dt ) * g
1052
1056
# Compute and collapse logp across time dimension
1053
1057
sde_logp = at .sum (logp (Normal .dist (mu , sigma ), xt ), axis = - 1 )
1054
- init_logp = logp (init_dist , x [..., :1 ])
1055
- if init_dist .owner .op .ndim_supp == 0 :
1056
- init_logp = at .sum (init_logp , axis = - 1 )
1058
+ init_logp = logp (init_dist , x [..., 0 ])
1057
1059
return init_logp + sde_logp
0 commit comments