15
15
import warnings
16
16
17
17
from abc import ABCMeta
18
- from typing import Optional
18
+ from typing import Callable , Optional
19
19
20
20
import aesara
21
21
import aesara .tensor as at
@@ -881,7 +881,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps
881
881
return at .zeros_like (rv )
882
882
883
883
884
- class EulerMaruyama (distribution .Continuous ):
884
+ class EulerMaruyamaRV (SymbolicRandomVariable ):
885
+ """A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph."""
886
+
887
+ default_output = 1
888
+ dt : float
889
+ sde_fn : Callable
890
+ _print_name = ("EulerMaruyama" , "\\ operatorname{EulerMaruyama}" )
891
+
892
+ def __init__ (self , * args , dt , sde_fn , ** kwargs ):
893
+ self .dt = dt
894
+ self .sde_fn = sde_fn
895
+ super ().__init__ (* args , ** kwargs )
896
+
897
+ def update (self , node : Node ):
898
+ """Return the update mapping for the noise RV."""
899
+ # Since noise is a shared variable it shows up as the last node input
900
+ return {node .inputs [- 1 ]: node .outputs [0 ]}
901
+
902
+
903
+ class EulerMaruyama (Distribution ):
885
904
r"""
886
905
Stochastic differential equation discretized with the Euler-Maruyama method.
887
906
@@ -893,39 +912,131 @@ class EulerMaruyama(distribution.Continuous):
893
912
function returning the drift and diffusion coefficients of SDE
894
913
sde_pars: tuple
895
914
parameters of the SDE, passed as ``*args`` to ``sde_fn``
915
+ init_dist : unnamed distribution, optional
916
+ Scalar or vector distribution for initial values. Unnamed refers to distributions
917
+ created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
918
+ If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
919
+
920
+ .. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
896
921
"""
897
922
898
- def __new__ (cls , * args , ** kwargs ):
899
- raise NotImplementedError (f"{ cls .__name__ } has not yet been ported to PyMC 4.0." )
923
+ rv_type = EulerMaruyamaRV
924
+
925
+ def __new__ (cls , name , dt , sde_fn , * args , steps = None , ** kwargs ):
926
+ dt = at .as_tensor_variable (floatX (dt ))
927
+ steps = get_support_shape_1d (
928
+ support_shape = steps ,
929
+ shape = None , # Shape will be checked in `cls.dist`
930
+ dims = kwargs .get ("dims" , None ),
931
+ observed = kwargs .get ("observed" , None ),
932
+ support_shape_offset = 1 ,
933
+ )
934
+ return super ().__new__ (cls , name , dt , sde_fn , * args , steps = steps , ** kwargs )
900
935
901
936
@classmethod
902
- def dist (cls , * args , ** kwargs ):
903
- raise NotImplementedError (f"{ cls .__name__ } has not yet been ported to PyMC 4.0." )
937
+ def dist (cls , dt , sde_fn , sde_pars , * , init_dist = None , steps = None , ** kwargs ):
938
+ steps = get_support_shape_1d (
939
+ support_shape = steps , shape = kwargs .get ("shape" , None ), support_shape_offset = 1
940
+ )
941
+ if steps is None :
942
+ raise ValueError ("Must specify steps or shape parameter" )
943
+ steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
904
944
905
- def __init__ (self , dt , sde_fn , sde_pars , * args , ** kwds ):
906
- super ().__init__ (* args , ** kwds )
907
- self .dt = dt = at .as_tensor_variable (dt )
908
- self .sde_fn = sde_fn
909
- self .sde_pars = sde_pars
945
+ dt = at .as_tensor_variable (floatX (dt ))
946
+ sde_pars = [at .as_tensor_variable (x ) for x in sde_pars ]
910
947
911
- def logp (self , x ):
912
- """
913
- Calculate log-probability of EulerMaruyama distribution at specified value.
948
+ if init_dist is not None :
949
+ if not isinstance (init_dist , TensorVariable ) or not isinstance (
950
+ init_dist .owner .op , (RandomVariable , SymbolicRandomVariable )
951
+ ):
952
+ raise ValueError (
953
+ f"Init dist must be a distribution created via the `.dist()` API, "
954
+ f"got { type (init_dist )} "
955
+ )
956
+ check_dist_not_registered (init_dist )
957
+ if init_dist .owner .op .ndim_supp > 1 :
958
+ raise ValueError (
959
+ "Init distribution must have a scalar or vector support dimension, " ,
960
+ f"got ndim_supp={ init_dist .owner .op .ndim_supp } ." ,
961
+ )
962
+ else :
963
+ warnings .warn (
964
+ "Initial distribution not specified, defaulting to "
965
+ "`Normal.dist(0, 100, shape=...)`. You can specify an init_dist "
966
+ "manually to suppress this warning." ,
967
+ UserWarning ,
968
+ )
969
+ init_dist = Normal .dist (0 , 100 , shape = sde_pars [0 ].shape )
970
+ # Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
971
+ init_dist = ignore_logprob (init_dist )
914
972
915
- Parameters
916
- ----------
917
- x: numeric
918
- Value for which log-probability is calculated.
973
+ return super ().dist ([dt , sde_fn , sde_pars , init_dist , steps ], ** kwargs )
919
974
920
- Returns
921
- -------
922
- TensorVariable
923
- """
924
- xt = x [:- 1 ]
925
- f , g = self .sde_fn (x [:- 1 ], * self .sde_pars )
926
- mu = xt + self .dt * f
927
- sigma = at .sqrt (self .dt ) * g
928
- return at .sum (Normal .dist (mu = mu , sigma = sigma ).logp (x [1 :]))
929
-
930
- def _distr_parameters_for_repr (self ):
931
- return ["dt" ]
975
+ @classmethod
976
+ def rv_op (cls , dt , sde_fn , sde_pars , init_dist , steps , size = None ):
977
+ # Init dist should have shape (*size, ar_order)
978
+ if size is not None :
979
+ batch_size = size
980
+ else :
981
+ # In this case the size of the init_dist depends on the parameters shape
982
+ # The last dimension of rho and init_dist does not matter
983
+ batch_size = at .broadcast_shape (* sde_pars , at .atleast_1d (init_dist )[..., 0 ])
984
+ init_dist = change_dist_size (init_dist , batch_size )
985
+
986
+ # Create OpFromGraph representing random draws form AR process
987
+ # Variables with underscore suffix are dummy inputs into the OpFromGraph
988
+ init_ = init_dist .type ()
989
+ sde_pars_ = [x .type () for x in sde_pars ]
990
+ steps_ = steps .type ()
991
+
992
+ noise_rng = aesara .shared (np .random .default_rng ())
993
+
994
+ def step (* prev_args ):
995
+ prev_y , * prev_sde_pars , rng = prev_args
996
+ f , g = sde_fn (prev_y , * prev_sde_pars )
997
+ mu = prev_y + dt * f
998
+ sigma = at .sqrt (dt ) * g
999
+ next_rng , next_y = Normal .dist (mu = mu , sigma = sigma , rng = rng ).owner .outputs
1000
+ return next_y , {rng : next_rng }
1001
+
1002
+ y_t , innov_updates_ = aesara .scan (
1003
+ fn = step ,
1004
+ outputs_info = [init_ ],
1005
+ non_sequences = sde_pars_ + [noise_rng ],
1006
+ n_steps = steps_ ,
1007
+ strict = True ,
1008
+ )
1009
+ (noise_next_rng ,) = tuple (innov_updates_ .values ())
1010
+
1011
+ sde_out_ = at .concatenate ([init_ [None , ...], y_t ], axis = 0 ).dimshuffle (
1012
+ tuple (range (1 , y_t .ndim )) + (0 ,)
1013
+ )
1014
+
1015
+ eulermaruyama_op = EulerMaruyamaRV (
1016
+ inputs = [init_ , steps_ ] + sde_pars_ ,
1017
+ outputs = [noise_next_rng , sde_out_ ],
1018
+ dt = dt ,
1019
+ sde_fn = sde_fn ,
1020
+ ndim_supp = 1 ,
1021
+ )
1022
+
1023
+ eulermaruyama = eulermaruyama_op (init_dist , steps , * sde_pars )
1024
+ return eulermaruyama
1025
+
1026
+
1027
+ @_logprob .register (EulerMaruyamaRV )
1028
+ def eulermaruyama_logp (op , values , init_dist , steps , * sde_pars_noise_arg , ** kwargs ):
1029
+ (x ,) = values
1030
+ # noise arg is unused, but is needed to make the logp signature match the rv_op signature
1031
+ * sde_pars , _ = sde_pars_noise_arg
1032
+ xtm1 = x [..., :- 1 ]
1033
+ xt = x [..., 1 :]
1034
+ f , g = op .sde_fn (xtm1 , * sde_pars )
1035
+ mu = xtm1 + op .dt * f
1036
+ sigma = at .sqrt (op .dt ) * g
1037
+ # Compute and collapse logp across time dimension
1038
+ sde_logp = at .sum (logp (Normal .dist (mu , sigma ), xt ), axis = - 1 )
1039
+ init_logp = logp (init_dist , x [..., :1 ])
1040
+ if init_dist .owner .op .ndim_supp == 0 :
1041
+ init_logp = at .sum (init_logp , axis = - 1 )
1042
+ return init_logp + sde_logp
0 commit comments