15
15
import warnings
16
16
17
17
from abc import ABCMeta
18
- from typing import Optional
18
+ from typing import Callable , Optional
19
19
20
20
import aesara
21
21
import aesara .tensor as at
27
27
from aesara .tensor .random .op import RandomVariable
28
28
29
29
from pymc .aesaraf import constant_fold , floatX , intX
30
- from pymc .distributions import distribution
31
30
from pymc .distributions .continuous import Normal , get_tau_sigma
32
31
from pymc .distributions .distribution import (
33
32
Distribution ,
@@ -461,7 +460,7 @@ class AR(Distribution):
461
460
process.
462
461
init_dist : unnamed distribution, optional
463
462
Scalar or vector distribution for initial values. Unnamed refers to distributions
464
- created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
463
+ created with the ``.dist()`` API. Distributions should have shape (*shape[:-1], ar_order).
465
464
If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
466
465
467
466
.. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
@@ -881,7 +880,26 @@ def garch11_moment(op, rv, omega, alpha_1, beta_1, initial_vol, init_dist, steps
881
880
return at .zeros_like (rv )
882
881
883
882
884
- class EulerMaruyama (distribution .Continuous ):
883
+ class EulerMaruyamaRV (SymbolicRandomVariable ):
884
+ """A placeholder used to specify a log-likelihood for a EulerMaruyama sub-graph."""
885
+
886
+ default_output = 1
887
+ dt : float
888
+ sde_fn : Callable
889
+ _print_name = ("EulerMaruyama" , "\\ operatorname{EulerMaruyama}" )
890
+
891
+ def __init__ (self , * args , dt , sde_fn , ** kwargs ):
892
+ self .dt = dt
893
+ self .sde_fn = sde_fn
894
+ super ().__init__ (* args , ** kwargs )
895
+
896
+ def update (self , node : Node ):
897
+ """Return the update mapping for the noise RV."""
898
+ # Since noise is a shared variable it shows up as the last node input
899
+ return {node .inputs [- 1 ]: node .outputs [0 ]}
900
+
901
+
902
+ class EulerMaruyama (Distribution ):
885
903
r"""
886
904
Stochastic differential equation discretized with the Euler-Maruyama method.
887
905
@@ -893,39 +911,149 @@ class EulerMaruyama(distribution.Continuous):
893
911
function returning the drift and diffusion coefficients of SDE
894
912
sde_pars: tuple
895
913
parameters of the SDE, passed as ``*args`` to ``sde_fn``
914
+ init_dist : unnamed distribution, optional
915
+ Scalar distribution for initial values. Unnamed refers to distributions created with
916
+ the ``.dist()`` API. Distributions should have shape (*shape[:-1]).
917
+ If not, it will be automatically resized. Defaults to pm.Normal.dist(0, 100, shape=...).
918
+
919
+ .. warning:: init_dist will be cloned, rendering it independent of the one passed as input.
896
920
"""
897
921
898
- def __new__ (cls , * args , ** kwargs ):
899
- raise NotImplementedError (f"{ cls .__name__ } has not yet been ported to PyMC 4.0." )
922
+ rv_type = EulerMaruyamaRV
923
+
924
+ def __new__ (cls , name , dt , sde_fn , * args , steps = None , ** kwargs ):
925
+ dt = at .as_tensor_variable (floatX (dt ))
926
+ steps = get_support_shape_1d (
927
+ support_shape = steps ,
928
+ shape = None , # Shape will be checked in `cls.dist`
929
+ dims = kwargs .get ("dims" , None ),
930
+ observed = kwargs .get ("observed" , None ),
931
+ support_shape_offset = 1 ,
932
+ )
933
+ return super ().__new__ (cls , name , dt , sde_fn , * args , steps = steps , ** kwargs )
900
934
901
935
@classmethod
902
- def dist (cls , * args , ** kwargs ):
903
- raise NotImplementedError (f"{ cls .__name__ } has not yet been ported to PyMC 4.0." )
936
+ def dist (cls , dt , sde_fn , sde_pars , * , init_dist = None , steps = None , ** kwargs ):
937
+ steps = get_support_shape_1d (
938
+ support_shape = steps , shape = kwargs .get ("shape" , None ), support_shape_offset = 1
939
+ )
940
+ if steps is None :
941
+ raise ValueError ("Must specify steps or shape parameter" )
942
+ steps = at .as_tensor_variable (intX (steps ), ndim = 0 )
904
943
905
- def __init__ (self , dt , sde_fn , sde_pars , * args , ** kwds ):
906
- super ().__init__ (* args , ** kwds )
907
- self .dt = dt = at .as_tensor_variable (dt )
908
- self .sde_fn = sde_fn
909
- self .sde_pars = sde_pars
944
+ dt = at .as_tensor_variable (floatX (dt ))
945
+ sde_pars = [at .as_tensor_variable (x ) for x in sde_pars ]
910
946
911
- def logp (self , x ):
912
- """
913
- Calculate log-probability of EulerMaruyama distribution at specified value.
947
+ if init_dist is not None :
948
+ if not isinstance (init_dist , TensorVariable ) or not isinstance (
949
+ init_dist .owner .op , (RandomVariable , SymbolicRandomVariable )
950
+ ):
951
+ raise ValueError (
952
+ f"Init dist must be a distribution created via the `.dist()` API, "
953
+ f"got { type (init_dist )} "
954
+ )
955
+ check_dist_not_registered (init_dist )
956
+ if init_dist .owner .op .ndim_supp > 0 :
957
+ raise ValueError (
958
+ "Init distribution must have a scalar support dimension, " ,
959
+ f"got ndim_supp={ init_dist .owner .op .ndim_supp } ." ,
960
+ )
961
+ else :
962
+ warnings .warn (
963
+ "Initial distribution not specified, defaulting to "
964
+ "`Normal.dist(0, 100, shape=...)`. You can specify an init_dist "
965
+ "manually to suppress this warning." ,
966
+ UserWarning ,
967
+ )
968
+ init_dist = Normal .dist (0 , 100 , shape = sde_pars [0 ].shape )
969
+ # Tell Aeppl to ignore init_dist, as it will be accounted for in the logp term
970
+ init_dist = ignore_logprob (init_dist )
914
971
915
- Parameters
916
- ----------
917
- x: numeric
918
- Value for which log-probability is calculated.
972
+ return super ().dist ([init_dist , steps , sde_pars , dt , sde_fn ], ** kwargs )
919
973
920
- Returns
921
- -------
922
- TensorVariable
923
- """
924
- xt = x [:- 1 ]
925
- f , g = self .sde_fn (x [:- 1 ], * self .sde_pars )
926
- mu = xt + self .dt * f
927
- sigma = at .sqrt (self .dt ) * g
928
- return at .sum (Normal .dist (mu = mu , sigma = sigma ).logp (x [1 :]))
929
-
930
- def _distr_parameters_for_repr (self ):
931
- return ["dt" ]
974
+ @classmethod
975
+ def rv_op (cls , init_dist , steps , sde_pars , dt , sde_fn , size = None ):
976
+ # Init dist should have shape (*size,)
977
+ if size is not None :
978
+ batch_size = size
979
+ else :
980
+ batch_size = at .broadcast_shape (* sde_pars , init_dist )
981
+ init_dist = change_dist_size (init_dist , batch_size )
982
+
983
+ # Create OpFromGraph representing random draws from SDE process
984
+ # Variables with underscore suffix are dummy inputs into the OpFromGraph
985
+ init_ = init_dist .type ()
986
+ sde_pars_ = [x .type () for x in sde_pars ]
987
+ steps_ = steps .type ()
988
+
989
+ noise_rng = aesara .shared (np .random .default_rng ())
990
+
991
+ def step (* prev_args ):
992
+ prev_y , * prev_sde_pars , rng = prev_args
993
+ f , g = sde_fn (prev_y , * prev_sde_pars )
994
+ mu = prev_y + dt * f
995
+ sigma = at .sqrt (dt ) * g
996
+ next_rng , next_y = Normal .dist (mu = mu , sigma = sigma , rng = rng ).owner .outputs
997
+ return next_y , {rng : next_rng }
998
+
999
+ y_t , innov_updates_ = aesara .scan (
1000
+ fn = step ,
1001
+ outputs_info = [init_ ],
1002
+ non_sequences = sde_pars_ + [noise_rng ],
1003
+ n_steps = steps_ ,
1004
+ strict = True ,
1005
+ )
1006
+ (noise_next_rng ,) = tuple (innov_updates_ .values ())
1007
+
1008
+ sde_out_ = at .concatenate ([init_ [None , ...], y_t ], axis = 0 ).dimshuffle (
1009
+ tuple (range (1 , y_t .ndim )) + (0 ,)
1010
+ )
1011
+
1012
+ eulermaruyama_op = EulerMaruyamaRV (
1013
+ inputs = [init_ , steps_ ] + sde_pars_ ,
1014
+ outputs = [noise_next_rng , sde_out_ ],
1015
+ dt = dt ,
1016
+ sde_fn = sde_fn ,
1017
+ ndim_supp = 1 ,
1018
+ )
1019
+
1020
+ eulermaruyama = eulermaruyama_op (init_dist , steps , * sde_pars )
1021
+ return eulermaruyama
1022
+
1023
+
1024
+ @_change_dist_size .register (EulerMaruyamaRV )
1025
+ def change_eulermaruyama_size (op , dist , new_size , expand = False ):
1026
+
1027
+ if expand :
1028
+ old_size = dist .shape [:- 1 ]
1029
+ new_size = tuple (new_size ) + tuple (old_size )
1030
+
1031
+ init_dist , steps , * sde_pars , _ = dist .owner .inputs
1032
+ return EulerMaruyama .rv_op (
1033
+ init_dist ,
1034
+ steps ,
1035
+ sde_pars ,
1036
+ dt = op .dt ,
1037
+ sde_fn = op .sde_fn ,
1038
+ size = new_size ,
1039
+ )
1040
+
1041
+
1042
+ @_logprob .register (EulerMaruyamaRV )
1043
+ def eulermaruyama_logp (op , values , init_dist , steps , * sde_pars_noise_arg , ** kwargs ):
1044
+ (x ,) = values
1045
+ # noise arg is unused, but is needed to make the logp signature match the rv_op signature
1046
+ * sde_pars , _ = sde_pars_noise_arg
1047
+ # sde_fn is user provided and likely not broadcastable to additional time dimension,
1048
+ # since the input x is now [..., t], we need to broadcast each input to [..., None]
1049
+ # below as best effort attempt to make it work
1050
+ sde_pars_broadcast = [x [..., None ] for x in sde_pars ]
1051
+ xtm1 = x [..., :- 1 ]
1052
+ xt = x [..., 1 :]
1053
+ f , g = op .sde_fn (xtm1 , * sde_pars_broadcast )
1054
+ mu = xtm1 + op .dt * f
1055
+ sigma = at .sqrt (op .dt ) * g
1056
+ # Compute and collapse logp across time dimension
1057
+ sde_logp = at .sum (logp (Normal .dist (mu , sigma ), xt ), axis = - 1 )
1058
+ init_logp = logp (init_dist , x [..., 0 ])
1059
+ return init_logp + sde_logp
0 commit comments