19
19
20
20
from pymc .distributions import Dirichlet , Normal
21
21
from pymc .distributions .transforms import log
22
- from pymc .model import Model
22
+ from pymc .model import Deterministic , Model
23
23
from pymc .stats .log_density import compute_log_likelihood , compute_log_prior
24
24
from tests .distributions .test_multivariate import dirichlet_logpdf
25
25
@@ -41,7 +41,7 @@ def test_basic(self, transform):
41
41
assert m .rvs_to_transforms [x ] is transform
42
42
43
43
assert res is idata
44
- assert res .log_likelihood .dims == {"chain" : 4 , "draw" : 25 , "test_dim" : 3 }
44
+ assert res .log_likelihood .sizes == {"chain" : 4 , "draw" : 25 , "test_dim" : 3 }
45
45
46
46
np .testing .assert_allclose (
47
47
res .log_likelihood ["y" ].values ,
@@ -62,7 +62,7 @@ def test_multivariate(self):
62
62
idata = InferenceData (posterior = dict_to_dataset ({"p" : p_draws }))
63
63
res = compute_log_likelihood (idata )
64
64
65
- assert res .log_likelihood .dims == {"chain" : 4 , "draw" : 25 , "test_event_dim" : 10 }
65
+ assert res .log_likelihood .sizes == {"chain" : 4 , "draw" : 25 , "test_event_dim" : 10 }
66
66
67
67
np .testing .assert_allclose (
68
68
res .log_likelihood ["y" ].values ,
@@ -149,7 +149,26 @@ def test_basic_log_prior(self, transform):
149
149
assert m .rvs_to_transforms [x ] is transform
150
150
151
151
assert res is idata
152
- assert res .log_prior .dims == {"chain" : 4 , "draw" : 25 }
152
+ assert res .log_prior .sizes == {"chain" : 4 , "draw" : 25 }
153
+
154
+ np .testing .assert_allclose (
155
+ res .log_prior ["x" ].values ,
156
+ st .norm (0 , 1 ).logpdf (idata .posterior ["x" ].values ),
157
+ )
158
+
159
+ def test_deterministic_log_prior (self ):
160
+ with Model () as m :
161
+ x = Normal ("x" )
162
+ Deterministic ("d" , 2 * x )
163
+ Normal ("y" , x , observed = [0 , 1 , 2 ])
164
+
165
+ idata = InferenceData (posterior = dict_to_dataset ({"x" : np .arange (100 ).reshape (4 , 25 )}))
166
+ res = compute_log_prior (idata )
167
+
168
+ assert res is idata
169
+ assert "x" in res .log_prior
170
+ assert "d" not in res .log_prior
171
+ assert res .log_prior .sizes == {"chain" : 4 , "draw" : 25 }
153
172
154
173
np .testing .assert_allclose (
155
174
res .log_prior ["x" ].values ,
0 commit comments