From 9266900db4ba14abe018453ca0791ccf030e8fbf Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Fri, 3 Jan 2014 15:21:07 -0500 Subject: [PATCH 01/15] Include tests_require argument in setup.py --- setup.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 95052d6128..89a160c38c 100755 --- a/setup.py +++ b/setup.py @@ -25,8 +25,9 @@ 'Topic :: Scientific/Engineering :: Mathematics', 'Operating System :: OS Independent'] -required = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', - 'Theano==0.6.0'] +install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', + 'Theano==0.6.0'] +test_reqs = ['nose'] if __name__ == "__main__": setup(name=DISTNAME, @@ -41,4 +42,6 @@ 'pymc.step_methods', 'pymc.tuning', 'pymc.tests', 'pymc.glm'], classifiers=classifiers, - install_requires=required) + install_requires=install_reqs, + tests_require=test_reqs, + test_suite='nose.collector') From 2bd8e5bec862c1bf8f97f4b7f239dec3b270ff57 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Fri, 3 Jan 2014 15:21:56 -0500 Subject: [PATCH 02/15] Include mock as test dependency This is only needed for python 2 because mock is in stdlib for python 3 (unittest.mock). --- .travis.yml | 1 + setup.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/.travis.yml b/.travis.yml index 12ff62a264..f258fd05c0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,6 +16,7 @@ install: - conda create -n testenv --yes pip python=$TRAVIS_PYTHON_VERSION - source activate testenv - conda install --yes ipython==1.1.0 pyzmq numpy==1.8.0 scipy nose matplotlib pandas Cython patsy statsmodels + - if [ ${TRAVIS_PYTHON_VERSION:0:1} == "2" ]; then conda install --yes mock; fi - pip install --no-deps numdifftools - pip install git+https://github.com/Theano/Theano.git - python setup.py build_ext --inplace diff --git a/setup.py b/setup.py index 89a160c38c..2a338bf31f 100755 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python from setuptools import setup +import sys DISTNAME = 'pymc' @@ -27,7 +28,10 @@ install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', 'Theano==0.6.0'] + test_reqs = ['nose'] +if sys.version_info[0] == 2: # py3 has mock in stdlib + test_reqs.append('mock') if __name__ == "__main__": setup(name=DISTNAME, From e95718036727f041c449fed055f0b08badf805ee Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Feb 2014 22:18:39 -0500 Subject: [PATCH 03/15] Add base and NDArray backend This commit contains a new backend for sampling and selecting values. Non-backend files have been changed to work with the new backend. This commit also merges the `sample` and `psample` functions. `sample` now takes a keyword argument `njobs`, and if this is over one, the multiprocessing version is used. --- pymc/backends/__init__.py | 1 + pymc/backends/base.py | 271 ++++++++++++++++++++++++ pymc/backends/ndarray.py | 113 ++++++++++ pymc/diagnostics.py | 19 +- pymc/examples/gelman_bioassay.py | 2 +- pymc/plots.py | 103 +++------ pymc/sampling.py | 205 +++++++++--------- pymc/stats.py | 53 ++--- pymc/tests/checks.py | 2 +- pymc/tests/test_base_backend.py | 56 +++++ pymc/tests/test_diagnostics.py | 4 +- pymc/tests/test_glm.py | 16 +- pymc/tests/test_ndarray_backend.py | 321 +++++++++++++++++++++++++++++ pymc/tests/test_plots.py | 4 +- pymc/tests/test_sampling.py | 89 ++++++-- pymc/tests/test_trace.py | 124 ----------- pymc/trace.py | 112 +--------- setup.py | 2 +- 18 files changed, 1020 insertions(+), 477 deletions(-) create mode 100644 pymc/backends/__init__.py create mode 100644 pymc/backends/base.py create mode 100644 pymc/backends/ndarray.py create mode 100644 pymc/tests/test_base_backend.py create mode 100644 pymc/tests/test_ndarray_backend.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py new file mode 100644 index 0000000000..feebc00bb6 --- /dev/null +++ b/pymc/backends/__init__.py @@ -0,0 +1 @@ +from pymc.backends.ndarray import NDArray diff --git a/pymc/backends/base.py b/pymc/backends/base.py new file mode 100644 index 0000000000..31855beced --- /dev/null +++ b/pymc/backends/base.py @@ -0,0 +1,271 @@ +"""Base backend for traces + +See the docstring for pymc.backends for more information (includng +creating custom backends). +""" +import numpy as np +from pymc.model import modelcontext + + +class BaseTrace(object): + """Base trace object + + Parameters + ---------- + name : str + Name of backend + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name, model=None, vars=None): + self.name = name + + model = modelcontext(model) + self.model = model + ## `vars` is used throughout these backends to be consistent + ## with other code, but I'd prefer to rename this since it is a + ## built-in. + if vars is None: + vars = model.unobserved_RVs + self.vars = vars + self.varnames = [str(var) for var in vars] + self.fn = model.fastfn(vars) + + ## Get variable shapes. Most backends will need this + ## information. + var_values = zip(self.varnames, self.fn(model.test_point)) + self.var_shapes = {var: value.shape + for var, value in var_values} + self.chain = None + + ## Sampling methods + + def setup(self, draws, chain): + """Perform chain-specific setup. + + Parameters + ---------- + draws : int + Expected number of draws + chain : int + Chain number + """ + pass + + def record(self, point): + """Record results of a sampling iteration. + + Parameters + ---------- + point : dict + Values mapped to variable names + """ + raise NotImplementedError + + def close(self): + """Close the database backend. + + This is called after sampling has finished. + """ + pass + + ## Selection methods + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self._slice(idx) + + try: + return self.point(idx) + except ValueError: + pass + except TypeError: + pass + return self.get_values(idx) + + def __len__(self): + raise NotImplementedError + + def get_values(self, varname, burn=0, thin=1): + """Get values from trace. + + Parameters + ---------- + varname : str + burn : int + thin : int + + Returns + ------- + A NumPy array + """ + raise NotImplementedError + + def _slice(self, idx): + """Slice trace object.""" + raise NotImplementedError + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + """ + raise NotImplementedError + + +class MultiTrace(object): + """MultiTrace provides the main interface for accessing values from + traces objects. + + The core method to select values is `get_values`. Values can also be + accessed by indexing the MultiTrace object. Indexing can behave in + three ways: + + 1. Indexing with a variable or variable name (str) returns all + values for that variable. + 2. Indexing with an integer returns a dictionary with values for + each variable at the given index (corresponding to a single + sampling iteration). + 3. Slicing with a range returns a new trace with the number of draws + corresponding to the range. + + For any methods that require a single trace (e.g., taking the length + of the MultiTrace instance, which returns the number of draws), the + trace with the highest chain number is always used. + + Parameters + ---------- + traces : list of traces + Each object must have a unique `chain` attribute. + """ + def __init__(self, traces): + self._traces = {} + for trace in traces: + if trace.chain in self._traces: + raise ValueError("Chains are not unique.") + self._traces[trace.chain] = trace + + @property + def nchains(self): + return len(self._traces) + + @property + def chains(self): + return list(sorted(self._traces.keys())) + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self._slice(idx) + + try: + return self.point(idx) + except ValueError: + pass + except TypeError: + pass + return self.get_values(idx) + + def __len__(self): + chain = self.chains[-1] + return len(self._traces[chain]) + + @property + def varnames(self): + chain = self.chains[-1] + return self._traces[chain].varnames + + def get_values(self, varname, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from traces. + + Parameters + ---------- + varname : str + burn : int + thin : int + combine : bool + If True, results from `chains` will be concatenated. + chains : int or list of ints + Chains to retrieve. If None, all chains are used. A single + values can also accepted. + squeeze : bool + Return a single array element if the resulting list of + values only has one element. If this is not true, the result + will always be a list of arrays, even if `combine` is True. + + Returns + ------- + A list of NumPy arrays or a single NumPy array (depending on + `squeeze`). + """ + if chains is None: + chains = self.chains + varname = str(varname) + try: + results = [self._traces[chain].get_values(varname, burn, thin) + for chain in chains] + except TypeError: # Single chain passed. + results = [self._traces[chains].get_values(varname, burn, thin)] + return _squeeze_cat(results, combine, squeeze) + + def _slice(self, idx): + """Return a new MultiTrace object sliced according to `idx`.""" + chain = self.chains[-1] + model = self._traces[chain].model + vars = self._traces[chain].vars + + new_traces = [trace._slice(idx) for trace in self._traces.values()] + return MultiTrace(new_traces) + + def point(self, idx, chain=None): + """Return a dictionary of point values at `idx`. + + Parameters + ---------- + idx : int + chain : int + If a chain is not given, the highest chain number is used. + """ + if chain is None: + chain = self.chains[-1] + return self._traces[chain].point(idx) + + +def merge_traces(mtraces): + """Merge MultiTrace objects. + + Parameters + ---------- + mtraces : list of MultiTraces + Each instance should have unique chain numbers. + + Raises + ------ + A ValueError is raised if any traces have overlapping chain numbers. + + Returns + ------- + A MultiTrace instance with merged chains + """ + base_mtrace = mtraces[0] + for new_mtrace in mtraces[1:]: + for new_chain, trace in new_mtrace._traces.items(): + if new_chain in base_mtrace._traces: + raise ValueError("Chains are not unique.") + base_mtrace._traces[new_chain] = trace + return base_mtrace + + +def _squeeze_cat(results, combine, squeeze): + """Squeeze and concatenate the results depending on values of + `combine` and `squeeze`.""" + if combine: + results = np.concatenate(results) + if not squeeze: + results = [results] + else: + if squeeze and len(results) == 1: + results = results[0] + return results diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py new file mode 100644 index 0000000000..27f55c5e9f --- /dev/null +++ b/pymc/backends/ndarray.py @@ -0,0 +1,113 @@ +"""NumPy array trace backend + +Store sampling values in memory as a NumPy array. +""" +import numpy as np +from pymc.backends import base + + +class NDArray(base.BaseTrace): + """NDArray trace object + + Parameters + ---------- + name : str + Name of backend. This has no meaning for the NDArray backend. + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name=None, model=None, vars=None): + super(NDArray, self).__init__(name, model, vars) + self.draw_idx = 0 + self.draws = None + self.samples = {} + + ## Sampling methods + + def setup(self, draws, chain): + """Perform chain-specific setup. + + Parameters + ---------- + draws : int + Expected number of draws + chain : int + Chain number + """ + self.chain = chain + if self.samples: # Concatenate new array if chain is already present. + old_draws = len(self) + self.draws = old_draws + draws + self.draws_idx = old_draws + for varname, shape in self.var_shapes.items(): + old_trace = self.samples[varname] + new_trace = np.zeros((draws, ) + shape) + self.samples[varname] = np.concatenate((old_trace, new_trace), + axis=0) + else: # Otherwise, make array of zeros for each variable. + self.draws = draws + for varname, shape in self.var_shapes.items(): + self.samples[varname] = np.zeros((draws, ) + shape) + + def record(self, point): + """Record results of a sampling iteration. + + Parameters + ---------- + point : dict + Values mapped to variable names + """ + for varname, value in zip(self.varnames, self.fn(point)): + self.samples[varname][self.draw_idx] = value + self.draw_idx += 1 + + def close(self): + if self.draw_idx == self.draws: + return + ## Remove trailing zeros if interrupted before completed all + ## draws. + self.samples = {var: trace[:self.draw_idx] + for var, trace in self.samples.items()} + + ## Selection methods + + def __len__(self): + if not self.samples: # `setup` has not been called. + return 0 + varname = self.varnames[0] + return self.samples[varname].shape[0] + + def get_values(self, varname, burn=0, thin=1): + """Get values from trace. + + Parameters + ---------- + varname : str + burn : int + thin : int + + Returns + ------- + A NumPy array + """ + return self.samples[varname][burn::thin] + + def _slice(self, idx): + sliced = NDArray(model=self.model, vars=self.vars) + sliced.chain = self.chain + sliced.samples = {varname: values[idx] + for varname, values in self.samples.items()} + return sliced + + def point(self, idx): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + idx = int(idx) + return {varname: values[idx] + for varname, values in self.samples.items()} diff --git a/pymc/diagnostics.py b/pymc/diagnostics.py index 88e15a83a7..1c91a68218 100644 --- a/pymc/diagnostics.py +++ b/pymc/diagnostics.py @@ -87,7 +87,7 @@ def geweke(x, first=.1, last=.5, intervals=20): return np.array(zscores) -def gelman_rubin(mtrace): +def gelman_rubin(trace): """ Returns estimate of R for a set of traces. The Gelman-Rubin diagnostic tests for lack of convergence by comparing @@ -99,8 +99,8 @@ def gelman_rubin(mtrace): Parameters ---------- - mtrace : MultiTrace - A MultiTrace object containing parallel traces (minimum 2) + trace + A trace object containing parallel traces (minimum 2) of one or more stochastic parameters. Returns @@ -126,8 +126,7 @@ def gelman_rubin(mtrace): Brooks and Gelman (1998) Gelman and Rubin (1992)""" - m = len(mtrace.traces) - if m < 2: + if trace.nchains < 2: raise ValueError( 'Gelman-Rubin diagnostic requires multiple chains of the same length.') @@ -157,10 +156,10 @@ def calc_rhat(x): return np.squeeze([calc_rhat(xi) for xi in x.transpose(rotated_indices)]) Rhat = {} - for var in mtrace.varnames: + for var in trace.varnames: # Get all traces for var - x = np.array([mtrace.traces[i][var] for i in range(m)]) + x = np.array(trace.get_values(var)) try: Rhat[var] = calc_rhat(x) @@ -169,9 +168,11 @@ def calc_rhat(x): return Rhat + def trace_to_dataframe(trace): """Convert a PyMC trace consisting of 1-D variables to a pandas DataFrame """ import pandas as pd - return pd.DataFrame({name: np.squeeze(trace_var.vals) - for name, trace_var in trace.samples.items()}) + return pd.DataFrame( + {varname: np.squeeze(trace.get_values(varname, combine=True)) + for varname in trace.varnames}) diff --git a/pymc/examples/gelman_bioassay.py b/pymc/examples/gelman_bioassay.py index 4725ded00b..5aa23d038d 100644 --- a/pymc/examples/gelman_bioassay.py +++ b/pymc/examples/gelman_bioassay.py @@ -28,7 +28,7 @@ def run(n=1000): if n == "short": n = 50 with model: - trace = sample(n, step, trace=model.unobserved_RVs) + trace = sample(n, step) if __name__ == '__main__': run() diff --git a/pymc/plots.py b/pymc/plots.py index c8008de66d..ce8d8fa556 100644 --- a/pymc/plots.py +++ b/pymc/plots.py @@ -1,7 +1,6 @@ import numpy as np from scipy.stats import kde from .stats import * -from .trace import * __all__ = ['traceplot', 'kdeplot', 'kde2plot', 'forestplot', 'autocorrplot'] @@ -23,8 +22,8 @@ def traceplot(trace, vars=None, figsize=None, lines to the posteriors and horizontal lines on sample values e.g. mean of posteriors, true values of a simulation combined : bool - Flag for combining MultiTrace into a single trace. If False (default) - traces will be plotted separately on the same set of axes. + Flag for combining multiple chains into a single chain. If False + (default), chains will be plotted separately. grid : bool Flag for adding gridlines to histogram. Defaults to True. @@ -38,14 +37,6 @@ def traceplot(trace, vars=None, figsize=None, if vars is None: vars = trace.varnames - if isinstance(trace, MultiTrace): - if combined: - traces = [trace.combined()] - else: - traces = trace.traces - else: - traces = [trace] - n = len(vars) if figsize is None: @@ -53,11 +44,10 @@ def traceplot(trace, vars=None, figsize=None, fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize) - for trace in traces: - for i, v in enumerate(vars): - d = np.squeeze(trace[v]) - - if trace[v].dtype.kind == 'i': + for i, v in enumerate(vars): + for d in trace.get_values(v, combine=combined, squeeze=False): + d = np.squeeze(d) + if d.dtype.kind == 'i': histplot_op(ax[i, 0], d) else: kdeplot_op(ax[i, 0], d) @@ -133,35 +123,23 @@ def kde2plot(x, y, grid=200): def autocorrplot(trace, vars=None, fontmap=None, max_lag=100): """Bar plot of the autocorrelation function for a trace""" import matplotlib.pyplot as plt - try: - # MultiTrace - traces = trace.traces - - except AttributeError: - # NpTrace - traces = [trace] - if fontmap is None: fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4} if vars is None: - vars = traces[0].varnames - - # Extract sample data - samples = [{v: trace[v] for v in vars} for trace in traces] + vars = trace.varnames + else: + vars = [str(var) for var in vars] - chains = len(traces) + chains = trace.nchains - n = len(samples[0]) - f, ax = plt.subplots(n, chains, squeeze=False) + f, ax = plt.subplots(len(vars), chains, squeeze=False) - max_lag = min(len(samples[0][vars[0]])-1, max_lag) + max_lag = min(len(trace) - 1, max_lag) for i, v in enumerate(vars): - for j in range(chains): - - d = np.squeeze(samples[j][v]) + d = np.squeeze(trace.get_values(v, chains=[j])) ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag) @@ -279,41 +257,26 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, interval_plot = None rhat_plot = None - try: - # First try MultiTrace type - traces = trace_obj.traces - - if rhat and len(traces) > 1: - - from .diagnostics import gelman_rubin - - R = gelman_rubin(trace_obj) - if vars is not None: - R = {v: R[v] for v in vars} - - else: - - rhat = False - - except AttributeError: - - # Single NpTrace - traces = [trace_obj] + nchains = trace_obj.nchains + if nchains > 1: + from .diagnostics import gelman_rubin + R = gelman_rubin(trace_obj) + if vars is not None: + R = {v: R[v] for v in vars} + else: # Can't calculate Gelman-Rubin with a single trace rhat = False if vars is None: - vars = traces[0].varnames + vars = trace_obj.varnames # Empty list for y-axis labels labels = [] - chains = len(traces) - if gs is None: # Initialize plot - if rhat and chains > 1: + if rhat and nchains > 1: gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) else: @@ -323,21 +286,18 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) - for j, tr in enumerate(traces): - # Get quantiles - trace_quantiles = quantiles(tr, qlist) - hpd_intervals = hpd(tr, alpha) + trace_quantiles = quantiles(trace_obj, qlist, squeeze=False) + hpd_intervals = hpd(trace_obj, alpha, squeeze=False) + for j, chain in enumerate(trace_obj.chains): # Counter for current variable var = 1 - for varname in vars: - - var_quantiles = trace_quantiles[varname] + var_quantiles = trace_quantiles[chain][varname] quants = list(var_quantiles.values()) - var_hpd = hpd_intervals[varname].T + var_hpd = hpd_intervals[chain][varname].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T @@ -354,7 +314,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, plotrange = [np.min(quants), np.max(quants)] # Number of elements in current variable - value = tr[varname][0] + value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list @@ -368,7 +328,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, # Add spacing for each chain, if more than one e = [0] + [(chain_spacing * ((i + 2) / 2)) * - (-1) ** i for i in range(chains - 1)] + (-1) ** i for i in range(nchains - 1)] # Deal with multivariate nodes if k > 1: @@ -480,7 +440,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, plt.axvline(vline, color='k', linestyle='--') # Genenerate Gelman-Rubin plot - if rhat and chains > 1: + if rhat and nchains > 1: # If there are multiple chains, calculate R-hat rhat_plot = plt.subplot(gs[1]) @@ -498,7 +458,8 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, i = 1 for varname in vars: - value = traces[0][varname][0] + chain = trace_obj.chains[0] + value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) if k > 1: diff --git a/pymc/sampling.py b/pymc/sampling.py index f6df93689b..d46fc432cb 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -1,5 +1,6 @@ from .point import * -from .trace import NpTrace, MultiTrace +from pymc.backends.base import merge_traces, BaseTrace, MultiTrace +from pymc.backends.ndarray import NDArray import multiprocessing as mp from time import time from .core import * @@ -7,10 +8,11 @@ from .progressbar import progress_bar from numpy.random import seed -__all__ = ['sample', 'psample', 'iter_sample'] +__all__ = ['sample', 'iter_sample'] -def sample(draws, step, start=None, trace=None, tune=None, progressbar=True, model=None, random_seed=None): +def sample(draws, step, start=None, trace=None, chain=0, njobs=1, tune=None, + progressbar=True, model=None, random_seed=None): """ Draw a number of samples using the given step method. Multiple step methods supported via compound step method @@ -27,37 +29,84 @@ def sample(draws, step, start=None, trace=None, tune=None, progressbar=True, mod Starting point in parameter space (or partial point) Defaults to trace.point(-1)) if there is a trace provided and model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + trace : backend, list, or MultiTrace + This should be a backend instance, a list of variables to track, + or a MultiTrace object with past values. If a MultiTrace object + is given, it must contain samples for the chain number `chain`. + If None or a list of variables, the NDArray backend is used. + chain : int + Chain number used to store sample in backend. If `njobs` is + greater than one, chain numbers will start here. + njobs : int + Number of parallel jobs to start. If None, set to number of cpus + in the system - 2. tune : int Number of iterations to tune, if applicable (defaults to None) progressbar : bool Flag for progress bar model : Model (optional if in `with` context) + random_seed : int or list of ints + A list is accepted if more if `njobs` is greater than one. + Returns + ------- + MultiTrace object with access to sampling values """ - progress = progress_bar(draws) + if njobs is None: + njobs = max(mp.cpu_count() - 2, 1) + if njobs > 1: + try: + if not len(random_seed) == njobs: + random_seeds = [random_seed] * njobs + else: + random_seeds = random_seed + except TypeError: # None, int + random_seeds = [random_seed] * njobs + + chains = list(range(chain, chain + njobs)) + + pbars = [progressbar] + [False] * (njobs - 1) + + argset = zip([draws] * njobs, + [step] * njobs, + [start] * njobs, + [trace] * njobs, + chains, + [tune] * njobs, + pbars, + [model] * njobs, + random_seeds) + sample_func = _mp_sample + sample_args = [njobs, argset] + else: + sample_func = _sample + sample_args = [draws, step, start, trace, chain, + tune, progressbar, model, random_seed] + return sample_func(*sample_args) + +def _sample(draws, step, start=None, trace=None, chain=0, tune=None, + progressbar=True, model=None, random_seed=None): + sampling = _iter_sample(draws, step, start, trace, chain, + tune, model, random_seed) + progress = progress_bar(draws) try: - for i, trace in enumerate(iter_sample(draws, step, - start=start, - trace=trace, - tune=tune, - model=model, - random_seed=random_seed)): + for i, trace in enumerate(sampling): if progressbar: progress.update(i) except KeyboardInterrupt: - pass - return trace + trace.close() + return MultiTrace([trace]) + -def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, random_seed=None): +def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, + model=None, random_seed=None): """ Generator that returns a trace on each iteration using the given step method. Multiple step methods supported via compound step method returns the amount of time taken. + Parameters ---------- @@ -69,41 +118,52 @@ def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, rand Starting point in parameter space (or partial point) Defaults to trace.point(-1)) if there is a trace provided and model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + trace : backend, list, or MultiTrace + This should be a backend instance, a list of variables to track, + or a MultiTrace object with past values. If a MultiTrace object + is given, it must contain samples for the chain number `chain`. + If None or a list of variables, the NDArray backend is used. + chain : int + Chain number used to store sample in backend. If `njobs` is + greater than one, chain numbers will start here. tune : int Number of iterations to tune, if applicable (defaults to None) model : Model (optional if in `with` context) + random_seed : int or list of ints + A list is accepted if more if `njobs` is greater than one. Example ------- for trace in iter_sample(500, step): ... - """ + sampling = _iter_sample(draws, step, start, trace, chain, tune, + model, random_seed) + for i, trace in enumerate(sampling): + yield trace[:i + 1] + + +def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, + model=None, random_seed=None): model = modelcontext(model) draws = int(draws) seed(random_seed) + if draws < 1: + raise ValueError('Argument `draws` should be above 0.') if start is None: start = {} - if isinstance(trace, NpTrace) and len(trace) > 0: - trace_point = trace.point(-1) - trace_point.update(start) - start = trace_point + if isinstance(trace, MultiTrace): + trace = trace._traces[chain] + elif not isinstance(trace, BaseTrace): + trace = NDArray(model=model, vars=trace) + if len(trace) > 0: + _soft_update(start, trace.point(-1)) else: - test_point = model.test_point.copy() - test_point.update(start) - start = test_point - - if not isinstance(trace, NpTrace): - if trace is None: - trace = model.unobserved_RVs - trace = NpTrace(trace) + _soft_update(start, model.test_point) try: step = step_methods.CompoundStep(step) @@ -112,12 +172,22 @@ def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, rand point = Point(start, model=model) + trace.setup(draws, chain) for i in range(draws): - if (i == tune): + if i == tune: step = stop_tuning(step) point = step.step(point) trace.record(point) yield trace + else: + trace.close() + + +def _mp_sample(njobs, args): + p = mp.Pool(njobs) + traces = p.map(argsample, args) + p.close() + return merge_traces(traces) def stop_tuning(step): @@ -134,71 +204,10 @@ def stop_tuning(step): def argsample(args): """ defined at top level so it can be pickled""" - return sample(*args) + return _sample(*args) -def psample(draws, step, start=None, trace=None, tune=None, progressbar=True, - model=None, threads=None, random_seeds=None): - """draw a number of samples using the given step method. - Multiple step methods supported via compound step method - returns the amount of time taken - - Parameters - ---------- - - draws : int - The number of samples to draw - step : function - A step function - start : dict - Starting point in parameter space (Defaults to trace.point(-1)) - trace : MultiTrace or list - Either a trace of past values or a list of variables to track (defaults to None) - tune : int - Number of iterations to tune, if applicable (defaults to None) - progressbar : bool - Flag for progress bar - model : Model (optional if in `with` context) - threads : int - Number of parallel traces to start - - Examples - -------- - - >>> an example - +def _soft_update(a, b): + """As opposed to dict.update, don't overwrite keys if present. """ - - model = modelcontext(model) - - if not threads: - threads = max(mp.cpu_count() - 2, 1) - - if start is None: - start = {} - - if isinstance(start, dict): - start = threads * [start] - - if trace is None: - trace = model.vars - - if type(trace) is MultiTrace: - mtrace = trace - else: - mtrace = MultiTrace(threads, trace) - - p = mp.Pool(threads) - - if random_seeds is None: - random_seeds = [None] * threads - pbars = [progressbar] + [False] * (threads - 1) - - argset = zip([draws] * threads, [step] * threads, start, mtrace.traces, - [tune] * threads, pbars, [model] * threads, random_seeds) - - traces = p.map(argsample, argset) - - p.close() - - return MultiTrace(traces) + a.update({k: v for k, v in b.items() if k not in a}) diff --git a/pymc/stats.py b/pymc/stats.py index b028e35b54..22dbdc1901 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -11,39 +11,30 @@ def statfunc(f): """ def wrapped_f(pymc_obj, *args, **kwargs): - - try: - burn = kwargs.pop('burn') - except KeyError: - burn = 0 - try: - # MultiTrace - traces = pymc_obj.traces - - try: - vars = kwargs.pop('vars') - except KeyError: - vars = traces[0].varnames - - return [{v: f(trace[v][burn:], *args, **kwargs) for v in vars} for trace in traces] - + vars = kwargs.pop('vars', pymc_obj.varnames) + chains = kwargs.pop('chains', pymc_obj.chains) except AttributeError: - pass - - try: - # NpTrace - try: - vars = kwargs.pop('vars') - except KeyError: - vars = pymc_obj.varnames - - return {v: f(pymc_obj[v][burn:], *args, **kwargs) for v in vars} - except AttributeError: - pass - - # If others fail, assume that raw data is passed - return f(pymc_obj, *args, **kwargs) + # If fails, assume that raw data was passed. + return f(pymc_obj, *args, **kwargs) + + burn = kwargs.pop('burn', 0) + thin = kwargs.pop('thin', 1) + combine = kwargs.pop('combine', False) + ## Remove outer level chain keys if only one chain) + squeeze = kwargs.pop('squeeze', True) + + results = {chain: {} for chain in chains} + for var in vars: + samples = pymc_obj.get_values(var, chains=chains, burn=burn, + thin=thin, combine=combine, + squeeze=False) + for chain, data in zip(chains, samples): + results[chain][var] = f(np.squeeze(data), *args, **kwargs) + + if squeeze and (len(chains) == 1 or combine): + results = results[chains[0]] + return results wrapped_f.__doc__ = f.__doc__ wrapped_f.__name__ = f.__name__ diff --git a/pymc/tests/checks.py b/pymc/tests/checks.py index 2e5470bc8f..6d4adf2180 100644 --- a/pymc/tests/checks.py +++ b/pymc/tests/checks.py @@ -1,7 +1,7 @@ import pymc as pm import numpy as np -from pymc import sample, psample +from pymc import sample from numpy.testing import assert_almost_equal diff --git a/pymc/tests/test_base_backend.py b/pymc/tests/test_base_backend.py new file mode 100644 index 0000000000..16536e1823 --- /dev/null +++ b/pymc/tests/test_base_backend.py @@ -0,0 +1,56 @@ +import numpy as np +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc.backends import base + + +class TestMultiTrace(unittest.TestCase): + + def test_multitrace_init_unique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + trace1 = mock.Mock() + trace1.chain = 1 + mtrace = base.MultiTrace([trace0, trace1]) + self.assertEqual(mtrace._traces[0], trace0) + self.assertEqual(mtrace._traces[1], trace1) + + def test_multitrace_init_nonunique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + trace1 = mock.Mock() + trace1.chain = 0 + self.assertRaises(ValueError, + base.MultiTrace, [trace0, trace1]) + + +class TestMergeChains(unittest.TestCase): + + def test_merge_traces_unique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + mtrace0 = base.MultiTrace([trace0]) + + trace1 = mock.Mock() + trace1.chain = 1 + mtrace1 = base.MultiTrace([trace1]) + + merged = base.merge_traces([mtrace0, mtrace1]) + self.assertEqual(merged._traces[0], trace0) + self.assertEqual(merged._traces[1], trace1) + + def test_merge_traces_nonunique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + mtrace0 = base.MultiTrace([trace0]) + + trace1 = mock.Mock() + trace1.chain = 0 + mtrace1 = base.MultiTrace([trace1]) + + self.assertRaises(ValueError, + base.merge_traces, [mtrace0, mtrace1]) diff --git a/pymc/tests/test_diagnostics.py b/pymc/tests/test_diagnostics.py index 8322a7ed27..6a04aeea28 100644 --- a/pymc/tests/test_diagnostics.py +++ b/pymc/tests/test_diagnostics.py @@ -9,8 +9,8 @@ def test_gelman_rubin(n=1000): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(n, [step1, step2], start, threads=2, - random_seeds=[1, 3]) + ptrace = sample(n, [step1, step2], start, njobs=2, + random_seed=[1, 3]) rhat = gelman_rubin(ptrace) diff --git a/pymc/tests/test_glm.py b/pymc/tests/test_glm.py index 1882d6040c..0262261e45 100644 --- a/pymc/tests/test_glm.py +++ b/pymc/tests/test_glm.py @@ -44,9 +44,9 @@ def test_linear_component(self): step = Slice(model.vars) trace = sample(2000, step, start, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) @unittest.skip("Fails only on travis. Investigate") def test_glm(self): @@ -57,9 +57,9 @@ def test_glm(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) def test_glm_link_func(self): with Model() as model: @@ -71,5 +71,5 @@ def test_glm_link_func(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 0) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 0) diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py new file mode 100644 index 0000000000..e0319090a8 --- /dev/null +++ b/pymc/tests/test_ndarray_backend.py @@ -0,0 +1,321 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc.backends import base, ndarray + + +class NDArrayTestCase(unittest.TestCase): + def setUp(self): + self.varnames = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.varnames + self.model.fastfn = mock.MagicMock() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace = ndarray.NDArray() + + +class TestNDArraySampling(NDArrayTestCase): + + def test_setup_scalar(self): + trace = self.trace + trace.var_shapes = {'x': ()} + draws, chain = 3, 0 + trace.setup(draws, chain) + npt.assert_equal(trace.samples['x'], np.zeros(draws)) + + def test_setup_1d(self): + trace = self.trace + shape = (2,) + trace.var_shapes = {'x': shape} + draws, chain = 3, 0 + trace.setup(draws, chain) + npt.assert_equal(trace.samples['x'], np.zeros((draws,) + shape)) + + def test_record(self): + trace = self.trace + draws = 3 + + trace.var_shapes = {'x': (), 'y': (4,)} + trace.setup(draws, chain=0) + + def just_ones(*args): + while True: + yield 1. + + trace.fn = just_ones + trace.draw_idx = 0 + + trace.record(point=None) + npt.assert_equal(1., trace.get_values('x')[0]) + npt.assert_equal(np.ones(4), trace['y'][0]) + + def test_clean_interrupt(self): + trace = self.trace + trace.setup(draws=10, chain=0) + trace.samples = {'x': np.zeros(10), 'y': np.zeros((10, 5))} + trace.draw_idx = 3 + trace.close() + npt.assert_equal(np.zeros(3), trace['x']) + npt.assert_equal(np.zeros((3, 5)), trace['y']) + + def test_standard_close(self): + trace = self.trace + trace.setup(draws=10, chain=0) + trace.samples = {'x': np.zeros(10), 'y': np.zeros((10, 5))} + trace.draw_idx = 10 + trace.close() + npt.assert_equal(np.zeros(10), trace['x']) + npt.assert_equal(np.zeros((10, 5)), trace['y']) + + +class TestNDArraySelection(NDArrayTestCase): + + def setUp(self): + super(TestNDArraySelection, self).setUp() + draws = 3 + self.trace.samples = {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))} + self.draws = draws + var_shapes = {'x': (), 'y': (2,)} + self.var_shapes = var_shapes + self.trace.var_shapes = var_shapes + + def test_get_values_default(self): + base_shape = (self.draws,) + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + xsample = self.trace.get_values('x') + npt.assert_equal(np.zeros(base_shape + xshape), xsample) + + ysample = self.trace.get_values('y') + npt.assert_equal(np.zeros(base_shape + yshape), ysample) + + def test_get_values_burn_keyword(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples['y'][:burn] = np.ones((burn,) + yshape) + + xsample = self.trace.get_values('x', burn=burn) + npt.assert_equal(np.zeros(base_shape + xshape)[burn:], xsample) + + ysample = self.trace.get_values('y', burn=burn) + npt.assert_equal(np.zeros(base_shape + yshape)[burn:], ysample) + + def test_get_values_thin_keyword(self): + base_shape = (self.draws,) + thin = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + xthin = np.ones((self.draws,) + xshape)[::thin] + ythin = np.ones((self.draws,) + yshape)[::thin] + self.trace.samples['x'][::thin] = xthin + self.trace.samples['y'][::thin] = ythin + + xsample = self.trace.get_values('x', thin=thin) + npt.assert_equal(xthin, xsample) + + ysample = self.trace.get_values('y', thin=thin) + npt.assert_equal(ythin, ysample) + + def test_point(self): + idx = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples['x'][idx] = 1. + self.trace.samples['y'][idx] = 1. + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for varname, value in expected.items(): + npt.assert_equal(value, point[varname]) + + def test_slice(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples['y'][:burn] = np.ones((burn,) + yshape) + + sliced = self.trace[burn:] + + expected = {'x': np.zeros(base_shape + xshape)[burn:], + 'y': np.zeros(base_shape + yshape)[burn:]} + + for varname, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.samples[varname], + expected[varname]) + + +class TestNDArrayMultipleChains(unittest.TestCase): + + def setUp(self): + varnames = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + + self.varnames = varnames + self.var_shapes = var_shapes + self.draws = draws + self.total_draws = 2 * draws + + self.model = mock.Mock() + self.model.unobserved_RVs = varnames + self.model.fastfn = mock.MagicMock() + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + trace0 = ndarray.NDArray(varnames) + trace0.samples = {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))} + trace0.chain = 0 + + trace1 = ndarray.NDArray(varnames) + trace1.samples = {'x': np.ones(draws), + 'y': np.ones((draws, 2))} + trace1.chain = 1 + + self.mtrace = base.MultiTrace([trace0, trace1]) + + def test_chains_multichain(self): + self.mtrace.chains == [0, 1] + + def test_nchains_multichain(self): + self.mtrace.nchains == 1 + + def test_get_values_multi_default(self): + sample = self.mtrace.get_values('x') + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_one_chain_list_arg(self): + sample = self.mtrace.get_values('x', chains=[0]) + xshape = self.var_shapes['x'] + expected = np.zeros((self.draws,) + xshape) + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_one_chain_int_arg(self): + npt.assert_equal(self.mtrace.get_values('x', chains=[0]), + self.mtrace.get_values('x', chains=0)) + + def test_get_values_multi_chains_two_element_reversed(self): + sample = self.mtrace.get_values('x', chains=[1, 0]) + xshape = self.var_shapes['x'] + + expected = [np.ones((self.draws,) + xshape), + np.zeros((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_combine(self): + sample = self.mtrace.get_values('x', combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn(self): + sample = self.mtrace.get_values('x', burn=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn_combine(self): + sample = self.mtrace.get_values('x', burn=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin(self): + sample = self.mtrace.get_values('x', thin=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin_combine(self): + sample = self.mtrace.get_values('x', thin=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]]) + npt.assert_equal(sample, expected) + + def test_multichain_point(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.mtrace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for varname, value in expected.items(): + npt.assert_equal(value, point[varname]) + + def test_multichain_point_chain_arg(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.mtrace.point(idx, chain=0) + expected = {'x': np.squeeze(np.zeros(xshape)), + 'y': np.squeeze(np.zeros(yshape))} + + for varname, value in expected.items(): + npt.assert_equal(value, point[varname]) + + def test_multichain_slice(self): + burn = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + expected = {0: + {'x': np.zeros((self.draws, ) + xshape)[burn:], + 'y': np.zeros((self.draws, ) + yshape)[burn:]}, + 1: + {'x': np.ones((self.draws, ) + xshape)[burn:], + 'y': np.ones((self.draws, ) + yshape)[burn:]}} + + sliced = self.mtrace[burn:] + + for chain in self.mtrace.chains: + for varname, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.get_values(varname, chains=[0]), + expected[0][varname]) + npt.assert_equal(sliced.get_values(varname, chains=[1]), + expected[1][varname]) diff --git a/pymc/tests/test_plots.py b/pymc/tests/test_plots.py index 3d5acc19f1..ae9c70c24e 100644 --- a/pymc/tests/test_plots.py +++ b/pymc/tests/test_plots.py @@ -2,7 +2,7 @@ matplotlib.use('Agg', warn=False) from pymc.plots import * -from pymc import psample, Slice, Metropolis, find_hessian, sample +from pymc import Slice, Metropolis, find_hessian, sample def test_plots(): @@ -31,7 +31,7 @@ def test_multichain_plots(): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(1000, [step1, step2], start, threads=2) + ptrace = sample(1000, [step1, step2], start, njobs=2) forestplot(ptrace, vars=['early_mean', 'late_mean']) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index d4016c40f8..c41a5aea5a 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1,38 +1,85 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + import pymc -from pymc import sample, psample, iter_sample +from pymc import sampling +from pymc.sampling import sample from .models import simple_init -# Test if multiprocessing is available -import multiprocessing -try: - multiprocessing.Pool(2) - test_parallel = False -except: - test_parallel = False +## Set to False to keep effect of cea5659. Should this be set to True? +TEST_PARALLEL = False + + +@mock.patch('pymc.sampling._sample') +def test_sample_check_full_signature_single_process(sample_func): + sample('draws', 'step', start='start', trace='trace', njobs=1, chain=1, + tune='tune', progressbar='progressbar', model='model', + random_seed='random_seed') + sample_func.assert_called_with('draws', 'step', 'start', 'trace', 1, + 'tune', 'progressbar', 'model', + 'random_seed') + + +@mock.patch('pymc.sampling._mp_sample') +def test_sample_check_full_signature_mp(sample_func): + sample('draws', 'step', start='start', trace='trace', njobs=2, chain=1, + tune='tune', progressbar='progressbar', model='model', + random_seed=0) + + args = sample_func.call_args_list[0][0] + assert args[0] == 2 + + expected_argset = [('draws', 'step', 'start', 'trace', 1, 'tune', + 'progressbar', 'model', 0), + ('draws', 'step', 'start', 'trace', 2, 'tune', + False, 'model', 0)] + argset = list(args[1]) + assert argset == expected_argset + + +def test_soft_update_all_present(): + start = {'a': 1, 'b': 2} + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == {'a': 1, 'b': 2} + + +def test_soft_update_one_missing(): + start = {'a': 1, } + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == {'a': 1, 'b': 4} + + +def test_soft_update_empty(): + start = {} + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == test_point def test_sample(): model, start, step, _ = simple_init() - test_samplers = [sample] - - tr = sample(5, step, start, model=model) - test_traces = [None, tr] + test_njobs = [1] - if test_parallel: - test_samplers.append(psample) + if TEST_PARALLEL: + test_njobs.append(2) with model: - for trace in test_traces: - for samplr in test_samplers: - for n in [0, 1, 10, 300]: + for njobs in test_njobs: + for n in [1, 10, 300]: + yield sample, n, step, {}, None, njobs - yield samplr, n, step, {} - yield samplr, n, step, {}, trace - yield samplr, n, step, start def test_iter_sample(): model, start, step, _ = simple_init() - for i, trace in enumerate(iter_sample(5, step, start, model=model)): + samps = sampling.iter_sample(5, step, start, model=model) + for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." diff --git a/pymc/tests/test_trace.py b/pymc/tests/test_trace.py index 35c03a740e..ef82ee29df 100644 --- a/pymc/tests/test_trace.py +++ b/pymc/tests/test_trace.py @@ -5,130 +5,6 @@ import warnings import nose -# Test if multiprocessing is available -import multiprocessing -try: - multiprocessing.Pool(2) - test_parallel = False -except: - test_parallel = False - - -def check_trace(model, trace, n, step, start): - # try using a trace object a few times - - for i in range(2): - trace = sample( - n, step, start, trace, progressbar=False, model=model) - - for (var, val) in start.items(): - - assert np.shape(trace[var]) == (n * (i + 1),) + np.shape(val) - -def test_trace(): - model, start, step, _ = simple_init() - - for h in [pm.NpTrace]: - for n in [20, 1000]: - for vars in [model.vars, model.vars + [model.vars[0] ** 2]]: - trace = h(vars) - - yield check_trace, model, trace, n, step, start - - -def test_multitrace(): - if not test_parallel: - return - model, start, step, _ = simple_init() - trace = None - for n in [20, 1000]: - - yield check_multi_trace, model, trace, n, step, start - - -def check_multi_trace(model, trace, n, step, start): - - for i in range(2): - trace = psample( - n, step, start, trace, model=model) - - for (var, val) in start.items(): - print([len(tr.samples[var].vals) for tr in trace.traces]) - for t in trace[var]: - assert np.shape(t) == (n * (i + 1),) + np.shape(val) - - ctrace = trace.combined() - for (var, val) in start.items(): - - assert np.shape( - ctrace[var]) == (len(trace.traces) * n * (i + 1),) + np.shape(val) - - -def test_get_point(): - - p, model = simple_2model() - p2 = p.copy() - p2['x'] *= 2. - - x = pm.NpTrace(model.vars) - x.record(p) - x.record(p2) - assert x.point(1) == x[1] - -def test_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = sample(iterations, start=start, step=step, progressbar=False) - - burned = tr[burn:] - - # Slicing returns a trace - assert type(burned) is pm.trace.NpTrace - - # Indexing returns an array - assert type(tr[tr.varnames[0]]) is np.ndarray - - # Burned trace is of the correct length - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v].shape == (iterations, start[v].size) for v in tr.varnames]) - - # Now take more burn-in from the burned trace - burned_again = burned[burn:] - assert np.all([burned_again[v].shape == (iterations-2*burn, start[v].size) for v in burned_again.varnames]) - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - -def test_multi_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = psample(iterations, start=start, step=step, threads=2) - - burned = tr[burn:] - - # Slicing returns a MultiTrace - assert type(burned) is pm.trace.MultiTrace - - # Indexing returns a list of arrays - assert type(tr[tr.varnames[0]]) is list - assert type(tr[tr.varnames[0]][0]) is np.ndarray - - # # Burned trace is of the correct length - assert np.all([burned[v][0].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v][0].shape == (iterations, start[v].size) for v in tr.varnames]) - def test_summary_1_value_model(): mu = -2.1 diff --git a/pymc/trace.py b/pymc/trace.py index a1a70c0a35..2d76549a94 100644 --- a/pymc/trace.py +++ b/pymc/trace.py @@ -5,113 +5,11 @@ import types import warnings -__all__ = ['NpTrace', 'MultiTrace', 'summary'] - -class NpTrace(object): - """ - encapsulates the recording of a process chain - """ - def __init__(self, vars): - vars = list(vars) - model = vars[0].model - self.f = model.fastfn(vars) - self.vars = vars - self.varnames = list(map(str, vars)) - self.samples = dict((v, ListArray()) for v in self.varnames) - - def record(self, point): - """ - Records the position of a chain at a certain point in time. - """ - for var, value in zip(self.varnames, self.f(point)): - self.samples[var].append(value) - return self - - def __getitem__(self, index_value): - """ - Return copy NpTrace with sliced sample values if a slice is passed, - or the array of samples if a varname is passed. - """ - - if isinstance(index_value, slice): - - sliced_trace = NpTrace(self.vars) - sliced_trace.samples = dict((name, vals[index_value]) for (name, vals) in self.samples.items()) - - return sliced_trace - - else: - try: - return self.point(index_value) - except (ValueError, TypeError, IndexError): - pass - - return self.samples[str(index_value)].value - - def __len__(self): - return len(self.samples[self.varnames[0]]) - - def point(self, index): - return dict((k, v.value[index]) for (k, v) in self.samples.items()) - - -class ListArray(object): - def __init__(self, *args): - self.vals = list(args) - - @property - def value(self): - if len(self.vals) > 1: - self.vals = [np.concatenate(self.vals, axis=0)] - - return self.vals[0] - - def __getitem__(self, idx): - return ListArray(self.value[idx]) - - - def append(self, v): - self.vals.append(v[np.newaxis]) - - def __len__(self): - if self.vals: - return self.value.shape[0] - else: - return 0 - - -class MultiTrace(object): - def __init__(self, traces, vars=None): - try: - self.traces = list(traces) - except TypeError: - if vars is None: - raise ValueError("vars can't be None if trace count specified") - self.traces = [NpTrace(vars) for _ in range(traces)] - - def __getitem__(self, index_value): - - item_list = [h[index_value] for h in self.traces] - - if isinstance(index_value, slice): - return MultiTrace(item_list) - return item_list - - @property - def varnames(self): - return self.traces[0].varnames - - def point(self, index): - return [h.point(index) for h in self.traces] - - def combined(self): - # Returns a trace consisting of concatenated MultiTrace elements - h = NpTrace(self.traces[0].vars) - for k in self.traces[0].samples: - h.samples[k].vals = [s[k] for s in self.traces] - return h +__all__ = ['summary'] +# TODO: Move this to pymc.stats. (It was left here for diffing +# purposes). def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): """ Generate a pretty-printed summary of the node. @@ -142,15 +40,13 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): """ if vars is None: vars = trace.varnames - if isinstance(trace, MultiTrace): - trace = trace.combined() stat_summ = _StatSummary(roundto, batches, alpha) pq_summ = _PosteriorQuantileSummary(roundto, alpha) for var in vars: # Extract sampled values - sample = trace[var][start:] + sample = trace.get_values(var, burn=start, combine=True) if sample.ndim == 1: sample = sample[:, None] elif sample.ndim > 2: diff --git a/setup.py b/setup.py index 2a338bf31f..adcd609177 100755 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ long_description=LONG_DESCRIPTION, packages=['pymc', 'pymc.distributions', 'pymc.step_methods', 'pymc.tuning', - 'pymc.tests', 'pymc.glm'], + 'pymc.tests', 'pymc.glm', 'pymc.backends'], classifiers=classifiers, install_requires=install_reqs, tests_require=test_reqs, From 3a83062816af674edd268f6d44b48e9da2ad2bc1 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Feb 2014 22:16:17 -0500 Subject: [PATCH 04/15] Add Text backend --- pymc/backends/__init__.py | 1 + pymc/backends/text.py | 114 +++++++++++++++++++++ pymc/tests/test_text_backend.py | 170 ++++++++++++++++++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 pymc/backends/text.py create mode 100644 pymc/tests/test_text_backend.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index feebc00bb6..b5357a5b4c 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1 +1,2 @@ from pymc.backends.ndarray import NDArray +from pymc.backends.text import Text diff --git a/pymc/backends/text.py b/pymc/backends/text.py new file mode 100644 index 0000000000..40ef1a24d6 --- /dev/null +++ b/pymc/backends/text.py @@ -0,0 +1,114 @@ +"""Text file trace backend + +After sampling with NDArray backend, save results as text files. + +Database format +--------------- + +For each chain, a directory named `chain-N` is created. In this +directory, one file per variable is created containing the values of the +object. To deal with multidimensional variables, the array is reshaped +to one dimension before saving with `numpy.savetxt`. The shape +information is saved in a json file in the same directory and is used to +load the database back again using `numpy.loadtxt`. +""" +import os +import glob +import json +import numpy as np +from contextlib import contextmanager + +from pymc.backends import base +from pymc.backends.ndarray import NDArray + + +class Text(NDArray): + """Text storage + + Parameters + ---------- + name : str + Name of directory to store text files. + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name, model=None, vars=None): + super(Text, self).__init__(name, model, vars) + if not os.path.exists(name): + os.mkdir(name) + + def close(self): + super(Text, self).close() + chain_name = 'chain-{}'.format(self.chain) + chain_dir = os.path.join(self.name, chain_name) + os.mkdir(chain_dir) + + shapes = {} + for varname in self.varnames: + data = self.samples[varname] + var_file = os.path.join(chain_dir, varname + '.txt') + np.savetxt(var_file, data.reshape(-1, data.size)) + shapes[varname] = data.shape + ## Store shape information for reloading. + with _get_shape_fh(chain_dir, 'w') as sfh: + json.dump(shapes, sfh) + + +def load(name, chains=None, model=None): + """Load text database. + + Parameters + ---------- + name : str + Path to root directory for text database + chains : list + Chains to load. If None, all chains are loaded. + model : Model + If None, the model is taken from the `with` context. + + Returns + ------- + ndarray.Trace instance + """ + chain_dirs = _get_chain_dirs(name) + if chains is None: + chains = list(chain_dirs.keys()) + + traces = [] + for chain in chains: + chain_dir = chain_dirs[chain] + with _get_shape_fh(chain_dir, 'r') as sfh: + shapes = json.load(sfh) + samples = {} + for varname, shape in shapes.items(): + var_file = os.path.join(chain_dir, varname + '.txt') + samples[varname] = np.loadtxt(var_file).reshape(shape) + trace = NDArray(model=model) + trace.samples = samples + trace.chain = chain + traces.append(trace) + return base.MultiTrace(traces) + + +## The json file is opened here instead of `Text.close` and `load` for +## testing convenience. +@contextmanager +def _get_shape_fh(chain_dir, mode='r'): + fh = open(os.path.join(chain_dir, 'shapes.json'), mode) + try: + yield fh + finally: + fh.close() + + +def _get_chain_dirs(name): + """Return mapping of chain number to directory.""" + return {_chain_dir_to_chain(chain_dir): chain_dir + for chain_dir in glob.glob(os.path.join(name, 'chain-*'))} + + +def _chain_dir_to_chain(chain_dir): + return int(os.path.basename(chain_dir).split('-')[1]) diff --git a/pymc/tests/test_text_backend.py b/pymc/tests/test_text_backend.py new file mode 100644 index 0000000000..2e428bf3b8 --- /dev/null +++ b/pymc/tests/test_text_backend.py @@ -0,0 +1,170 @@ +import sys +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +if sys.version_info[0] == 2: + from StringIO import StringIO +else: + from io import StringIO +import json + +from pymc.backends import text + + +class TextTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + shape_fh_patch = mock.patch('pymc.backends.text._get_shape_fh') + self.addCleanup(shape_fh_patch.stop) + self.shape_fh = shape_fh_patch.start() + + mkdir_patch = mock.patch('pymc.backends.text.os.mkdir') + self.addCleanup(mkdir_patch.stop) + mkdir_patch.start() + + +class TestTextWrite(TextTestCase): + + def setUp(self): + super(TestTextWrite, self).setUp() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace = text.Text('textdb') + + self.draws = 5 + self.trace.var_shapes = {'x': (), 'y': (4,)} + self.trace.setup(self.draws, chain=0) + self.trace.draw_idx = self.draws + + savetxt_patch = mock.patch('pymc.backends.text.np.savetxt') + self.addCleanup(savetxt_patch.stop) + self.savetxt = savetxt_patch.start() + + def test_close_args(self): + trace = self.trace + + trace.close() + + self.assertEqual(self.savetxt.call_count, 2) + + for call, varname in enumerate(trace.varnames): + fname, data = self.savetxt.call_args_list[call][0] + self.assertEqual(fname, 'textdb/chain-0/{}.txt'.format(varname)) + npt.assert_equal(data, trace[varname].reshape(-1, data.size)) + + def test_close_shape(self): + trace = self.trace + + fh = StringIO() + self.shape_fh.return_value.__enter__.return_value = fh + trace.close() + self.shape_fh.assert_called_with('textdb/chain-0', 'w') + + shape_result = fh.getvalue() + expected = {varname: [self.draws] + list(var_shape) + for varname, var_shape in trace.var_shapes.items()} + self.assertEqual(json.loads(shape_result), expected) + + +def test__chain_dir_to_chain(): + assert text._chain_dir_to_chain('/path/to/chain-0') == 0 + assert text._chain_dir_to_chain('chain-0') == 0 + + +class TestTextLoad(TextTestCase): + + def setUp(self): + super(TestTextLoad, self).setUp() + + data = {'chain-1/x.txt': np.zeros(4), 'chain-1/y.txt': np.ones(2)} + loadtxt_patch = mock.patch('pymc.backends.text.np.loadtxt') + self.addCleanup(loadtxt_patch.stop) + self.loadtxt = loadtxt_patch.start() + + chain_patch = mock.patch('pymc.backends.text._get_chain_dirs') + self.addCleanup(chain_patch.stop) + self._get_chain_dirs = chain_patch.start() + + def test_load_model_supplied_scalar(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,)})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros(draws) + self.loadtxt.return_value = data + + mtrace = text.load('textdb', model=self.model) + npt.assert_equal(mtrace.get_values('x', chains=[0]), data) + + def test_load_model_supplied_1d(self): + draws = 4 + var_shape = (2,) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + mtrace = text.load('textdb', model=self.model) + npt.assert_equal(mtrace['x'], data) + + def test_load_model_supplied_2d(self): + draws = 4 + var_shape = (2, 3) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + mtrace = text.load('textdb', model=self.model) + npt.assert_equal(mtrace['x'], data) + + def test_load_model_supplied_multichain_chains(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + mtrace = text.load('textdb', model=self.model) + + self.assertEqual(mtrace.chains, [0, 1]) + + def test_load_model_supplied_multichain_chains_select_one(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + mtrace = text.load('textdb', model=self.model, chains=[1]) + + self.assertEqual(mtrace.chains, [1]) From d5295a47dd78bb22c29b1be65cb56d2886a5a070 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Feb 2014 22:16:36 -0500 Subject: [PATCH 05/15] Add SQLite backend --- pymc/backends/__init__.py | 1 + pymc/backends/sqlite.py | 341 ++++++++++++++++++++++++++++++ pymc/tests/test_sqlite_backend.py | 267 +++++++++++++++++++++++ 3 files changed, 609 insertions(+) create mode 100644 pymc/backends/sqlite.py create mode 100644 pymc/tests/test_sqlite_backend.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index b5357a5b4c..aa5f5fce6f 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1,2 +1,3 @@ from pymc.backends.ndarray import NDArray from pymc.backends.text import Text +from pymc.backends.sqlite import SQLite diff --git a/pymc/backends/sqlite.py b/pymc/backends/sqlite.py new file mode 100644 index 0000000000..24dd968c31 --- /dev/null +++ b/pymc/backends/sqlite.py @@ -0,0 +1,341 @@ +"""SQLite trace backend + +Store and retrieve sampling values in SQLite database file. + +Database format +--------------- +For each variable, a table is created with the following format: + + recid (INT), draw (INT), chain (INT), v1 (FLOAT), v2 (FLOAT), v3 (FLOAT) ... + +The variable column names are extended to reflect addition dimensions. +For example, a variable with the shape (2, 2) would be stored as + + key (INT), draw (INT), chain (INT), v1_1 (FLOAT), v1_2 (FLOAT), v2_1 (FLOAT) ... + +The key is autoincremented each time a new row is added to the table. +The chain column denotes the chain index, and starts at 0. +""" +import numpy as np +import sqlite3 +import warnings + +from pymc.backends import base + +TEMPLATES = { + 'table': ('CREATE TABLE IF NOT EXISTS [{table}] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, chain INT(5), ' + '{value_cols})'), + 'insert': ('INSERT INTO [{table}] ' + '(recid, draw, chain, {value_cols}) ' + 'VALUES (NULL, ?, ?, {values})'), + 'max_draw': ('SELECT MAX(draw) FROM [{table}] ' + 'WHERE chain = ?'), + 'draw_count': ('SELECT COUNT(*) FROM [{table}] ' + 'WHERE chain = ?'), + ## Named placeholders are used in the selection templates because + ## some values occur more than once in the same template. + 'select': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain)'), + 'select_burn': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw > :burn)'), + 'select_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND ' + '(draw - (SELECT draw FROM [{table}] ' + 'WHERE chain = :chain ' + 'ORDER BY draw LIMIT 1)) % :thin = 0'), + 'select_burn_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw > :burn) ' + 'AND (draw - (SELECT draw FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw > :burn) ' + 'ORDER BY draw LIMIT 1)) % :thin = 0'), + 'select_point': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw = :draw)'), +} + + +class SQLite(base.BaseTrace): + """SQLite trace object + + Parameters + ---------- + name : str + Name of database file + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name, model=None, vars=None): + super(SQLite, self).__init__(name, model, vars) + self._var_cols = {} + self.var_inserts = {} # varname -> insert statement + self.draw_idx = 0 + self._is_setup = False + self._len = None + + self.db = _SQLiteDB(name) + ## Inserting sampling information is queued to avoid locks + ## caused by hitting the database with transactions each + ## iteration. + self._queue = {varname: [] for varname in self.varnames} + self._queue_limit = 5000 + + ## Sampling methods + + def setup(self, draws, chain): + """Perform chain-specific setup. + + Parameters + ---------- + draws : int + Expected number of draws + chain : int + Chain number + """ + self.db.connect() + self.chain = chain + + if not self._is_setup: # Table has not been created. + self._var_cols = {varname: _create_colnames(shape) + for varname, shape in self.var_shapes.items()} + self._create_table() + self._is_setup = True + else: + self.draw_idx = self._get_max_draw(chain) + 1 + self._len = None + self._create_insert_queries(chain) + + def _create_table(self): + template = TEMPLATES['table'] + with self.db.con: + for varname, var_cols in self._var_cols.items(): + var_float = ', '.join([v + ' FLOAT' for v in var_cols]) + statement = template.format(table=varname, + value_cols=var_float) + self.db.cursor.execute(statement) + + def _create_insert_queries(self, chain): + template = TEMPLATES['insert'] + for varname, var_cols in self._var_cols.items(): + ## Create insert statement for each variable. + var_str = ', '.join(var_cols) + placeholders = ', '.join(['?'] * len(var_cols)) + statement = template.format(table=varname, + value_cols=var_str, + values=placeholders) + self.var_inserts[varname] = statement + + def record(self, point): + """Record results of a sampling iteration. + + Parameters + ---------- + point : dict + Values mapped to variable names + """ + for varname, value in zip(self.varnames, self.fn(point)): + values = (self.draw_idx, self.chain) + tuple(np.ravel(value)) + self._queue[varname].append(values) + + if len(self._queue[varname]) > self._queue_limit: + self._execute_queue() + self.draw_idx += 1 + + def _execute_queue(self): + with self.db.con: + for varname in self.varnames: + if not self._queue[varname]: + continue + self.db.cursor.executemany(self.var_inserts[varname], + self._queue[varname]) + self._queue[varname] = [] + + def close(self): + self._execute_queue() + self.db.close() + + ## Selection methods + + def __len__(self): + if not self._is_setup: + return 0 + if self._len is None: + self._len = self._get_number_draws() + return self._len + + def _get_number_draws(self): + self.db.connect() + statement = TEMPLATES['draw_count'].format(table=self.varnames[0]) + self.db.cursor.execute(statement, (self.chain,)) + return self.db.cursor.fetchall()[0][0] + + def _get_max_draw(self, chain): + self.db.connect() + statement = TEMPLATES['max_draw'].format(table=self.varnames[0]) + self.db.cursor.execute(statement, (chain, )) + return self.db.cursor.fetchall()[0][0] + + def get_values(self, varname, burn=0, thin=1): + """Get values from trace. + + Parameters + ---------- + varname : str + burn : int + thin : int + + Returns + ------- + A NumPy array + """ + if burn < 0: + raise ValueError('Negative burn values not supported ' + 'in SQLite backend.') + if thin < 1: + raise ValueError('Only positive thin values are supported ' + 'in SQLite backend.') + varname = str(varname) + + statement_args = {'chain': self.chain} + if burn == 0 and thin == 1: + action = 'select' + elif thin == 1: + action = 'select_burn' + statement_args['burn'] = burn - 1 + elif burn == 0: + action = 'select_thin' + statement_args['thin'] = thin + else: + action = 'select_burn_thin' + statement_args['burn'] = burn - 1 + statement_args['thin'] = thin + + self.db.connect() + statement = TEMPLATES[action].format(table=varname) + self.db.cursor.execute(statement, statement_args) + return _rows_to_ndarray(self.db.cursor) + + def _slice(self, idx): + warnings.warn('Slice for SQLite backend has no effect.') + + def point(self, idx): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + """ + idx = int(idx) + if idx < 0: + idx = self._get_max_draw(self.chain) - idx - 1 + statement = TEMPLATES['select_point'] + self.db.connect() + var_values = {} + statement_args = {'chain': self.chain, 'draw': idx} + for varname in self.varnames: + self.db.cursor.execute(statement.format(table=varname), + statement_args) + var_values[varname] = np.squeeze( + _rows_to_ndarray(self.db.cursor)) + return var_values + + +class _SQLiteDB(object): + def __init__(self, name): + self.name = name + self.con = None + self.cursor = None + self.connected = False + + def connect(self): + if self.connected: + return + self.con = sqlite3.connect(self.name) + self.connected = True + self.cursor = self.con.cursor() + + def close(self): + if not self.connected: + return + self.con.commit() + self.cursor.close() + self.con.close() + self.connected = False + + +def _create_colnames(shape): + """Return column names based on `shape`. + + Examples + -------- + >>> create_colnames((5,)) + ['v1', 'v2', 'v3', 'v4', 'v5'] + + >>> create_colnames((2,2)) + ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + """ + if not shape: + return ['v1'] + + size = np.prod(shape) + indices = (np.indices(shape) + 1).reshape(-1, size) + return ['v' + '_'.join(map(str, i)) for i in zip(*indices)] + + +def load(name, model=None): + """Load SQLite database. + + Parameters + ---------- + name : str + Path to SQLite database file + model : Model + If None, the model is taken from the `with` context. + + Returns + ------- + A MultiTrace instance + """ + db = _SQLiteDB(name) + db.connect() + varnames = _get_table_list(db.cursor) + chains = _get_chain_list(db.cursor, varnames[0]) + + traces = [] + for chain in chains: + trace = SQLite(name, model=model) + trace.varnames = varnames + trace.chain = chain + trace._is_setup = True + trace.db = db # Share the db with all traces. + traces.append(trace) + return base.MultiTrace(traces) + + +def _get_table_list(cursor): + """Return a list of table names in the current database.""" + ## Modified from Django. Skips the sqlite_sequence system table used + ## for autoincrement key generation. + cursor.execute("SELECT name FROM sqlite_master " + "WHERE type='table' AND NOT name='sqlite_sequence' " + "ORDER BY name") + return [row[0] for row in cursor.fetchall()] + + +def _get_var_strs(cursor, varname): + cursor.execute('SELECT * FROM [{}]'.format(varname)) + col_names = (col_descr[0] for col_descr in cursor.description) + return [name for name in col_names if name.startswith('v')] + + +def _get_chain_list(cursor, varname): + """Return a list of sorted chains for `varname`.""" + cursor.execute('SELECT DISTINCT chain FROM [{}]'.format(varname)) + chains = [chain[0] for chain in cursor.fetchall()] + chains.sort() + return chains + + +def _rows_to_ndarray(cursor): + """Convert SQL row to NDArray.""" + return np.array([row[3:] for row in cursor.fetchall()]) diff --git a/pymc/tests/test_sqlite_backend.py b/pymc/tests/test_sqlite_backend.py new file mode 100644 index 0000000000..d531fd8ab5 --- /dev/null +++ b/pymc/tests/test_sqlite_backend.py @@ -0,0 +1,267 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +import warnings + +from pymc.backends import base, sqlite + + +class SQLiteTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + db_patch = mock.patch('pymc.backends.sqlite._SQLiteDB') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace = sqlite.SQLite('test.db') + + self.draws = 5 + + self.trace.var_shapes = {'x': (), 'y': (3,)} + + self.trace._chains = [0] + self.trace._len = self.draws + + +class TestSQLiteSample(SQLiteTestCase): + + def test_setup_trace(self): + self.trace.setup(self.draws, chain=0) + assert self.trace.db.connect.called + + def test_setup_scalar(self): + trace = self.trace + trace.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT)') + trace.db.cursor.execute.assert_any_call(tbl_expected) + + trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' + 'VALUES (NULL, ?, ?, ?)') + self.assertEqual(trace.var_inserts['x'], trace_expected) + + def test_setup_1d(self): + trace = self.trace + trace.setup(draws=3, chain=0) + trace._chains = [] + + tbl_expected = ('CREATE TABLE IF NOT EXISTS [y] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT, v2 FLOAT, v3 FLOAT)') + trace.db.cursor.execute.assert_any_call(tbl_expected) + + trace_expected = ('INSERT INTO [y] (recid, draw, chain, v1, v2, v3) ' + 'VALUES (NULL, ?, ?, ?, ?, ?)') + self.assertEqual(trace.var_inserts['y'], trace_expected) + + def test_setup_2d(self): + trace = self.trace + trace.var_shapes = {'x': (2, 3)} + trace.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), ' + 'v1_1 FLOAT, v1_2 FLOAT, v1_3 FLOAT, ' + 'v2_1 FLOAT, v2_2 FLOAT, v2_3 FLOAT)') + + trace.db.cursor.execute.assert_any_call(tbl_expected) + trace_expected = ('INSERT INTO [x] (recid, draw, chain, ' + 'v1_1, v1_2, v1_3, ' + 'v2_1, v2_2, v2_3) ' + 'VALUES (NULL, ?, ?, ?, ?, ?, ?, ?, ?)') + self.assertEqual(trace.var_inserts['x'], trace_expected) + + def test_record_scalar(self): + trace = self.trace + trace.setup(draws=3, chain=0) + varname = 'x' + trace.varnames = ['x'] + + trace.draw_idx = 0 + trace.fn = mock.Mock(return_value=iter([3.])) + trace.record({'x': None}) + expected = (0, 0, 3.) + self.assertTrue(expected in self.trace._queue['x']) + + def test_record_1d(self): + trace = self.trace + trace.setup(draws=3, chain=0) + varname = 'x' + trace.varnames = ['x'] + + trace.draw_idx = 0 + trace.fn = mock.Mock(return_value=iter([[3., 3.]])) + trace.record({'x': None}) + expected = (0, 0, 3., 3.) + self.assertTrue(expected in self.trace._queue['x']) + + +class TestSQLiteSelection(SQLiteTestCase): + + def setUp(self): + super(TestSQLiteSelection, self).setUp() + self.trace.var_shapes = {'x': (), 'y': (4,)} + self.trace.setup(self.draws, chain=0) + + ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') + self.addCleanup(ndarray_patch.stop) + ndarray_patch.start() + + self.draws = 5 + + def test_get_values_default_keywords(self): + self.trace.get_values('x') + statement = sqlite.TEMPLATES['select'].format(table='x') + expected = (statement, {'chain': 0}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_burn_arg(self): + self.trace.get_values('x', burn=2).format(table='x') + statement = sqlite.TEMPLATES['select_burn'].format(table='x') + expected = (statement, {'chain': 0, 'burn': 1}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_thin_arg(self): + self.trace.get_values('x', thin=2) + statement = sqlite.TEMPLATES['select_thin'].format(table='x') + expected = (statement, {'chain': 0, 'thin': 2}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_burn_thin_arg(self): + self.trace.get_values('x', thin=2, burn=1) + statement = sqlite.TEMPLATES['select_burn_thin'].format(table='x') + expected = (statement, {'chain': 0, 'burn': 0, 'thin': 2}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_point(self): + idx = 2 + + point = self.trace.point(idx) + statement = sqlite.TEMPLATES['select_point'].format(table='x') + statement_args = {'chain': 0, 'draw': idx} + expected = {'x': (statement, statement_args), + 'y': (statement, statement_args)} + + for varname, value in expected.items(): + self.trace.db.cursor.execute.assert_any_call(*value) + + def test_slice(self): + with warnings.catch_warnings(record=True) as wrn: + self.trace[:10] + self.assertEqual(len(wrn), 1) + self.assertEqual(str(wrn[0].message), + 'Slice for SQLite backend has no effect.') + + +class TestSQLiteSelectionMultipleChains(SQLiteTestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + db_patch = mock.patch('pymc.backends.sqlite._SQLiteDB') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace0 = sqlite.SQLite('test.db') + self.trace1 = sqlite.SQLite('test.db') + + self.draws = 5 + + self.trace0.var_shapes = {'x': (), 'y': (3,)} + self.trace0.chain = 0 + self.trace0._len = self.draws + + self.trace1.var_shapes = {'x': (), 'y': (3,)} + self.trace1.chain = 1 + self.trace1._len = self.draws + + self.mtrace = base.MultiTrace([self.trace0, self.trace1]) + + ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') + self.addCleanup(ndarray_patch.stop) + ndarray_patch.start() + + self.draws = 5 + + def test_get_values_default_keywords(self): + self.mtrace.get_values('x') + + db = self.mtrace._traces[0].db + self.assertEqual(db.cursor.execute.call_count, 2) + + statement = 'SELECT * FROM [x] WHERE (chain = :chain)' + expected = [mock.call(statement, {'chain': chain}) + for chain in (0, 1)] + db.cursor.execute.assert_has_calls(expected) + + def test_get_values_chains_one_given(self): + self.mtrace.get_values('x', chains=[0]) + ## If 0 chain is last call, 1 was not called. + statement = sqlite.TEMPLATES['select'].format(table='x') + expected = (statement, {'chain': 0}) + trace = self.mtrace._traces[0] + trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_chains_one_chain_arg(self): + self.mtrace.get_values('x', chains=[0]) + ## If 0 chain is last call, 1 was not called. + statement = sqlite.TEMPLATES['select'].format(table='x') + expected = (statement, {'chain': 0}) + trace = self.mtrace._traces[0] + trace.db.cursor.execute.assert_called_with(*expected) + + +class TestSQLiteLoad(unittest.TestCase): + + def setUp(self): + db_patch = mock.patch('pymc.backends.sqlite._SQLiteDB') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + table_list_patch = mock.patch('pymc.backends.sqlite._get_table_list') + self.addCleanup(table_list_patch.stop) + self.table_list = table_list_patch.start() + self.table_list.return_value = ['x', 'y'] + + def test_load(self): + trace = sqlite.load('test.db') + assert self.table_list.called + assert self.db.called + + +def test_create_column_empty(): + result = sqlite._create_colnames(()) + expected = ['v1'] + assert result == expected + + +def test_create_column_1d(): + result = sqlite._create_colnames((2, )) + expected = ['v1', 'v2'] + assert result == expected + + +def test_create_column_2d(): + result = sqlite._create_colnames((2, 2)) + expected = ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + assert result == expected From 3c01fbaa5237a9cad7ec628e6c732d01a7f552e9 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Feb 2014 22:16:49 -0500 Subject: [PATCH 06/15] Add backend documentation --- pymc/backends/__init__.py | 111 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index aa5f5fce6f..c225f1c994 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1,3 +1,114 @@ +"""Backends for traces + +Available backends +------------------ + +1. NumPy array (pymc.backends.NDArray) +2. Text files (pymc.backends.Text) +3. SQLite (pymc.backends.SQLite) + +The NumPy arrays and text files both hold the entire trace in memory, +whereas SQLite commits the trace to the database while sampling. + +Selecting a backend +------------------- + +By default, a NumPy array is used as the backend. To specify a different +backend, pass a backend instance to `sample`. + +For example, the following would save traces to the file 'test.db'. + + >>> import pymc as pm + >>> db = pm.backends.SQLite('test.db') + >>> trace = pm.sample(..., trace=db) + +Selecting values from a backend +------------------------------- + +After a backend is finished sampling, it returns a MultiTrace object. +Values can be accessed in a few ways. The easiest way is to index the +backend object with a variable or variable name. + + >>> trace['x'] # or trace[x] + +The call will return a list containing the sampling values for all +chains of `x`. (Each call to `pymc.sample` creates a separate chain of +samples.) + +For more control of which values are returned, the `get_values` method +can be used. The call below will return values from all chains, burning +the first 1000 iterations from each chain. + + >>> trace.get_values('x', burn=1000) + +Setting the `combined` flag will concatenate the results from all the +chains. + + >>> trace.get_values('x', burn=1000, combine=True) + +The `chains` parameter of `get_values` can be used to limit the chains +that are retrieved. + + >>> trace.get_values('x', burn=1000, combine=True, chains=[0, 2]) + +Backends can also suppport slicing the MultiTrace object. For example, +the following call would return a new trace object without the first +1000 sampling iterations for all traces and variables. + + >>> sliced_trace = trace[1000:] + +Loading a saved backend +----------------------- + +Saved backends can be loaded using `load` function in the module for the +specific backend. + + >>> trace = pm.backends.sqlite.load('test.db') + +Writing custom backends +----------------------- + +Backends consist of a class that handles sampling storage and value +selection. Three sampling methods of backend will be called: + +- setup: Before sampling is started, the `setup` method will be called + with two arguments: the number of draws and the chain number. This is + useful setting up any structure for storing the sampling values that + require the above information. + +- record: Record the sampling results for the current draw. This method + will be called with a dictionary of values mapped to the variable + names. This is the only sampling function that *must* do something to + have a meaningful backend. + +- close: This method is called following sampling and should perform any + actions necessary for finalizing and cleaning up the backend. + +The base storage class `backends.base.BaseTrace` provides common model +setup that is used by all the PyMC backends. + +Several selection methods must also be defined: + +- get_values: This is the core method for selecting values from the + backend. It can be called directly and is used by __getitem__ when the + backend is indexed with a variable name or object. + +- _slice: Defines how the backend returns a slice of itself. This + is called if the backend is indexed with a slice range. + +- point: Returns values for each variable at a single iteration. This is + called if the backend is indexed with a single integer. + +- __len__: This should return the number of draws (for the highest chain + number). + +When `pymc.sample` finishes, it wraps all trace objects in a MultiTrace +object that provides a consistent selection interface for all backends. +If the traces are stored on disk, then a `load` function should also be +defined that returns a MultiTrace object. + +For specific examples, see pymc.backends.{ndarray,text,sqlite}.py. +""" from pymc.backends.ndarray import NDArray from pymc.backends.text import Text from pymc.backends.sqlite import SQLite From 640be845d682ab181df576ce6743dfe71ded9bf0 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Feb 2014 22:18:48 -0500 Subject: [PATCH 07/15] Test equality of NDArray and SQLite selections --- pymc/tests/test_ndarray_sqlite_selection.py | 128 ++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 pymc/tests/test_ndarray_sqlite_selection.py diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py new file mode 100644 index 0000000000..2856ba6450 --- /dev/null +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -0,0 +1,128 @@ +import os +import numpy as np +import numpy.testing as npt +import unittest +import multiprocessing as mp + +import pymc as pm + +## Set to False to keep effect of cea5659. Should this be set to True? +TEST_PARALLEL = False + + +def remove_file_or_directory(name): + try: + os.remove(name) + except OSError: + shutil.rmtree(name, ignore_errors=True) + + +class TestCompareNDArraySQLite(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if TEST_PARALLEL: + njobs = 2 + else: + njobs = 1 + + data = np.random.normal(size=(3, 20)) + n = 1 + + model = pm.Model() + draws = 5 + with model: + x = pm.Normal('x', 0, 1., shape=n) + + start = {'x': 0.} + step = pm.Metropolis() + cls.db = 'test.db' + + try: + cls.ntrace = pm.sample(draws, step=step, + njobs=njobs, random_seed=9) + cls.strace = pm.sample(draws, step=step, + njobs=njobs, random_seed=9, + trace=pm.backends.SQLite(cls.db)) + ## Extend each trace. + cls.ntrace = pm.sample(draws, step=step, + njobs=njobs, random_seed=4, + trace=cls.ntrace) + cls.strace = pm.sample(draws, step=step, + njobs=njobs, random_seed=4, + trace=cls.strace) + cls.draws = draws * 2 # Account for extension. + except: + remove_file_or_directory(cls.db) + raise + + @classmethod + def tearDownClass(cls): + remove_file_or_directory(cls.db) + + def test_chain_length(self): + assert self.ntrace.nchains == self.strace.nchains + assert len(self.ntrace) == len(self.strace) + + def test_number_of_draws(self): + nvalues = self.ntrace.get_values('x', squeeze=False) + svalues = self.strace.get_values('x', squeeze=False) + assert nvalues[0].shape[0] == self.draws + assert svalues[0].shape[0] == self.draws + + def test_get_item(self): + npt.assert_equal(self.ntrace['x'], self.strace['x']) + + def test_get_values(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf), + self.strace.get_values('x', combine=cf)) + + def test_get_values_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=False, + squeeze=False), + self.strace.get_values('x', combine=False, + squeeze=False)) + + def test_get_values_combine_and_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=True, + squeeze=False), + self.strace.get_values('x', combine=True, + squeeze=False)) + + def test_get_values_with_burn(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, burn=3), + self.strace.get_values('x', combine=cf, burn=3)) + + ## Burn to one value. + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=self.draws - 1), + self.strace.get_values('x', combine=cf, + burn=self.draws - 1)) + + def test_get_values_with_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, thin=2), + self.strace.get_values('x', combine=cf, thin=2)) + + def test_get_values_with_burn_and_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=2, thin=2), + self.strace.get_values('x', combine=cf, + burn=2, thin=2)) + + def test_get_values_with_chains_arg(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', chains=[0]), + self.strace.get_values('x', chains=[0])) + + def test_point(self): + npoint, spoint = self.ntrace[4], self.strace[4] + npt.assert_equal(npoint['x'], spoint['x']) + + def test_point_with_chain_arg(self): + npoint = self.ntrace.point(4, chain=0) + spoint = self.strace.point(4, chain=0) + npt.assert_equal(npoint['x'], spoint['x']) From 6890fe96643296dd4b43db47454db762bbd707c7 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Feb 2014 22:17:19 -0500 Subject: [PATCH 08/15] Dump and load tests for text and SQLite --- pymc/tests/test_backend_dump_load.py | 102 +++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 pymc/tests/test_backend_dump_load.py diff --git a/pymc/tests/test_backend_dump_load.py b/pymc/tests/test_backend_dump_load.py new file mode 100644 index 0000000000..a9ab5b63ba --- /dev/null +++ b/pymc/tests/test_backend_dump_load.py @@ -0,0 +1,102 @@ +import os +import shutil +import numpy as np +import numpy.testing as npt +import unittest +import multiprocessing as mp + +import pymc as pm + +## Set to False to keep effect of cea5659. Should this be set to True? +TEST_PARALLEL = False + + +def remove_file_or_directory(name): + try: + os.remove(name) + except OSError: + shutil.rmtree(name, ignore_errors=True) + + +class DumpLoadTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if TEST_PARALLEL: + njobs = 2 + else: + njobs = 1 + + data = np.random.normal(size=(2, 20)) + model = pm.Model() + with model: + x = pm.Normal('x', mu=.5, tau=2. ** -2, shape=(2, 1)) + z = pm.Beta('z', alpha=10, beta=5.5) + d = pm.Normal('data', mu=x, tau=.75 ** -2, observed=data) + data = np.random.normal(size=(3, 20)) + n = 1 + + draws = 5 + cls.draws = draws + + with model: + try: + cls.trace = pm.sample(n, step=pm.Metropolis(), + trace=cls.backend(cls.db), + njobs=2) + cls.dumped = cls.load_func(cls.db) + except: + remove_file_or_directory(cls.db) + raise + + @classmethod + def tearDownClass(cls): + remove_file_or_directory(cls.db) + + +class TestTextDumpLoad(DumpLoadTestCase): + + backend = pm.backends.Text + load_func = staticmethod(pm.backends.text.load) + db = 'text-db' + + def test_nchains(self): + self.assertEqual(self.trace.nchains, self.dumped.nchains) + + def test_varnames(self): + trace_names = list(sorted(self.trace.varnames)) + dumped_names = list(sorted(self.dumped.varnames)) + self.assertEqual(trace_names, dumped_names) + + def test_values(self): + trace = self.trace + dumped = self.dumped + for chain in trace.chains: + for varname in trace.varnames: + data = trace.get_values(varname, chains=[chain]) + dumped_data = dumped.get_values(varname, chains=[chain]) + npt.assert_equal(data, dumped_data) + + +class TestSQLiteDumpLoad(DumpLoadTestCase): + + backend = pm.backends.SQLite + load_func = staticmethod(pm.backends.sqlite.load) + db = 'test.db' + + def test_nchains(self): + self.assertEqual(self.trace.nchains, self.dumped.nchains) + + def test_varnames(self): + trace_names = list(sorted(self.trace.varnames)) + dumped_names = list(sorted(self.dumped.varnames)) + self.assertEqual(trace_names, dumped_names) + + def test_values(self): + trace = self.trace + dumped = self.dumped + for chain in trace.chains: + for varname in trace.varnames: + data = trace.get_values(varname, chains=[chain]) + dumped_data = dumped.get_values(varname, chains=[chain]) + npt.assert_equal(data, dumped_data) From 764cdd1a2b58c2b5f6a67120747cbb104bb8e0ac Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Mar 2014 00:32:02 -0400 Subject: [PATCH 09/15] Remove unused multiprocessing from tests --- pymc/tests/test_backend_dump_load.py | 1 - pymc/tests/test_ndarray_sqlite_selection.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pymc/tests/test_backend_dump_load.py b/pymc/tests/test_backend_dump_load.py index a9ab5b63ba..6dc75d6814 100644 --- a/pymc/tests/test_backend_dump_load.py +++ b/pymc/tests/test_backend_dump_load.py @@ -3,7 +3,6 @@ import numpy as np import numpy.testing as npt import unittest -import multiprocessing as mp import pymc as pm diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py index 2856ba6450..d016043f70 100644 --- a/pymc/tests/test_ndarray_sqlite_selection.py +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -2,7 +2,6 @@ import numpy as np import numpy.testing as npt import unittest -import multiprocessing as mp import pymc as pm From c3533e06c79e913c4865560e9207af363f7e3464 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Mar 2014 00:32:25 -0400 Subject: [PATCH 10/15] Add missing shutil import to test --- pymc/tests/test_ndarray_sqlite_selection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py index d016043f70..0864039668 100644 --- a/pymc/tests/test_ndarray_sqlite_selection.py +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -1,4 +1,5 @@ import os +import shutil import numpy as np import numpy.testing as npt import unittest From 0d5c5f9194ed4757201f01959a1233f6841166b8 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Mar 2014 11:46:06 -0400 Subject: [PATCH 11/15] Clean up obsolete information in docstring --- pymc/backends/ndarray.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 27f55c5e9f..a59457f0ce 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -105,8 +105,6 @@ def _slice(self, idx): def point(self, idx): """Return dictionary of point values at `idx` for current chain with variables names as keys. - - If `chain` is not specified, `default_chain` is used. """ idx = int(idx) return {varname: values[idx] From 92d813f9b37cc117529c2eed3e1bb11ac39cdff9 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Wed, 19 Mar 2014 18:43:32 -0400 Subject: [PATCH 12/15] Fix hardcoded `njobs` argument in dump test --- pymc/tests/test_backend_dump_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/tests/test_backend_dump_load.py b/pymc/tests/test_backend_dump_load.py index 6dc75d6814..8f1b705302 100644 --- a/pymc/tests/test_backend_dump_load.py +++ b/pymc/tests/test_backend_dump_load.py @@ -42,7 +42,7 @@ def setUpClass(cls): try: cls.trace = pm.sample(n, step=pm.Metropolis(), trace=cls.backend(cls.db), - njobs=2) + njobs=njobs) cls.dumped = cls.load_func(cls.db) except: remove_file_or_directory(cls.db) From 59788594f98a494bc39e72a04f9e8b82386b6dd0 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Wed, 19 Mar 2014 19:50:21 -0400 Subject: [PATCH 13/15] Add shortcuts for Text and SQLite backends --- pymc/backends/__init__.py | 5 +++++ pymc/sampling.py | 30 ++++++++++++++++++++++++++---- pymc/tests/test_sampling.py | 25 +++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index c225f1c994..6e087dc66d 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -112,3 +112,8 @@ from pymc.backends.ndarray import NDArray from pymc.backends.text import Text from pymc.backends.sqlite import SQLite + +_shortcuts = {'text': {'backend': Text, + 'name': 'mcmc'}, + 'sqlite': {'backend': SQLite, + 'name': 'mcmc.sqlite'}} diff --git a/pymc/sampling.py b/pymc/sampling.py index d46fc432cb..162d105788 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -1,4 +1,5 @@ from .point import * +from pymc import backends from pymc.backends.base import merge_traces, BaseTrace, MultiTrace from pymc.backends.ndarray import NDArray import multiprocessing as mp @@ -34,6 +35,9 @@ def sample(draws, step, start=None, trace=None, chain=0, njobs=1, tune=None, or a MultiTrace object with past values. If a MultiTrace object is given, it must contain samples for the chain number `chain`. If None or a list of variables, the NDArray backend is used. + Passing either "text" or "sqlite" is taken as a shortcut to set + up the corresponding backend (with "mcmc" used as the base + name). chain : int Chain number used to store sample in backend. If `njobs` is greater than one, chain numbers will start here. @@ -155,10 +159,7 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, if start is None: start = {} - if isinstance(trace, MultiTrace): - trace = trace._traces[chain] - elif not isinstance(trace, BaseTrace): - trace = NDArray(model=model, vars=trace) + trace = _choose_backend(trace, chain, model=model) if len(trace) > 0: _soft_update(start, trace.point(-1)) @@ -183,6 +184,27 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, trace.close() +def _choose_backend(trace, chain, shortcuts=None, **kwds): + if isinstance(trace, BaseTrace): + return trace + if isinstance(trace, MultiTrace): + return trace._traces[chain] + if trace is None: + return NDArray(**kwds) + + if shortcuts is None: + shortcuts = backends._shortcuts + + try: + backend = shortcuts[trace]['backend'] + name = shortcuts[trace]['name'] + return backend(name, **kwds) + except TypeError: + return NDArray(vars=trace, **kwds) + except KeyError: + raise ValueError('Argument `trace` is invalid.') + + def _mp_sample(njobs, args): p = mp.Pool(njobs) traces = p.map(argsample, args) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index c41a5aea5a..6a15c4725a 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -83,3 +83,28 @@ def test_iter_sample(): samps = sampling.iter_sample(5, step, start, model=model) for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." + + +class TestChooseBackend(unittest.TestCase): + + def test_choose_backend_none(self): + with mock.patch('pymc.sampling.NDArray') as nd: + sampling._choose_backend(None, 'chain') + self.assertTrue(nd.called) + + def test_choose_backend_list_of_variables(self): + with mock.patch('pymc.sampling.NDArray') as nd: + sampling._choose_backend(['var1', 'var2'], 'chain') + nd.assert_called_with(vars=['var1', 'var2']) + + def test_choose_backend_invalid(self): + self.assertRaises(ValueError, + sampling._choose_backend, + 'invalid', 'chain') + + def test_choose_backend_shortcut(self): + backend = mock.Mock() + shortcuts = {'test_backend': {'backend': backend, + 'name': None}} + sampling._choose_backend('test_backend', 'chain', shortcuts=shortcuts) + self.assertTrue(backend.called) From ec92cf24a0ab62b5421e4e908b7b7b78c1e99e16 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Mar 2014 00:10:59 -0400 Subject: [PATCH 14/15] Add SQLite backend example for hierarchical.py --- pymc/examples/hierarchical_sqlite.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 pymc/examples/hierarchical_sqlite.py diff --git a/pymc/examples/hierarchical_sqlite.py b/pymc/examples/hierarchical_sqlite.py new file mode 100644 index 0000000000..e39ebdb4a7 --- /dev/null +++ b/pymc/examples/hierarchical_sqlite.py @@ -0,0 +1,5 @@ +import pymc as pm +import pymc.examples.hierarchical as hier + +with hier.model: + trace = pm.sample(3000, hier.step, hier.start, trace='sqlite') From e6a32c66714990e30987381a880e7915733d81c1 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 20 Mar 2014 17:10:07 -0400 Subject: [PATCH 15/15] Hide hierarchical SQLite example under main I'm moving this under main so that it doesn't run with "test_examples". This could be set up like the other examples, with a run definition that allows for a short version, but I'd prefer not to for a couple of reasons. 1. Everything aside from the SQLite backend is the same as "hierarchical.py", so it isn't testing much more for the time added to the run. 2. This results in an SQLite file, so a cleanup should be added somewhere if it is run with "test_examples". --- pymc/examples/hierarchical_sqlite.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pymc/examples/hierarchical_sqlite.py b/pymc/examples/hierarchical_sqlite.py index e39ebdb4a7..d8850e8dae 100644 --- a/pymc/examples/hierarchical_sqlite.py +++ b/pymc/examples/hierarchical_sqlite.py @@ -1,5 +1,8 @@ -import pymc as pm -import pymc.examples.hierarchical as hier -with hier.model: - trace = pm.sample(3000, hier.step, hier.start, trace='sqlite') +if __name__ == '__main__': + ## Avoid loading during tests. + import pymc as pm + import pymc.examples.hierarchical as hier + + with hier.model: + trace = pm.sample(3000, hier.step, hier.start, trace='sqlite')