Skip to content

Commit bc4b51d

Browse files
ferrinericardoV94
andauthored
Fix fgraph_from_model with multivariate transformed variables (#6924)
* fix do with transformed * Change test --------- Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
1 parent a24acb9 commit bc4b51d

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc/model/fgraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def make_node(self, rv, value, *dims):
7070
dims = self._parse_dims(rv, *dims)
7171
if value is not None:
7272
assert isinstance(value, Variable)
73-
assert rv.type.in_same_class(value.type)
73+
assert rv.type.dtype == value.type.dtype
7474
return Apply(self, [rv, value, *dims], [rv.type(name=rv.name)])
7575

7676

tests/model/test_fgraph.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,16 @@ def test_fgraph_rewrite(non_centered_rewrite):
351351
m_new.compile_logp()(ip),
352352
m_ref.compile_logp()(ip),
353353
)
354+
355+
356+
def test_multivariate_transform():
357+
with pm.Model() as m:
358+
x = pm.Dirichlet("x", a=[1, 1, 1])
359+
y, *_ = pm.LKJCholeskyCov("y", n=4, eta=1, sd_dist=pm.Exponential.dist(1))
360+
361+
new_m = clone_model(m)
362+
363+
ip = m.initial_point()
364+
new_ip = new_m.initial_point()
365+
np.testing.assert_allclose(ip["x_simplex__"], new_ip["x_simplex__"])
366+
np.testing.assert_allclose(ip["y_cholesky-cov-packed__"], new_ip["y_cholesky-cov-packed__"])

0 commit comments

Comments
 (0)