Skip to content

Commit 5f60008

Browse files
committed
add test
1 parent b8ab9e2 commit 5f60008

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc/variational/approximations.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,13 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
228228
for j in range(len(trace)):
229229
histogram[i] = DictToArrayBijection.map(trace.point(j, t)).data
230230
i += 1
231-
return dict(histogram=pytensor.shared(pm.floatX(histogram), "histogram"))
231+
return dict(
232+
histogram=pytensor.shared(
233+
pm.floatX(histogram),
234+
"histogram",
235+
shape=histogram.shape,
236+
)
237+
)
232238

233239
def _check_trace(self):
234240
trace = self._kwargs.get("trace", None)

tests/sampling/test_jax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,10 @@ def test_sample_partially_observed():
471471
assert idata.observed_data["x_observed"].shape == (2,)
472472
assert idata.posterior["x_unobserved"].shape == (1, 10, 1)
473473
assert idata.posterior["x"].shape == (1, 10, 3)
474+
475+
476+
@pytest.mark.parametrize("method", ["advi", "fullrank_advi"])
477+
def test_vi_sampling_jax(method):
478+
with pm.Model() as model:
479+
x = pm.Normal("x")
480+
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))

0 commit comments

Comments
 (0)