Skip to content

Trouble sampling from Dirichlet with multiple groups #1153

Closed
@wcbeard

Description

@wcbeard

I would like to make 2 (And eventually more) Categorical distributions in a single pymc3 variable, which I believe should have shape 2xK. In my data set each observation will come from one of the two different categories, so I can have data that looks like the following:

import numpy.random as nr
import pymc3 as pm
from pandas import DataFrame

nr.seed(0)
N_groups = 2
N_cats = 5
n_obs = 10
data = DataFrame({'Data': nr.randint(0, N_cats, (n_obs)),
                  'Group': nr.randint(0, N_groups, (n_obs))
                 }
                )
data  # =>
''' 
       Data  Group
0     4      0
1     0      0
2     3      0
3     3      1
4     3      0
5     1      1
6     3      1
7     2      0
8     4      0
9     0      1
'''

When I try to sample from the following model, however, it hangs and doesn't finish after at least 10 minutes, which seems too long for such a small sample data set.

mod = pm.Model()

cat_shape = (N_groups, N_cats)

with mod:
    a = np.ones(cat_shape, dtype=float)
    p = pm.Dirichlet('p', a, shape=cat_shape, testval=None)
    c = pm.Categorical('c', p[data.Group], shape=cat_shape, observed=data.Data)

n = 1000

with mod:
    step = pm.Slice()
    trace = pm.sample(n, step)

Should I expect it to take that long, or am I doing something wrong?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions