27
27
from aesara .tensor .random .op import RandomVariable
28
28
29
29
from pymc .aesaraf import constant_fold , floatX , intX
30
- from pymc .distributions import distribution
31
30
from pymc .distributions .continuous import Normal , get_tau_sigma
32
31
from pymc .distributions .distribution import (
33
32
Distribution ,
@@ -461,7 +460,7 @@ class AR(Distribution):
461
460
process.
462
461
init_dist : unnamed distribution, optional
463
462
Scalar or vector distribution for initial values. Unnamed refers to distributions
464
- created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
463
+ created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
465
464
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
466
465
467
466
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -914,7 +913,7 @@ class EulerMaruyama(Distribution):
914
913
parameters of the SDE, passed as ``*args`` to ``sde_fn``
915
914
init_dist : unnamed distribution, optional
916
915
Scalar or vector distribution for initial values. Unnamed refers to distributions
917
- created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order ).
916
+ created with the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
918
917
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
919
918
920
919
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -953,7 +952,7 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
953
952
f"Init dist must be a distribution created via the `.dist()` API, "
954
953
f"got { type (init_dist )} "
955
954
)
956
- check_dist_not_registered (init_dist )
955
+ check_dist_not_registered (init_dist )
957
956
if init_dist .owner .op .ndim_supp > 1 :
958
957
raise ValueError (
959
958
"Init distribution must have a scalar or vector support dimension, " ,
@@ -970,17 +969,15 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
970
969
# Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
971
970
init_dist = ignore_logprob (init_dist )
972
971
973
- return super ().dist ([dt , sde_fn , sde_pars , init_dist , steps ], ** kwargs )
972
+ return super ().dist ([init_dist , steps , sde_pars , dt , sde_fn ], ** kwargs )
974
973
975
974
@classmethod
976
- def rv_op (cls , dt , sde_fn , sde_pars , init_dist , steps , size = None ):
977
- # Init dist should have shape (*size, ar_order )
975
+ def rv_op (cls , init_dist , steps , sde_pars , dt , sde_fn , size = None ):
976
+ # Init dist should have shape (*size,)
978
977
if size is not None :
979
978
batch_size = size
980
979
else :
981
- # In this case the size of the init_dist depends on the parameters shape
982
- # The last dimension of rho and init_dist does not matter
983
- batch_size = at .broadcast_shape (* sde_pars , at .atleast_1d (init_dist )[..., 0 ])
980
+ batch_size = at .broadcast_shape (* sde_pars , init_dist )
984
981
init_dist = change_dist_size (init_dist , batch_size )
985
982
986
983
# Create OpFromGraph representing random draws form AR process
@@ -1024,6 +1021,24 @@ def step(*prev_args):
1024
1021
return eulermaruyama
1025
1022
1026
1023
1024
+ @_change_dist_size .register (EulerMaruyamaRV )
1025
+ def change_eulermaruyama_size (op , dist , new_size , expand = False ):
1026
+
1027
+ if expand :
1028
+ old_size = dist .shape [:- 1 ]
1029
+ new_size = tuple (new_size ) + tuple (old_size )
1030
+
1031
+ init_dist , steps , * sde_pars , _ = dist .owner .inputs
1032
+ return EulerMaruyama .rv_op (
1033
+ init_dist ,
1034
+ steps ,
1035
+ sde_pars ,
1036
+ dt = op .dt ,
1037
+ sde_fn = op .sde_fn ,
1038
+ size = new_size ,
1039
+ )
1040
+
1041
+
1027
1042
@_logprob .register (EulerMaruyamaRV )
1028
1043
def eulermaruyama_logp (op , values , init_dist , steps , * sde_pars_noise_arg , ** kwargs ):
1029
1044
(x ,) = values
0 commit comments