Skip to content

Text and SQLite backends for PyMC3 (update) #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ install:
- conda create -n testenv --yes pip python=$TRAVIS_PYTHON_VERSION
- source activate testenv
- conda install --yes ipython==1.1.0 pyzmq numpy==1.8.0 scipy nose matplotlib pandas Cython patsy statsmodels
- if [ ${TRAVIS_PYTHON_VERSION:0:1} == "2" ]; then conda install --yes mock; fi
- pip install --no-deps numdifftools
- pip install git+https://github.com/Theano/Theano.git
- python setup.py build_ext --inplace
Expand Down
119 changes: 119 additions & 0 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Backends for traces

Available backends
------------------

1. NumPy array (pymc.backends.NDArray)
2. Text files (pymc.backends.Text)
3. SQLite (pymc.backends.SQLite)

The NumPy arrays and text files both hold the entire trace in memory,
whereas SQLite commits the trace to the database while sampling.

Selecting a backend
-------------------

By default, a NumPy array is used as the backend. To specify a different
backend, pass a backend instance to `sample`.

For example, the following would save traces to the file 'test.db'.

>>> import pymc as pm
>>> db = pm.backends.SQLite('test.db')
>>> trace = pm.sample(..., trace=db)

Selecting values from a backend
-------------------------------

After a backend is finished sampling, it returns a MultiTrace object.
Values can be accessed in a few ways. The easiest way is to index the
backend object with a variable or variable name.

>>> trace['x'] # or trace[x]

The call will return a list containing the sampling values for all
chains of `x`. (Each call to `pymc.sample` creates a separate chain of
samples.)

For more control of which values are returned, the `get_values` method
can be used. The call below will return values from all chains, burning
the first 1000 iterations from each chain.

>>> trace.get_values('x', burn=1000)

Setting the `combined` flag will concatenate the results from all the
chains.

>>> trace.get_values('x', burn=1000, combine=True)

The `chains` parameter of `get_values` can be used to limit the chains
that are retrieved.

>>> trace.get_values('x', burn=1000, combine=True, chains=[0, 2])

Backends can also suppport slicing the MultiTrace object. For example,
the following call would return a new trace object without the first
1000 sampling iterations for all traces and variables.

>>> sliced_trace = trace[1000:]

Loading a saved backend
-----------------------

Saved backends can be loaded using `load` function in the module for the
specific backend.

>>> trace = pm.backends.sqlite.load('test.db')

Writing custom backends
-----------------------

Backends consist of a class that handles sampling storage and value
selection. Three sampling methods of backend will be called:

- setup: Before sampling is started, the `setup` method will be called
with two arguments: the number of draws and the chain number. This is
useful setting up any structure for storing the sampling values that
require the above information.

- record: Record the sampling results for the current draw. This method
will be called with a dictionary of values mapped to the variable
names. This is the only sampling function that *must* do something to
have a meaningful backend.

- close: This method is called following sampling and should perform any
actions necessary for finalizing and cleaning up the backend.

The base storage class `backends.base.BaseTrace` provides common model
setup that is used by all the PyMC backends.

Several selection methods must also be defined:

- get_values: This is the core method for selecting values from the
backend. It can be called directly and is used by __getitem__ when the
backend is indexed with a variable name or object.

- _slice: Defines how the backend returns a slice of itself. This
is called if the backend is indexed with a slice range.

- point: Returns values for each variable at a single iteration. This is
called if the backend is indexed with a single integer.

- __len__: This should return the number of draws (for the highest chain
number).

When `pymc.sample` finishes, it wraps all trace objects in a MultiTrace
object that provides a consistent selection interface for all backends.
If the traces are stored on disk, then a `load` function should also be
defined that returns a MultiTrace object.

For specific examples, see pymc.backends.{ndarray,text,sqlite}.py.
"""
from pymc.backends.ndarray import NDArray
from pymc.backends.text import Text
from pymc.backends.sqlite import SQLite

_shortcuts = {'text': {'backend': Text,
'name': 'mcmc'},
'sqlite': {'backend': SQLite,
'name': 'mcmc.sqlite'}}
271 changes: 271 additions & 0 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""Base backend for traces

See the docstring for pymc.backends for more information (includng
creating custom backends).
"""
import numpy as np
from pymc.model import modelcontext


class BaseTrace(object):
"""Base trace object

Parameters
----------
name : str
Name of backend
model : Model
If None, the model is taken from the `with` context.
vars : list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
"""
def __init__(self, name, model=None, vars=None):
self.name = name

model = modelcontext(model)
self.model = model
## `vars` is used throughout these backends to be consistent
## with other code, but I'd prefer to rename this since it is a
## built-in.
if vars is None:
vars = model.unobserved_RVs
self.vars = vars
self.varnames = [str(var) for var in vars]
self.fn = model.fastfn(vars)

## Get variable shapes. Most backends will need this
## information.
var_values = zip(self.varnames, self.fn(model.test_point))
self.var_shapes = {var: value.shape
for var, value in var_values}
self.chain = None

## Sampling methods

def setup(self, draws, chain):
"""Perform chain-specific setup.

Parameters
----------
draws : int
Expected number of draws
chain : int
Chain number
"""
pass

def record(self, point):
"""Record results of a sampling iteration.

Parameters
----------
point : dict
Values mapped to variable names
"""
raise NotImplementedError

def close(self):
"""Close the database backend.

This is called after sampling has finished.
"""
pass

## Selection methods

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)

try:
return self.point(idx)
except ValueError:
pass
except TypeError:
pass
return self.get_values(idx)

def __len__(self):
raise NotImplementedError

def get_values(self, varname, burn=0, thin=1):
"""Get values from trace.

Parameters
----------
varname : str
burn : int
thin : int

Returns
-------
A NumPy array
"""
raise NotImplementedError

def _slice(self, idx):
"""Slice trace object."""
raise NotImplementedError

def point(self, idx, chain=None):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
raise NotImplementedError


class MultiTrace(object):
"""MultiTrace provides the main interface for accessing values from
traces objects.

The core method to select values is `get_values`. Values can also be
accessed by indexing the MultiTrace object. Indexing can behave in
three ways:

1. Indexing with a variable or variable name (str) returns all
values for that variable.
2. Indexing with an integer returns a dictionary with values for
each variable at the given index (corresponding to a single
sampling iteration).
3. Slicing with a range returns a new trace with the number of draws
corresponding to the range.

For any methods that require a single trace (e.g., taking the length
of the MultiTrace instance, which returns the number of draws), the
trace with the highest chain number is always used.

Parameters
----------
traces : list of traces
Each object must have a unique `chain` attribute.
"""
def __init__(self, traces):
self._traces = {}
for trace in traces:
if trace.chain in self._traces:
raise ValueError("Chains are not unique.")
self._traces[trace.chain] = trace

@property
def nchains(self):
return len(self._traces)

@property
def chains(self):
return list(sorted(self._traces.keys()))

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)

try:
return self.point(idx)
except ValueError:
pass
except TypeError:
pass
return self.get_values(idx)

def __len__(self):
chain = self.chains[-1]
return len(self._traces[chain])

@property
def varnames(self):
chain = self.chains[-1]
return self._traces[chain].varnames

def get_values(self, varname, burn=0, thin=1, combine=False, chains=None,
squeeze=True):
"""Get values from traces.

Parameters
----------
varname : str
burn : int
thin : int
combine : bool
If True, results from `chains` will be concatenated.
chains : int or list of ints
Chains to retrieve. If None, all chains are used. A single
values can also accepted.
squeeze : bool
Return a single array element if the resulting list of
values only has one element. If this is not true, the result
will always be a list of arrays, even if `combine` is True.

Returns
-------
A list of NumPy arrays or a single NumPy array (depending on
`squeeze`).
"""
if chains is None:
chains = self.chains
varname = str(varname)
try:
results = [self._traces[chain].get_values(varname, burn, thin)
for chain in chains]
except TypeError: # Single chain passed.
results = [self._traces[chains].get_values(varname, burn, thin)]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, idx):
"""Return a new MultiTrace object sliced according to `idx`."""
chain = self.chains[-1]
model = self._traces[chain].model
vars = self._traces[chain].vars

new_traces = [trace._slice(idx) for trace in self._traces.values()]
return MultiTrace(new_traces)

def point(self, idx, chain=None):
"""Return a dictionary of point values at `idx`.

Parameters
----------
idx : int
chain : int
If a chain is not given, the highest chain number is used.
"""
if chain is None:
chain = self.chains[-1]
return self._traces[chain].point(idx)


def merge_traces(mtraces):
"""Merge MultiTrace objects.

Parameters
----------
mtraces : list of MultiTraces
Each instance should have unique chain numbers.

Raises
------
A ValueError is raised if any traces have overlapping chain numbers.

Returns
-------
A MultiTrace instance with merged chains
"""
base_mtrace = mtraces[0]
for new_mtrace in mtraces[1:]:
for new_chain, trace in new_mtrace._traces.items():
if new_chain in base_mtrace._traces:
raise ValueError("Chains are not unique.")
base_mtrace._traces[new_chain] = trace
return base_mtrace


def _squeeze_cat(results, combine, squeeze):
"""Squeeze and concatenate the results depending on values of
`combine` and `squeeze`."""
if combine:
results = np.concatenate(results)
if not squeeze:
results = [results]
else:
if squeeze and len(results) == 1:
results = results[0]
return results
Loading