Skip to content

Commit b8ab9e2

Browse files
committed
add dispatch for identity Op, use static shapes for parameters
1 parent 2c355f1 commit b8ab9e2

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

pymc/sampling/jax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from pymc.distributions.multivariate import PosDefMatrix
4646
from pymc.initial_point import StartDict
4747
from pymc.logprob.utils import CheckParameterValue
48+
from pymc.pytensorf import IdentityOp
4849
from pymc.sampling.mcmc import _init_jitter
4950
from pymc.util import (
5051
RandomSeed,
@@ -67,6 +68,14 @@
6768
)
6869

6970

71+
@jax_funcify.register(IdentityOp)
72+
def jax_funcify_Identity(op, **kwargs):
73+
def identity_fn(value):
74+
return value
75+
76+
return identity_fn
77+
78+
7079
@jax_funcify.register(Assert)
7180
@jax_funcify.register(CheckParameterValue)
7281
def jax_funcify_Assert(op, **kwargs):

pymc/variational/approximations.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def create_shared_params(self, start=None, start_sigma=None):
9393
rho = rho1
9494

9595
return {
96-
"mu": pytensor.shared(pm.floatX(start), "mu"),
97-
"rho": pytensor.shared(pm.floatX(rho), "rho"),
96+
"mu": pytensor.shared(pm.floatX(start), "mu", shape=start.shape),
97+
"rho": pytensor.shared(pm.floatX(rho), "rho", shape=rho.shape),
9898
}
9999

100100
@node_property
@@ -137,7 +137,10 @@ def create_shared_params(self, start=None):
137137
start = self._prepare_start(start)
138138
n = self.ddim
139139
L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX)
140-
return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")}
140+
return {
141+
"mu": pytensor.shared(start, "mu", shape=start.shape),
142+
"L_tril": pytensor.shared(L_tril, "L_tril", shape=L_tril.shape),
143+
}
141144

142145
@node_property
143146
def L(self):

0 commit comments

Comments
 (0)