Skip to content

Commit b7304f6

Browse files
Wrap jax MvNormal rewrite in try/except block
1 parent e95a670 commit b7304f6

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

pymc_experimental/statespace/filters/distributions.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import jax.random
21
import numpy as np
32
import pymc as pm
43
import pytensor
@@ -10,7 +9,6 @@
109
from pymc.distributions.shape_utils import get_support_shape, get_support_shape_1d
1110
from pymc.logprob.abstract import _logprob
1211
from pytensor.graph.basic import Node
13-
from pytensor.link.jax.dispatch.random import jax_sample_fn
1412
from pytensor.tensor.random.basic import MvNormalRV
1513

1614
floatX = pytensor.config.floatX
@@ -66,18 +64,25 @@ class MvNormalSVD(MvNormal):
6664
rv_op = MvNormalSVDRV()
6765

6866

69-
@jax_sample_fn.register(MvNormalSVDRV)
70-
def jax_sample_fn_mvnormal_svd(op, node):
71-
def sample_fn(rng, size, dtype, *parameters):
72-
rng_key = rng["jax_state"]
73-
rng_key, sampling_key = jax.random.split(rng_key, 2)
74-
sample = jax.random.multivariate_normal(
75-
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
76-
)
77-
rng["jax_state"] = rng_key
78-
return (rng, sample)
67+
try:
68+
import jax.random
69+
from pytensor.link.jax.dispatch.random import jax_sample_fn
70+
71+
@jax_sample_fn.register(MvNormalSVDRV)
72+
def jax_sample_fn_mvnormal_svd(op, node):
73+
def sample_fn(rng, size, dtype, *parameters):
74+
rng_key = rng["jax_state"]
75+
rng_key, sampling_key = jax.random.split(rng_key, 2)
76+
sample = jax.random.multivariate_normal(
77+
sampling_key, *parameters, shape=size, dtype=dtype, method="svd"
78+
)
79+
rng["jax_state"] = rng_key
80+
return (rng, sample)
81+
82+
return sample_fn
7983

80-
return sample_fn
84+
except ImportError:
85+
pass
8186

8287

8388
class LinearGaussianStateSpaceRV(SymbolicRandomVariable):
@@ -90,7 +95,6 @@ def update(self, node: Node):
9095

9196
class _LinearGaussianStateSpace(Continuous):
9297
rv_op = LinearGaussianStateSpaceRV
93-
ndim_supp = 2
9498

9599
def __new__(
96100
cls,

0 commit comments

Comments
 (0)