Skip to content

Commit dc966b7

Browse files
committed
fixed dirichlet and added example
1 parent 839786b commit dc966b7

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

examples/dirichlet.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
from pymc import *
3+
from pymc.distributions.multivariate import Dirichlet
4+
5+
6+
model = Model()
7+
Var = model.Var
8+
Data = model.Data
9+
10+
k = 5
11+
a = constant(np.array([2,3.,4, 2,2]))
12+
13+
p, p_m1 = model.TransformedVar(
14+
'p', Dirichlet(k,a),
15+
transforms.simplex, shape = k - 1)
16+
17+
18+
H = model.d2logpc()
19+
20+
s = find_MAP(model)
21+
22+
step = hmc_step(model, model.vars, H(s))
23+
trace, _,t = sample(1000, step, s)
24+
25+

pymc/distributions/multivariate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ def Dirichlet(k, a):
4444
\cdot\left(1-\sum_{i=1}^{k-1}x_i\right)^\theta_k
4545
4646
:Parameters:
47-
theta : array
48-
An (n,k) or (1,k) array > 0.
47+
k : scalar int
48+
k > 1
49+
a : float tensor
50+
a > 0
51+
concentration parameters
52+
last index is the k index
4953
5054
:Support:
5155
x : vector
@@ -55,16 +59,17 @@ def Dirichlet(k, a):
5559
Only the first `k-1` elements of `x` are expected. Can be used
5660
as a parent of Multinomial and Categorical nevertheless.
5761
"""
62+
5863
support = 'continuous'
5964

60-
a = ones(k) * a
65+
a = ones([k]) * a
6166
def logp(value):
6267

6368
#only defined for sum(value) == 1
6469
return bound(
65-
sum((a -1)*log(value)) + gammaln(sum(a)) - sum(gammaln(a)),
70+
sum(logpow(value, a -1) - gammaln(a), axis = 0) + gammaln(sum(a)),
6671

67-
k > 2,
72+
k > 1,
6873
a > 0)
6974

7075
mean = a/sum(a)

0 commit comments

Comments
 (0)