Skip to content

Commit f06cdb7

Browse files
author
Maxim Kochurov
committed
add an exceptionn to prevent using SVGD with minibatch
1 parent 70d4f5e commit f06cdb7

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

pymc/tests/variational/test_inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def fit_kwargs(inference, use_minibatch):
153153

154154

155155
def test_fit_oo(inference, fit_kwargs, simple_model_data):
156-
trace = inference.fit(**fit_kwargs).sample(10000)
156+
try:
157+
trace = inference.fit(**fit_kwargs).sample(10000)
158+
except opvi.NotImplementedInference as e:
159+
pytest.skip(str(e))
157160
mu_post = simple_model_data["mu_post"]
158161
d = simple_model_data["d"]
159162
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05)

pymc/variational/approximations.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,5 @@ def sample(post, *_):
388388
nodes, _ = aesara.scan(
389389
sample,
390390
self.histogram,
391-
non_sequences=opvi._known_scan_ignored_inputs(node),
392-
strict=True,
393391
)
394392
return nodes

pymc/variational/opvi.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def _known_scan_ignored_inputs(terms):
169169

170170
return [
171171
n.owner.inputs[0]
172-
for n in find_rng_nodes(terms)
173-
if isinstance(n, (MinibatchIndexRV, SimulatorRV))
172+
for n in aesara.graph.ancestors(terms)
173+
if n.owner is not None and isinstance(n.owner.op, (MinibatchIndexRV, SimulatorRV))
174174
]
175175

176176

@@ -1025,9 +1025,7 @@ def symbolic_sample_over_posterior(self, node):
10251025
def sample(post, *_):
10261026
return aesara.clone_replace(node, {self.input: post})
10271027

1028-
nodes, _ = aesara.scan(
1029-
sample, random, non_sequences=_known_scan_ignored_inputs(node), strict=True
1030-
)
1028+
nodes, _ = aesara.scan(sample, random)
10311029
assert self.input not in set(aesara.graph.graph_inputs(as_list(nodes)))
10321030
return nodes
10331031

@@ -1388,8 +1386,6 @@ def sample(*post):
13881386
nodes, _ = aesara.scan(
13891387
sample,
13901388
self.symbolic_randoms,
1391-
non_sequences=_known_scan_ignored_inputs(node),
1392-
strict=True,
13931389
)
13941390
assert not (set(self.inputs) & set(aesara.graph.graph_inputs(as_list(nodes))))
13951391
return nodes

pymc/variational/stein.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
from pymc.aesaraf import floatX
1919
from pymc.util import WithMemoization, locally_cachedmethod
20-
from pymc.variational.opvi import node_property
20+
from pymc.variational.opvi import (
21+
NotImplementedInference,
22+
_known_scan_ignored_inputs,
23+
node_property,
24+
)
2125
from pymc.variational.test_functions import rbf
2226

2327
__all__ = ["Stein"]
@@ -46,7 +50,12 @@ def approx_symbolic_matrices(self):
4650

4751
@node_property
4852
def dlogp(self):
49-
grad = at.grad(self.logp_norm.sum(), self.approx_symbolic_matrices)
53+
logp = self.logp_norm.sum()
54+
if _known_scan_ignored_inputs([logp]):
55+
raise NotImplementedInference(
56+
"SVGD does not currently support Minibatch or Simulator RV"
57+
)
58+
grad = at.grad(logp, self.approx_symbolic_matrices)
5059

5160
def flatten2(tensor):
5261
return tensor.flatten(2)

0 commit comments

Comments
 (0)