1
- import jax .random
2
1
import numpy as np
3
2
import pymc as pm
4
3
import pytensor
10
9
from pymc .distributions .shape_utils import get_support_shape , get_support_shape_1d
11
10
from pymc .logprob .abstract import _logprob
12
11
from pytensor .graph .basic import Node
13
- from pytensor .link .jax .dispatch .random import jax_sample_fn
14
12
from pytensor .tensor .random .basic import MvNormalRV
15
13
16
14
floatX = pytensor .config .floatX
@@ -66,18 +64,25 @@ class MvNormalSVD(MvNormal):
66
64
rv_op = MvNormalSVDRV ()
67
65
68
66
69
- @jax_sample_fn .register (MvNormalSVDRV )
70
- def jax_sample_fn_mvnormal_svd (op , node ):
71
- def sample_fn (rng , size , dtype , * parameters ):
72
- rng_key = rng ["jax_state" ]
73
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
74
- sample = jax .random .multivariate_normal (
75
- sampling_key , * parameters , shape = size , dtype = dtype , method = "svd"
76
- )
77
- rng ["jax_state" ] = rng_key
78
- return (rng , sample )
67
+ try :
68
+ import jax .random
69
+ from pytensor .link .jax .dispatch .random import jax_sample_fn
70
+
71
+ @jax_sample_fn .register (MvNormalSVDRV )
72
+ def jax_sample_fn_mvnormal_svd (op , node ):
73
+ def sample_fn (rng , size , dtype , * parameters ):
74
+ rng_key = rng ["jax_state" ]
75
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
76
+ sample = jax .random .multivariate_normal (
77
+ sampling_key , * parameters , shape = size , dtype = dtype , method = "svd"
78
+ )
79
+ rng ["jax_state" ] = rng_key
80
+ return (rng , sample )
81
+
82
+ return sample_fn
79
83
80
- return sample_fn
84
+ except ImportError :
85
+ pass
81
86
82
87
83
88
class LinearGaussianStateSpaceRV (SymbolicRandomVariable ):
@@ -90,7 +95,6 @@ def update(self, node: Node):
90
95
91
96
class _LinearGaussianStateSpace (Continuous ):
92
97
rv_op = LinearGaussianStateSpaceRV
93
- ndim_supp = 2
94
98
95
99
def __new__ (
96
100
cls ,
0 commit comments