Skip to content

Commit 854b752

Browse files
committed
add dispatch for identity Op, use static shapes for parameters
1 parent a06081e commit 854b752

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

pymc/sampling/jax.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pymc.distributions.multivariate import PosDefMatrix
4747
from pymc.initial_point import StartDict
4848
from pymc.logprob.utils import CheckParameterValue
49+
from pymc.pytensorf import IdentityOp
4950
from pymc.sampling.mcmc import _init_jitter
5051
from pymc.stats.convergence import log_warnings, run_convergence_checks
5152
from pymc.util import (
@@ -69,6 +70,14 @@
6970
)
7071

7172

73+
@jax_funcify.register(IdentityOp)
74+
def jax_funcify_Identity(op, **kwargs):
75+
def identity_fn(value):
76+
return value
77+
78+
return identity_fn
79+
80+
7281
@jax_funcify.register(Assert)
7382
@jax_funcify.register(CheckParameterValue)
7483
def jax_funcify_Assert(op, **kwargs):

pymc/variational/approximations.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def create_shared_params(self, start=None, start_sigma=None):
9494
rho = rho1
9595

9696
return {
97-
"mu": pytensor.shared(pm.floatX(start), "mu"),
98-
"rho": pytensor.shared(pm.floatX(rho), "rho"),
97+
"mu": pytensor.shared(pm.floatX(start), "mu", shape=start.shape),
98+
"rho": pytensor.shared(pm.floatX(rho), "rho", shape=rho.shape),
9999
}
100100

101101
@node_property
@@ -138,7 +138,10 @@ def create_shared_params(self, start=None):
138138
start = self._prepare_start(start)
139139
n = self.ddim
140140
L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX)
141-
return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")}
141+
return {
142+
"mu": pytensor.shared(start, "mu", shape=start.shape),
143+
"L_tril": pytensor.shared(L_tril, "L_tril", shape=L_tril.shape),
144+
}
142145

143146
@node_property
144147
def L(self):

0 commit comments

Comments
 (0)