Skip to content

Commit b82be63

Browse files
committed
Fix faulty tests for DM.
1 parent 26a4202 commit b82be63

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

pymc3/tests/test_distributions.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,9 @@ def test_dirichlet_multinomial_mode(self, alpha, n):
10701070
alpha = np.array(alpha)
10711071
n = np.array(n)
10721072
with Model() as model:
1073-
m = DirichletMultinomial('m', n, alpha)
1074-
assert_allclose(m.distribution.mode.eval().sum(), n)
1073+
m = DirichletMultinomial('m', n, alpha,
1074+
shape=alpha.shape)
1075+
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)
10751076

10761077
@pytest.mark.parametrize('alpha,n', [
10771078
[[[.25, .25, .25, .25]], [1]],
@@ -1082,10 +1083,11 @@ def test_dirichlet_multinomial_mode(self, alpha, n):
10821083
[10, 2]],
10831084
])
10841085
def test_dirichlet_multinomial_random(self, alpha, n):
1085-
alpha = np.asarray(alpha)
1086-
n = np.asarray(n)
1086+
alpha = np.array(alpha)
1087+
n = np.array(n)
10871088
with Model() as model:
1088-
m = DirichletMultinomial('m', n=n, alpha=alpha)
1089+
m = DirichletMultinomial('m', n=n, alpha=alpha,
1090+
shape=alpha.shape)
10891091
m.random()
10901092

10911093
def test_categorical_bounds(self):

0 commit comments

Comments
 (0)