From c1c6a8500305ddafd0aadef629966cb5107f6102 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Fri, 3 Jan 2014 15:21:07 -0500 Subject: [PATCH 01/18] Include tests_require argument in setup.py --- setup.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 95052d6128..89a160c38c 100755 --- a/setup.py +++ b/setup.py @@ -25,8 +25,9 @@ 'Topic :: Scientific/Engineering :: Mathematics', 'Operating System :: OS Independent'] -required = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', - 'Theano==0.6.0'] +install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', + 'Theano==0.6.0'] +test_reqs = ['nose'] if __name__ == "__main__": setup(name=DISTNAME, @@ -41,4 +42,6 @@ 'pymc.step_methods', 'pymc.tuning', 'pymc.tests', 'pymc.glm'], classifiers=classifiers, - install_requires=required) + install_requires=install_reqs, + tests_require=test_reqs, + test_suite='nose.collector') From 3060c242c2fd78949109eae8292cbb4030f2364a Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Fri, 3 Jan 2014 15:21:56 -0500 Subject: [PATCH 02/18] Include mock as test dependency This is only needed for python 2 because mock is in stdlib for python 3 (unittest.mock). --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index 89a160c38c..2a338bf31f 100755 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python from setuptools import setup +import sys DISTNAME = 'pymc' @@ -27,7 +28,10 @@ install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', 'Theano==0.6.0'] + test_reqs = ['nose'] +if sys.version_info[0] == 2: # py3 has mock in stdlib + test_reqs.append('mock') if __name__ == "__main__": setup(name=DISTNAME, From 6dda7e1ab641a40336bd86d53201ddbefe47c5e8 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Sun, 15 Dec 2013 03:15:48 -0500 Subject: [PATCH 03/18] Move summary to stats module --- pymc/__init__.py | 1 + pymc/stats.py | 163 +++++++++++++++++++++++++++++++++++++- pymc/tests/test_stats.py | 164 ++++++++++++++++++++++++++++++++++++--- pymc/tests/test_trace.py | 146 ---------------------------------- pymc/trace.py | 162 +------------------------------------- 5 files changed, 319 insertions(+), 317 deletions(-) diff --git a/pymc/__init__.py b/pymc/__init__.py index 44760fddf3..8dd1d80fc2 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -6,6 +6,7 @@ from .trace import * from .sample import * +from .stats import summary from .step_methods import * from .tuning import * diff --git a/pymc/stats.py b/pymc/stats.py index b028e35b54..732a38151b 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -1,8 +1,11 @@ """Utility functions for PyMC""" import numpy as np +from .trace import MultiTrace +import warnings -__all__ = ['autocorr', 'autocov', 'hpd', 'quantiles', 'mc_error'] + +__all__ = ['autocorr', 'autocov', 'hpd', 'quantiles', 'mc_error', 'summary'] def statfunc(f): """ @@ -237,3 +240,161 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5)): except IndexError: print("Too few elements for quantile calculation") + + +def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): + """ + Generate a pretty-printed summary of the node. + + :Parameters: + trace : Trace object + Trace containing MCMC sample + + vars : list of strings + List of variables to summarize. Defaults to None, which results + in all variables summarized. + + alpha : float + The alpha level for generating posterior intervals. Defaults to + 0.05. + + start : int + The starting index from which to summarize (each) chain. Defaults + to zero. + + batches : int + Batch size for calculating standard deviation for non-independent + samples. Defaults to 100. + + roundto : int + The number of digits to round posterior statistics. + + """ + if vars is None: + vars = trace.varnames + if isinstance(trace, MultiTrace): + trace = trace.combined() + + stat_summ = _StatSummary(roundto, batches, alpha) + pq_summ = _PosteriorQuantileSummary(roundto, alpha) + + for var in vars: + # Extract sampled values + sample = trace[var][start:] + if sample.ndim == 1: + sample = sample[:, None] + elif sample.ndim > 2: + ## trace dimensions greater than 2 (variable greater than 1) + warnings.warn('Skipping {} (above 1 dimension)'.format(var)) + continue + + print('\n%s:' % var) + print(' ') + + stat_summ.print_output(sample) + pq_summ.print_output(sample) + + +class _Summary(object): + """Base class for summary output""" + def __init__(self, roundto): + self.roundto = roundto + self.header_lines = None + self.leader = ' ' + self.spaces = None + + def print_output(self, sample): + print('\n'.join(list(self._get_lines(sample))) + '\n') + + def _get_lines(self, sample): + for line in self.header_lines: + yield self.leader + line + summary_lines = self._calculate_values(sample) + for line in self._create_value_output(summary_lines): + yield self.leader + line + + def _create_value_output(self, lines): + for values in lines: + self._format_values(values) + yield self.value_line.format(pad=self.spaces, **values).strip() + + def _calculate_values(self, sample): + raise NotImplementedError + + def _format_values(self, summary_values): + for key, val in summary_values.items(): + summary_values[key] = '{:.{ndec}f}'.format( + float(val), ndec=self.roundto) + + +class _StatSummary(_Summary): + def __init__(self, roundto, batches, alpha): + super(_StatSummary, self).__init__(roundto) + spaces = 17 + hpd_name = '{}% HPD interval'.format(int(100 * (1 - alpha))) + value_line = '{mean:<{pad}}{sd:<{pad}}{mce:<{pad}}{hpd:<{pad}}' + header = value_line.format(mean='Mean', sd='SD', mce='MC Error', + hpd=hpd_name, pad=spaces).strip() + hline = '-' * len(header) + + self.header_lines = [header, hline] + self.spaces = spaces + self.value_line = value_line + self.batches = batches + self.alpha = alpha + + def _calculate_values(self, sample): + return _calculate_stats(sample, self.batches, self.alpha) + + def _format_values(self, summary_values): + roundto = self.roundto + for key, val in summary_values.items(): + if key == 'hpd': + summary_values[key] = '[{:.{ndec}f}, {:.{ndec}f}]'.format( + *val, ndec=roundto) + else: + summary_values[key] = '{:.{ndec}f}'.format( + float(val), ndec=roundto) + + +class _PosteriorQuantileSummary(_Summary): + def __init__(self, roundto, alpha): + super(_PosteriorQuantileSummary, self).__init__(roundto) + spaces = 15 + title = 'Posterior quantiles:' + value_line = '{lo:<{pad}}{q25:<{pad}}{q50:<{pad}}{q75:<{pad}}{hi:<{pad}}' + lo, hi = 100 * alpha / 2, 100 * (1. - alpha / 2) + qlist = (lo, 25, 50, 75, hi) + header = value_line.format(lo=lo, q25=25, q50=50, q75=75, hi=hi, + pad=spaces).strip() + hline = '|{thin}|{thick}|{thick}|{thin}|'.format( + thin='-' * (spaces - 1), thick='=' * (spaces - 1)) + + self.header_lines = [title, header, hline] + self.spaces = spaces + self.lo, self.hi = lo, hi + self.qlist = qlist + self.value_line = value_line + + def _calculate_values(self, sample): + return _calculate_posterior_quantiles(sample, self.qlist) + + +def _calculate_stats(sample, batches, alpha): + means = sample.mean(0) + sds = sample.std(0) + mces = mc_error(sample, batches) + intervals = hpd(sample, alpha) + for index in range(sample.shape[1]): + mean, sd, mce = [stat[index] for stat in (means, sds, mces)] + interval = intervals[index].squeeze().tolist() + yield {'mean': mean, 'sd': sd, 'mce': mce, 'hpd': interval} + + +def _calculate_posterior_quantiles(sample, qlist): + var_quantiles = quantiles(sample, qlist=qlist) + ## Replace ends of qlist with 'lo' and 'hi' + qends = {qlist[0]: 'lo', qlist[-1]: 'hi'} + qkeys = {q: qends[q] if q in qends else 'q{}'.format(q) for q in qlist} + for index in range(sample.shape[1]): + yield {qkeys[q]: var_quantiles[q][index] for q in qlist} diff --git a/pymc/tests/test_stats.py b/pymc/tests/test_stats.py index 3ec98f760f..371dfc736e 100644 --- a/pymc/tests/test_stats.py +++ b/pymc/tests/test_stats.py @@ -1,6 +1,10 @@ -from ..stats import * +import pymc as pm +from pymc import stats +import numpy as np from numpy.random import random, normal, seed from numpy.testing import assert_equal, assert_almost_equal, assert_array_almost_equal +import warnings +import nose seed(111) normal_sample = normal(0, 1, 1000000) @@ -8,37 +12,179 @@ def test_autocorr(): """Test autocorrelation and autocovariance functions""" - assert_almost_equal(autocorr(normal_sample), 0, 2) + assert_almost_equal(stats.autocorr(normal_sample), 0, 2) y = [(normal_sample[i-1] + normal_sample[i])/2 for i in range(1, len(normal_sample))] - assert_almost_equal(autocorr(y), 0.5, 2) + assert_almost_equal(stats.autocorr(y), 0.5, 2) def test_hpd(): """Test HPD calculation""" - interval = hpd(normal_sample) + interval = stats.hpd(normal_sample) assert_array_almost_equal(interval, [-1.96, 1.96], 2) def test_make_indices(): """Test make_indices function""" - from ..stats import make_indices - ind = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)] - assert_equal(ind, make_indices((2, 3))) + assert_equal(ind, stats.make_indices((2, 3))) def test_mc_error(): """Test batch standard deviation function""" x = random(100000) - assert(mc_error(x) < 0.0025) + assert(stats.mc_error(x) < 0.0025) def test_quantiles(): """Test quantiles function""" - q = quantiles(normal_sample) + q = stats.quantiles(normal_sample) assert_array_almost_equal(sorted(q.values()), [-1.96, -0.67, 0, 0.67, 1.96], 2) + + +def test_summary_1_value_model(): + mu = -2.1 + tau = 1.3 + with pm.Model() as model: + x = pm.Normal('x', mu, tau, testval=.1) + step = pm.Metropolis(model.vars, np.diag([1.])) + trace = pm.sample(100, step=step) + stats.summary(trace) + + +def test_summary_2_value_model(): + mu = -2.1 + tau = 1.3 + with pm.Model() as model: + x = pm.Normal('x', mu, tau, shape=2, testval=[.1, .1]) + step = pm.Metropolis(model.vars, np.diag([1.])) + trace = pm.sample(100, step=step) + stats.summary(trace) + + +def test_summary_2dim_value_model(): + mu = -2.1 + tau = 1.3 + with pm.Model() as model: + x = pm.Normal('x', mu, tau, shape=(2, 2), + testval=np.tile(.1, (2, 2))) + step = pm.Metropolis(model.vars, np.diag([1.])) + trace = pm.sample(100, step=step) + + with warnings.catch_warnings(record=True) as wrn: + stats.summary(trace) + assert len(wrn) == 1 + assert str(wrn[0].message) == 'Skipping x (above 1 dimension)' + + +def test_summary_format_values(): + roundto = 2 + summ = stats._Summary(roundto) + d = {'nodec': 1, 'onedec': 1.0, 'twodec': 1.00, 'threedec': 1.000} + summ._format_values(d) + for val in d.values(): + assert val == '1.00' + + +def test_stat_summary_format_hpd_values(): + roundto = 2 + summ = stats._StatSummary(roundto, None, 0.05) + d = {'nodec': 1, 'hpd': [1, 1]} + summ._format_values(d) + for key, val in d.items(): + if key == 'hpd': + assert val == '[1.00, 1.00]' + else: + assert val == '1.00' + + +@nose.tools.raises(IndexError) +def test_calculate_stats_variable_size1_not_adjusted(): + sample = np.arange(10) + list(stats._calculate_stats(sample, 5, 0.05)) + + +def test_calculate_stats_variable_size1_adjusted(): + sample = np.arange(10)[:, None] + result_size = len(list(stats._calculate_stats(sample, 5, 0.05))) + assert result_size == 1 + +def test_calculate_stats_variable_size2(): + ## 2 traces of 5 + sample = np.arange(10).reshape(5, 2) + result_size = len(list(stats._calculate_stats(sample, 5, 0.05))) + assert result_size == 2 + + +@nose.tools.raises(IndexError) +def test_calculate_pquantiles_variable_size1_not_adjusted(): + sample = np.arange(10) + qlist = (0.25, 25, 50, 75, 0.98) + list(stats._calculate_posterior_quantiles(sample, + qlist)) + + +def test_calculate_pquantiles_variable_size1_adjusted(): + sample = np.arange(10)[:, None] + qlist = (0.25, 25, 50, 75, 0.98) + result_size = len(list(stats._calculate_posterior_quantiles(sample, + qlist))) + assert result_size == 1 + + +def test_stats_value_line(): + roundto = 1 + summ = stats._StatSummary(roundto, None, 0.05) + values = [{'mean': 0, 'sd': 1, 'mce': 2, 'hpd': [4, 4]}, + {'mean': 5, 'sd': 6, 'mce': 7, 'hpd': [8, 8]},] + + expected = ['0.0 1.0 2.0 [4.0, 4.0]', + '5.0 6.0 7.0 [8.0, 8.0]'] + result = list(summ._create_value_output(values)) + assert result == expected + + +def test_post_quantile_value_line(): + roundto = 1 + summ = stats._PosteriorQuantileSummary(roundto, 0.05) + values = [{'lo': 0, 'q25': 1, 'q50': 2, 'q75': 4, 'hi': 5}, + {'lo': 6, 'q25': 7, 'q50': 8, 'q75': 9, 'hi': 10},] + + expected = ['0.0 1.0 2.0 4.0 5.0', + '6.0 7.0 8.0 9.0 10.0'] + result = list(summ._create_value_output(values)) + assert result == expected + + +def test_stats_output_lines(): + roundto = 1 + x = np.arange(10).reshape(5, 2) + + summ = stats._StatSummary(roundto, 5, 0.05) + + expected = [' Mean SD MC Error 95% HPD interval', + ' -------------------------------------------------------------------', + ' 4.0 2.8 1.3 [0.0, 8.0]', + ' 5.0 2.8 1.3 [1.0, 9.0]',] + result = list(summ._get_lines(x)) + assert result == expected + + +def test_posterior_quantiles_output_lines(): + roundto = 1 + x = np.arange(10).reshape(5, 2) + + summ = stats._PosteriorQuantileSummary(roundto, 0.05) + + expected = [' Posterior quantiles:', + ' 2.5 25 50 75 97.5', + ' |--------------|==============|==============|--------------|', + ' 0.0 2.0 4.0 6.0 8.0', + ' 1.0 3.0 5.0 7.0 9.0'] + + result = list(summ._get_lines(x)) + assert result == expected diff --git a/pymc/tests/test_trace.py b/pymc/tests/test_trace.py index 35c03a740e..c3ee7db621 100644 --- a/pymc/tests/test_trace.py +++ b/pymc/tests/test_trace.py @@ -2,8 +2,6 @@ from .models import * import pymc as pm import numpy as np -import warnings -import nose # Test if multiprocessing is available import multiprocessing @@ -128,147 +126,3 @@ def test_multi_slice(): # Original trace did not change assert np.all([tr[v][0].shape == (iterations, start[v].size) for v in tr.varnames]) - - -def test_summary_1_value_model(): - mu = -2.1 - tau = 1.3 - with Model() as model: - x = Normal('x', mu, tau, testval=.1) - step = Metropolis(model.vars, np.diag([1.])) - trace = sample(100, step=step) - pm.summary(trace) - - -def test_summary_2_value_model(): - mu = -2.1 - tau = 1.3 - with Model() as model: - x = Normal('x', mu, tau, shape=2, testval=[.1, .1]) - step = Metropolis(model.vars, np.diag([1.])) - trace = sample(100, step=step) - pm.summary(trace) - - -def test_summary_2dim_value_model(): - mu = -2.1 - tau = 1.3 - with Model() as model: - x = Normal('x', mu, tau, shape=(2, 2), - testval=np.tile(.1, (2, 2))) - step = Metropolis(model.vars, np.diag([1.])) - trace = sample(100, step=step) - - with warnings.catch_warnings(record=True) as wrn: - pm.summary(trace) - assert len(wrn) == 1 - assert str(wrn[0].message) == 'Skipping x (above 1 dimension)' - - -def test_summary_format_values(): - roundto = 2 - summ = pm.trace._Summary(roundto) - d = {'nodec': 1, 'onedec': 1.0, 'twodec': 1.00, 'threedec': 1.000} - summ._format_values(d) - for val in d.values(): - assert val == '1.00' - - -def test_stat_summary_format_hpd_values(): - roundto = 2 - summ = pm.trace._StatSummary(roundto, None, 0.05) - d = {'nodec': 1, 'hpd': [1, 1]} - summ._format_values(d) - for key, val in d.items(): - if key == 'hpd': - assert val == '[1.00, 1.00]' - else: - assert val == '1.00' - - -@nose.tools.raises(IndexError) -def test_calculate_stats_variable_size1_not_adjusted(): - sample = np.arange(10) - list(pm.trace._calculate_stats(sample, 5, 0.05)) - - -def test_calculate_stats_variable_size1_adjusted(): - sample = np.arange(10)[:, None] - result_size = len(list(pm.trace._calculate_stats(sample, 5, 0.05))) - assert result_size == 1 - -def test_calculate_stats_variable_size2(): - ## 2 traces of 5 - sample = np.arange(10).reshape(5, 2) - result_size = len(list(pm.trace._calculate_stats(sample, 5, 0.05))) - assert result_size == 2 - - -@nose.tools.raises(IndexError) -def test_calculate_pquantiles_variable_size1_not_adjusted(): - sample = np.arange(10) - qlist = (0.25, 25, 50, 75, 0.98) - list(pm.trace._calculate_posterior_quantiles(sample, - qlist)) - - -def test_calculate_pquantiles_variable_size1_adjusted(): - sample = np.arange(10)[:, None] - qlist = (0.25, 25, 50, 75, 0.98) - result_size = len(list(pm.trace._calculate_posterior_quantiles(sample, - qlist))) - assert result_size == 1 - - -def test_stats_value_line(): - roundto = 1 - summ = pm.trace._StatSummary(roundto, None, 0.05) - values = [{'mean': 0, 'sd': 1, 'mce': 2, 'hpd': [4, 4]}, - {'mean': 5, 'sd': 6, 'mce': 7, 'hpd': [8, 8]},] - - expected = ['0.0 1.0 2.0 [4.0, 4.0]', - '5.0 6.0 7.0 [8.0, 8.0]'] - result = list(summ._create_value_output(values)) - assert result == expected - - -def test_post_quantile_value_line(): - roundto = 1 - summ = pm.trace._PosteriorQuantileSummary(roundto, 0.05) - values = [{'lo': 0, 'q25': 1, 'q50': 2, 'q75': 4, 'hi': 5}, - {'lo': 6, 'q25': 7, 'q50': 8, 'q75': 9, 'hi': 10},] - - expected = ['0.0 1.0 2.0 4.0 5.0', - '6.0 7.0 8.0 9.0 10.0'] - result = list(summ._create_value_output(values)) - assert result == expected - - -def test_stats_output_lines(): - roundto = 1 - x = np.arange(10).reshape(5, 2) - - summ = pm.trace._StatSummary(roundto, 5, 0.05) - - expected = [' Mean SD MC Error 95% HPD interval', - ' -------------------------------------------------------------------', - ' 4.0 2.8 1.3 [0.0, 8.0]', - ' 5.0 2.8 1.3 [1.0, 9.0]',] - result = list(summ._get_lines(x)) - assert result == expected - - -def test_posterior_quantiles_output_lines(): - roundto = 1 - x = np.arange(10).reshape(5, 2) - - summ = pm.trace._PosteriorQuantileSummary(roundto, 0.05) - - expected = [' Posterior quantiles:', - ' 2.5 25 50 75 97.5', - ' |--------------|==============|==============|--------------|', - ' 0.0 2.0 4.0 6.0 8.0', - ' 1.0 3.0 5.0 7.0 9.0'] - - result = list(summ._get_lines(x)) - assert result == expected diff --git a/pymc/trace.py b/pymc/trace.py index 7b174d7504..023580aa5b 100644 --- a/pymc/trace.py +++ b/pymc/trace.py @@ -1,11 +1,9 @@ import numpy as np from .core import * -from .stats import * import copy import types -import warnings -__all__ = ['NpTrace', 'MultiTrace', 'summary'] +__all__ = ['NpTrace', 'MultiTrace'] class NpTrace(object): """ @@ -112,161 +110,3 @@ def combined(self): for k in self.traces[0].samples: h.samples[k].vals = [s[k] for s in self.traces] return h - - -def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): - """ - Generate a pretty-printed summary of the node. - - :Parameters: - trace : Trace object - Trace containing MCMC sample - - vars : list of strings - List of variables to summarize. Defaults to None, which results - in all variables summarized. - - alpha : float - The alpha level for generating posterior intervals. Defaults to - 0.05. - - start : int - The starting index from which to summarize (each) chain. Defaults - to zero. - - batches : int - Batch size for calculating standard deviation for non-independent - samples. Defaults to 100. - - roundto : int - The number of digits to round posterior statistics. - - """ - if vars is None: - vars = trace.varnames - if isinstance(trace, MultiTrace): - trace = trace.combined() - - stat_summ = _StatSummary(roundto, batches, alpha) - pq_summ = _PosteriorQuantileSummary(roundto, alpha) - - for var in vars: - # Extract sampled values - sample = trace[var][start:] - if sample.ndim == 1: - sample = sample[:, None] - elif sample.ndim > 2: - ## trace dimensions greater than 2 (variable greater than 1) - warnings.warn('Skipping {} (above 1 dimension)'.format(var)) - continue - - print('\n%s:' % var) - print(' ') - - stat_summ.print_output(sample) - pq_summ.print_output(sample) - - -class _Summary(object): - """Base class for summary output""" - def __init__(self, roundto): - self.roundto = roundto - self.header_lines = None - self.leader = ' ' - self.spaces = None - - def print_output(self, sample): - print('\n'.join(list(self._get_lines(sample))) + '\n') - - def _get_lines(self, sample): - for line in self.header_lines: - yield self.leader + line - summary_lines = self._calculate_values(sample) - for line in self._create_value_output(summary_lines): - yield self.leader + line - - def _create_value_output(self, lines): - for values in lines: - self._format_values(values) - yield self.value_line.format(pad=self.spaces, **values).strip() - - def _calculate_values(self, sample): - raise NotImplementedError - - def _format_values(self, summary_values): - for key, val in summary_values.items(): - summary_values[key] = '{:.{ndec}f}'.format( - float(val), ndec=self.roundto) - - -class _StatSummary(_Summary): - def __init__(self, roundto, batches, alpha): - super(_StatSummary, self).__init__(roundto) - spaces = 17 - hpd_name = '{}% HPD interval'.format(int(100 * (1 - alpha))) - value_line = '{mean:<{pad}}{sd:<{pad}}{mce:<{pad}}{hpd:<{pad}}' - header = value_line.format(mean='Mean', sd='SD', mce='MC Error', - hpd=hpd_name, pad=spaces).strip() - hline = '-' * len(header) - - self.header_lines = [header, hline] - self.spaces = spaces - self.value_line = value_line - self.batches = batches - self.alpha = alpha - - def _calculate_values(self, sample): - return _calculate_stats(sample, self.batches, self.alpha) - - def _format_values(self, summary_values): - roundto = self.roundto - for key, val in summary_values.items(): - if key == 'hpd': - summary_values[key] = '[{:.{ndec}f}, {:.{ndec}f}]'.format( - *val, ndec=roundto) - else: - summary_values[key] = '{:.{ndec}f}'.format( - float(val), ndec=roundto) - - -class _PosteriorQuantileSummary(_Summary): - def __init__(self, roundto, alpha): - super(_PosteriorQuantileSummary, self).__init__(roundto) - spaces = 15 - title = 'Posterior quantiles:' - value_line = '{lo:<{pad}}{q25:<{pad}}{q50:<{pad}}{q75:<{pad}}{hi:<{pad}}' - lo, hi = 100 * alpha / 2, 100 * (1. - alpha / 2) - qlist = (lo, 25, 50, 75, hi) - header = value_line.format(lo=lo, q25=25, q50=50, q75=75, hi=hi, - pad=spaces).strip() - hline = '|{thin}|{thick}|{thick}|{thin}|'.format( - thin='-' * (spaces - 1), thick='=' * (spaces - 1)) - - self.header_lines = [title, header, hline] - self.spaces = spaces - self.lo, self.hi = lo, hi - self.qlist = qlist - self.value_line = value_line - - def _calculate_values(self, sample): - return _calculate_posterior_quantiles(sample, self.qlist) - - -def _calculate_stats(sample, batches, alpha): - means = sample.mean(0) - sds = sample.std(0) - mces = mc_error(sample, batches) - intervals = hpd(sample, alpha) - for index in range(sample.shape[1]): - mean, sd, mce = [stat[index] for stat in (means, sds, mces)] - interval = intervals[index].squeeze().tolist() - yield {'mean': mean, 'sd': sd, 'mce': mce, 'hpd': interval} - - -def _calculate_posterior_quantiles(sample, qlist): - var_quantiles = quantiles(sample, qlist=qlist) - ## Replace ends of qlist with 'lo' and 'hi' - qends = {qlist[0]: 'lo', qlist[-1]: 'hi'} - qkeys = {q: qends[q] if q in qends else 'q{}'.format(q) for q in qlist} - for index in range(sample.shape[1]): - yield {qkeys[q]: var_quantiles[q][index] for q in qlist} From 683507b08504c751a32dcc7e12bc2711bfbe2ecc Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Sun, 5 Jan 2014 17:41:58 -0500 Subject: [PATCH 04/18] Rename sample.py to sampling.py To avoid any confusing with function pymc.sample, which is the imported from model pymc.sample in __init__.py --- pymc/__init__.py | 2 +- pymc/{sample.py => sampling.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename pymc/{sample.py => sampling.py} (100%) diff --git a/pymc/__init__.py b/pymc/__init__.py index 8dd1d80fc2..b35f0f2ced 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -5,7 +5,7 @@ from .math import * from .trace import * -from .sample import * +from .sampling import * from .stats import summary from .step_methods import * from .tuning import * diff --git a/pymc/sample.py b/pymc/sampling.py similarity index 100% rename from pymc/sample.py rename to pymc/sampling.py From 2e80cef7ee947969d254605c6e5d70e461cbbeab Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 2 Jan 2014 19:05:42 -0500 Subject: [PATCH 05/18] Add base and NDArray backend This commit contains a new backend for sampling and selecting values. Non-backend files have been changed to work with the new backend. Everything seems to be working with the exception of two issues (marked with FIXME): 1. pymc.plots.forestplot has not been updated yet for the new backend. 2. The previous behavior of passing a trace object to sample is not the same. I updated stochastic_volatility to do this with the same trace object. This commit also introduces a change to `sample`/`psample`. Instead of having separate function, `sample` now takes a keyword argument `threads`, and if this is over one, the multiprocessing version is used. The method for selecting values has also been changed. Traces can still be indexed to return values, a new slice, or a point (depending on the index), but the handling of chains is different. The trace object is now manages multiple chains itself instead of having a separate class to manage the single trace object. `get_values` is the main method for selecting values. By default, it returns separate results for all the chains. The chains can be combine with the `combine` flags, and particular chains can be select with the `chains` argument. The motivation for both sample and selection changes above was to have a unified interface for dealing with multiple chains, as most people are likely going to take advantage of the parallel sampling. --- pymc/__init__.py | 2 +- pymc/backends/__init__.py | 1 + pymc/backends/base.py | 300 +++++++++++++++++++++++ pymc/backends/ndarray.py | 106 ++++++++ pymc/diagnostics.py | 20 +- pymc/examples/gelman_bioassay.py | 2 +- pymc/examples/stochastic_volatility.py | 10 +- pymc/plots.py | 63 ++--- pymc/sampling.py | 250 +++++++++---------- pymc/stats.py | 65 ++--- pymc/tests/checks.py | 2 +- pymc/tests/test_base_backend.py | 149 ++++++++++++ pymc/tests/test_diagnostics.py | 4 +- pymc/tests/test_glm.py | 16 +- pymc/tests/test_ndarray_backend.py | 320 +++++++++++++++++++++++++ pymc/tests/test_plots.py | 10 +- pymc/tests/test_sampling.py | 80 ++++++- pymc/tests/test_trace.py | 128 ---------- pymc/trace.py | 112 --------- setup.py | 2 +- 20 files changed, 1153 insertions(+), 489 deletions(-) create mode 100644 pymc/backends/__init__.py create mode 100644 pymc/backends/base.py create mode 100644 pymc/backends/ndarray.py create mode 100644 pymc/tests/test_base_backend.py create mode 100644 pymc/tests/test_ndarray_backend.py delete mode 100644 pymc/tests/test_trace.py delete mode 100644 pymc/trace.py diff --git a/pymc/__init__.py b/pymc/__init__.py index b35f0f2ced..1dae5edbbf 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -4,7 +4,7 @@ from .distributions import * from .math import * -from .trace import * + from .sampling import * from .stats import summary from .step_methods import * diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py new file mode 100644 index 0000000000..feebc00bb6 --- /dev/null +++ b/pymc/backends/__init__.py @@ -0,0 +1 @@ +from pymc.backends.ndarray import NDArray diff --git a/pymc/backends/base.py b/pymc/backends/base.py new file mode 100644 index 0000000000..5afe8ec371 --- /dev/null +++ b/pymc/backends/base.py @@ -0,0 +1,300 @@ +"""Base backend for traces + +These are the base classes for all trace backends. They define all the +required methods for sampling and value selection that should be +overridden or implementented in children classes. See the docstring for +pymc.backends for more information (includng creating custom backends). +""" +import numpy as np +from pymc.model import modelcontext + + +class Backend(object): + + def __init__(self, name, model=None, variables=None): + self.name = name + + ## model attributes + self.variables = None + self.var_names = None + self.var_shapes = None + self._fn = None + + model = modelcontext(model) + self.model = model + if model: + self._setup_model(model, variables) + + ## set by setup_samples + self.chain = None + self.trace = None + + self._draws = {} + + def _setup_model(self, model, variables): + if variables is None: + variables = model.unobserved_RVs + self.variables = variables + self.var_names = [str(var) for var in variables] + self._fn = model.fastfn(variables) + + var_values = zip(self.var_names, self._fn(model.test_point)) + self.var_shapes = {var: value.shape + for var, value in var_values} + + def setup_samples(self, draws, chain): + """Prepare structure to store traces + + Parameters + ---------- + draws : int + Number of sampling iterations + chain : int + Chain number to store trace under + """ + self.chain = chain + self._draws[chain] = draws + + if self.trace is None: + self.trace = self._initialize_trace() + trace = self.trace + trace._draws[chain] = draws + trace.backend = self + + trace.samples[chain] = {} + for var_name, var_shape in self.var_shapes.items(): + trace_shape = [draws] + list(var_shape) + trace.samples[chain][var_name] = self._create_trace(chain, + var_name, + trace_shape) + + def record(self, point, draw): + """Record the value of the current iteration + + Parameters + ---------- + point : dict + Map of point values to variable names + draw : int + Current sampling iteration + """ + for var_name, value in zip(self.var_names, self._fn(point)): + self._store_value(draw, + self.trace.samples[self.chain][var_name], + value) + + def clean_interrupt(self, current_draw): + """Clean up sampling after interruption + + Perform any clean up not taken care of by `close`. After + KeyboardInterrupt, `sample` calls `close`, so `close` should not + be called here. + """ + self.trace._draws[self.chain] = current_draw + + ## Sampling methods that children must define + + def _initialize_trace(self): + raise NotImplementedError + + def _create_trace(self, chain, var_name, shape): + """Create trace for a variable + + Parameters + ---------- + chain : int + Current chain number + var_name : str + Name of variable + shape : tuple + Shape of the trace. The first element corresponds to the + number of draws. + """ + raise NotImplementedError + + def _store_value(self, draw, var_trace, value): + raise NotImplementedError + + def commit(self): + """Commit samples to backend + + This is called at set intervals during sampling. + """ + raise NotImplementedError + + def close(self): + """Close the database backend + + This is called after sampling has finished. + """ + raise NotImplementedError + + +class Trace(object): + """ + Parameters + ---------- + var_names : list of strs + Sample variables names + backend : Backend object + + Attributes + ---------- + backend : Backend object + var_names + var_shapes : dict + Map of variables shape to variable names + samples : dict of dicts + Sample values keyed by chain and variable name + nchains : int + Number of sampling chains + chains : list of ints + List of sampling chain numbers + default_chain : int + Chain to be used if single chain requested + active_chains : list of ints + Values from chains to be used operations + """ + def __init__(self, var_names, backend=None): + self.var_names = var_names + + self.samples = {} + self._draws = {} + self.backend = backend + self._active_chains = [] + self._default_chain = None + + @property + def nchains(self): + """Number of chains + + A chain is created for each sample call (including parallel + threads). + """ + return len(self.samples) + + @property + def chains(self): + """All chains in trace""" + return list(self.samples.keys()) + + @property + def default_chain(self): + """Default chain to use for operations that require one chain (e.g., + `point`) + """ + if self._default_chain is None: + return self.active_chains[-1] + return self._default_chain + + @default_chain.setter + def default_chain(self, value): + self._default_chain = value + + @property + def active_chains(self): + """List of chains to be used. Defaults to all. + """ + if not self._active_chains: + return self.chains + return self._active_chains + + @active_chains.setter + def active_chains(self, values): + try: + self._active_chains = [chain for chain in values] + except TypeError: + self._active_chains = [values] + + def __len__(self): + return self._draws[self.default_chain] + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self._slice(idx) + + try: + return self.point(idx) + except ValueError: + pass + except TypeError: + pass + return self.get_values(idx) + + ## Selection methods that children must define + + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from samples + + Parameters + ---------- + var_name : str + burn : int + thin : int + combine : bool + If True, results from all chains will be concatenated. + chains : list + Chains to retrieve. If None, `active_chains` is used. + squeeze : bool + If `combine` is False, return a single array element if the + resulting list of values only has one element (even if + `combine` is True). + + Returns + ------- + A list of NumPy array of values + """ + raise NotImplementedError + + def _slice(self, idx): + """Slice trace object""" + raise NotImplementedError + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + raise NotImplementedError + + +def merge_chains(traces): + """Merge chains from trace instances + + Parameters + ---------- + traces : list + Backend trace instances. Each instance should have only one + chain, and all chain numbers should be unique. + + Raises + ------ + ValueError is raised if any traces have the same current chain + number. + + Returns + ------- + Backend instance with merge chains + """ + base_trace = traces[0] + for new_trace in traces[1:]: + new_chain = new_trace.chains[0] + if new_chain in base_trace.samples: + raise ValueError('Trace chain numbers conflict.') + base_trace.samples[new_chain] = new_trace.samples[new_chain] + return base_trace + + +def _squeeze_cat(results, combine, squeeze): + """Squeeze and concatenate the results dependending on values of + `combine` and `squeeze`""" + if combine: + results = np.concatenate(results) + if not squeeze: + results = [results] + else: + if squeeze and len(results) == 1: + results = results[0] + return results diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py new file mode 100644 index 0000000000..ebd8b7d10d --- /dev/null +++ b/pymc/backends/ndarray.py @@ -0,0 +1,106 @@ +"""NumPy trace backend + +Store sampling values in memory as a NumPy array. +""" +import numpy as np +from pymc.backends import base + + +class NDArray(base.Backend): + + ## make `name` an optional argument for NDArray + def __init__(self, name=None, model=None, variables=None): + super(NDArray, self).__init__(name, model, variables) + + def _initialize_trace(self): + return Trace(self.var_names) + + def _create_trace(self, chain, var_name, shape): + return np.zeros(shape) + + def _store_value(self, draw, var_trace, value): + var_trace[draw] = value + + def commit(self): + pass + + def close(self): + pass + + def clean_interrupt(self, current_draw): + super(NDArray, self).clean_interrupt(current_draw) + traces = self.trace.samples[self.chain] + ## get rid of trailing zeros + traces = {var: trace[:current_draw] for var, trace in traces.items()} + self.trace.samples[self.chain] = traces + + +class Trace(base.Trace): + + __doc__ = 'NumPy array trace\n' + base.Trace.__doc__ + + def __len__(self): + try: + return super(Trace, self).__len__() + except KeyError: + var_name = self.var_names[0] + draws = self.samples[self.default_chain][var_name].shape[0] + self._draws[self.default_chain] = draws + return draws + + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from samples + + Parameters + ---------- + var_name : str + burn : int + thin : int + combine : bool + If True, results from all chains will be concatenated. + chains : list + Chains to retrieve. If None, `active_chains` is used. + squeeze : bool + If `combine` is False, return a single array element if the + resulting list of values only has one element (even if + `combine` is True). + + Returns + ------- + A list of NumPy array of values + """ + if chains is None: + chains = self.active_chains + + var_name = str(var_name) + results = (self.samples[chain][var_name] for chain in chains) + results = [arr[burn::thin] for arr in results] + + return base._squeeze_cat(results, combine, squeeze) + + def _slice(self, idx): + sliced = Trace(self.var_names) + sliced.backend = self.backend + sliced._active_chains = sliced._active_chains + sliced._default_chain = sliced._default_chain + + sliced.samples = {} + sliced._draws = {} + for chain, trace in self.samples.items(): + sliced_values = {var_name: values[idx] + for var_name, values in trace.items()} + sliced.samples[chain] = sliced_values + sliced._draws[chain] = sliced_values[self.var_names[0]].shape[0] + return sliced + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + if chain is None: + chain = self.default_chain + return {var_name: values[idx] + for var_name, values in self.samples[chain].items()} diff --git a/pymc/diagnostics.py b/pymc/diagnostics.py index f5f7669a43..9c81d8897c 100644 --- a/pymc/diagnostics.py +++ b/pymc/diagnostics.py @@ -87,7 +87,7 @@ def geweke(x, first=.1, last=.5, intervals=20): return np.array(zscores) -def gelman_rubin(mtrace): +def gelman_rubin(trace): """ Returns estimate of R for a set of traces. The Gelman-Rubin diagnostic tests for lack of convergence by comparing @@ -99,8 +99,8 @@ def gelman_rubin(mtrace): Parameters ---------- - mtrace : MultiTrace - A MultiTrace object containing parallel traces (minimum 2) + trace + A trace object containing parallel traces (minimum 2) of one or more stochastic parameters. Returns @@ -126,8 +126,7 @@ def gelman_rubin(mtrace): Brooks and Gelman (1998) Gelman and Rubin (1992)""" - m = len(mtrace.traces) - if m < 2: + if trace.nchains < 2: raise ValueError( 'Gelman-Rubin diagnostic requires multiple chains of the same length.') @@ -147,10 +146,11 @@ def calc_rhat(x): return np.sqrt(Vhat/W) Rhat = {} - for var in mtrace.varnames: + for var in trace.var_names: # Get all traces for var - x = np.array([mtrace.traces[i][var] for i in range(m)]) + x = np.array([values[var] for chain in trace.chains + for values in trace.samples.values()]) try: Rhat[var] = calc_rhat(x) @@ -159,9 +159,11 @@ def calc_rhat(x): return Rhat + def trace_to_dataframe(trace): """Convert a PyMC trace consisting of 1-D variables to a pandas DataFrame """ import pandas as pd - return pd.DataFrame({name: np.squeeze(trace_var.vals) - for name, trace_var in trace.samples.items()}) + return pd.DataFrame( + {var_name: np.squeeze(trace.get_values(var_name, combine=True)) + for var_name in trace.var_names}) diff --git a/pymc/examples/gelman_bioassay.py b/pymc/examples/gelman_bioassay.py index 4725ded00b..5aa23d038d 100644 --- a/pymc/examples/gelman_bioassay.py +++ b/pymc/examples/gelman_bioassay.py @@ -28,7 +28,7 @@ def run(n=1000): if n == "short": n = 50 with model: - trace = sample(n, step, trace=model.unobserved_RVs) + trace = sample(n, step) if __name__ == '__main__': run() diff --git a/pymc/examples/stochastic_volatility.py b/pymc/examples/stochastic_volatility.py index 2f0297bd63..1275ff3f91 100644 --- a/pymc/examples/stochastic_volatility.py +++ b/pymc/examples/stochastic_volatility.py @@ -119,12 +119,18 @@ def run(n=2000): if n == "short": n = 50 with model: - trace = sample(5, step, start, trace=model.vars + [sigma]) + trace = sample(5, step, start, variables=model.vars + [sigma]) + + ## FIXME: At the moment, there isn't a method for updating the + ## same trace. Below makes a new trace in the same backend that + ## has both the chains. The chain needs to be manually set to + ## avoid overwriting the previous chain. A check could be added + ## to override the chain argument to previous chain + 1. # Start next run at the last sampled position. start2 = trace.point(-1) step2 = HamiltonianMC(model.vars, hessian(start2, 6), path_length=4.) - trace = sample(2000, step2, trace=trace) + trace = sample(n, step2, start=start2, db=trace.backend, chain=1) # diff --git a/pymc/plots.py b/pymc/plots.py index ce90439913..5f299ddad0 100644 --- a/pymc/plots.py +++ b/pymc/plots.py @@ -7,12 +7,11 @@ import numpy as np from scipy.stats import kde from .stats import * -from .trace import * __all__ = ['traceplot', 'kdeplot', 'kde2plot', 'forestplot', 'autocorrplot'] -def traceplot(trace, vars=None, figsize=None, +def traceplot(trace, var_names=None, figsize=None, lines=None, combined=False, grid=True): """Plot samples histograms and values @@ -20,7 +19,7 @@ def traceplot(trace, vars=None, figsize=None, ---------- trace : result of MCMC run - vars : list of variable names + var_names : list of variable names Variables to be plotted, if None all variable are plotted figsize : figure size tuple If None, size is (12, num of variables * 2) inch @@ -29,8 +28,8 @@ def traceplot(trace, vars=None, figsize=None, lines to the posteriors and horizontal lines on sample values e.g. mean of posteriors, true values of a simulation combined : bool - Flag for combining MultiTrace into a single trace. If False (default) - traces will be plotted separately on the same set of axes. + Flag for combining multiple chains into a single chain. If False + (default), chains will be plotted separately. grid : bool Flag for adding gridlines to histogram. Defaults to True. @@ -40,30 +39,19 @@ def traceplot(trace, vars=None, figsize=None, fig : figure object """ + if var_names is None: + var_names = trace.var_names - if vars is None: - vars = trace.varnames - - if isinstance(trace, MultiTrace): - if combined: - traces = [trace.combined()] - else: - traces = trace.traces - else: - traces = [trace] - - n = len(vars) + n = len(var_names) if figsize is None: figsize = (12, n*2) fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize) - for trace in traces: - for i, v in enumerate(vars): - d = np.squeeze(trace[v]) - - if trace[v].dtype.kind == 'i': + for i, v in enumerate(var_names): + for d in trace.get_values(v, combine=combined, squeeze=False): + if d.dtype.kind == 'i': histplot_op(ax[i, 0], d) else: kdeplot_op(ax[i, 0], d) @@ -138,36 +126,25 @@ def kde2plot(x, y, grid=200): def autocorrplot(trace, vars=None, fontmap=None, max_lag=100): """Bar plot of the autocorrelation function for a trace""" - - try: - # MultiTrace - traces = trace.traces - - except AttributeError: - # NpTrace - traces = [trace] - if fontmap is None: fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4} if vars is None: - vars = traces[0].varnames + var_names = trace.var_names + else: + var_names = [str(var) for var in vars] - # Extract sample data - samples = [{v: trace[v] for v in vars} for trace in traces] + chains = trace.nchains - chains = len(traces) - - n = len(samples[0]) - f, ax = subplots(n, chains, squeeze=False) + # Extract sample data - max_lag = min(len(samples[0][vars[0]])-1, max_lag) + f, ax = subplots(len(var_names), chains, squeeze=False) - for i, v in enumerate(vars): + max_lag = min(len(trace) - 1, max_lag) + for i, v in enumerate(var_names): for j in range(chains): - - d = np.squeeze(samples[j][v]) + d = np.squeeze(trace.get_values(v, chains=[j])) ax[i, j].acorr(d, detrend=mlab.detrend_mean, maxlags=max_lag) @@ -205,7 +182,7 @@ def var_str(name, shape): names[0] = '%s %s' % (name, names[0]) return names - +## FIXME: This has not been updated to work with backends def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xrange=None, ylabels=None, chain_spacing=0.05, vline=0): diff --git a/pymc/sampling.py b/pymc/sampling.py index 6220243efc..18c08cb902 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -1,5 +1,6 @@ from .point import * -from .trace import NpTrace, MultiTrace +from pymc.backends.ndarray import NDArray +from pymc.backends.base import merge_chains import multiprocessing as mp from time import time from .core import * @@ -7,103 +8,146 @@ from .progressbar import progress_bar from numpy.random import seed -__all__ = ['sample', 'psample', 'iter_sample'] +__all__ = ['sample', 'iter_sample'] -def sample(draws, step, start=None, trace=None, tune=None, progressbar=True, model=None, random_seed=None): - """ - Draw a number of samples using the given step method. - Multiple step methods supported via compound step method - returns the amount of time taken. +def sample(draws, step, start=None, db=None, chain=0, threads=1, tune=None, + progressbar=True, model=None, variables=None, random_seed=None): + """Draw samples using the given step method Parameters ---------- - draws : int The number of samples to draw - step : function - A step function + step : step method or list of step methods start : dict - Starting point in parameter space (or partial point) - Defaults to trace.point(-1)) if there is a trace provided and - model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + Starting point in parameter space (or partial point). Defaults + to model.test_point. + db : backend + If None, NDArray is used. + chain : int + Chain number used to store sample in trace. If threads greater + than one, chain numbers will start here + threads : int + Number of parallel traces to start. If None, set to number of + cpus in the system - 2. tune : int - Number of iterations to tune, if applicable (defaults to None) + Number of iterations to tune, if applicable progressbar : bool Flag for progress bar model : Model (optional if in `with` context) + variables : list + Variables to sample. If None, defaults to model.unobserved_RVs. + Ignored if model argument is supplied. + random_seed : int or list of ints + List accepted if more than one thread. + Returns + ------- + Backend object with access to sampling values """ - progress = progress_bar(draws) + if threads is None: + threads = max(mp.cpu_count() - 2, 1) + if threads > 1: + try: + if not len(random_seed) == threads: + random_seeds = [random_seed] * threads + else: + random_seeds = random_seed + except TypeError: # None, int + random_seeds = [random_seed] * threads + + chains = list(range(chain, chain + threads)) + argset = zip([draws] * threads, + [step] * threads, + [start] * threads, + [db] * threads, + chains, + [tune] * threads, + [False] * threads, + [model] * threads, + [variables] * threads, + random_seeds) + sample_func = _thread_sample + sample_args = [threads, argset] + else: + sample_func = _sample + sample_args = [draws, step, start, db, chain, + tune, progressbar, model, variables, random_seed] + return sample_func(*sample_args) + + +def _sample(draws, step, start=None, db=None, chain=0, tune=None, + progressbar=True, model=None, variables=None, random_seed=None): + sampling = _iter_sample(draws, step, start, db, chain, + tune, model, variables, random_seed) + if progressbar: + sampling = enumerate_progress(sampling, draws) + else: + sampling = enumerate(sampling) try: - for i, trace in enumerate(iter_sample(draws, step, - start=start, - trace=trace, - tune=tune, - model=model, - random_seed=random_seed)): + for i, trace in enumerate(sampling): if progressbar: progress.update(i) except KeyboardInterrupt: - pass + trace.backend.clean_interrupt(i) + trace.backend.close() return trace -def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, random_seed=None): + +def iter_sample(draws, step, start=None, db=None, chain=0, tune=None, + model=None, variables=None, random_seed=None): """ - Generator that returns a trace on each iteration using the given - step method. Multiple step methods supported via compound step - method returns the amount of time taken. + Generator that returns a trace on each iteration using the given step + method. Parameters ---------- draws : int The number of samples to draw - step : function - A step function + step : step method or list of step methods start : dict - Starting point in parameter space (or partial point) - Defaults to trace.point(-1)) if there is a trace provided and - model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + Starting point in parameter space (or partial point). Defaults + to model.test_point. + db : backend + If None, NDArray is used. + chain : int + Chain number used to store sample in trace. If threads greater + than one, chain numbers will start here tune : int - Number of iterations to tune, if applicable (defaults to None) + Number of iterations to tune, if applicable model : Model (optional if in `with` context) + variables : list + Variables to sample. If None, defaults to model.unobserved_RVs. + Ignored if model argument is supplied. + random_seed : int or list of ints + List accepted if more than one thread. Example ------- for trace in iter_sample(500, step): ... - """ + sampling = _iter_sample(draws, step, start, db, chain, + tune, model, variables, random_seed) + for i, trace in enumerate(sampling): + yield trace[:i + 1] + + +def _iter_sample(draws, step, start=None, db=None, chain=0, tune=None, + model=None, variables=None, random_seed=None): + seed(random_seed) model = modelcontext(model) draws = int(draws) - seed(random_seed) + if draws < 1: + raise ValueError('Argument `draws` should be above 0') if start is None: start = {} - - if isinstance(trace, NpTrace) and len(trace) > 0: - trace_point = trace.point(-1) - trace_point.update(start) - start = trace_point - - else: - test_point = model.test_point.copy() - test_point.update(start) - start = test_point - - if not isinstance(trace, NpTrace): - if trace is None: - trace = model.unobserved_RVs - trace = NpTrace(trace) + _soft_update(start, model.test_point) try: step = step_methods.CompoundStep(step) @@ -112,12 +156,32 @@ def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, rand point = Point(start, model=model) + if db is None: + db = NDArray(model=model, variables=variables) + db.setup_samples(draws, chain) + for i in range(draws): - if (i == tune): + if i == tune: step = stop_tuning(step) point = step.step(point) - trace.record(point) - yield trace + db.record(point, i) + if not i % 1000: + db.commit() + yield db.trace + else: + db.close() + + +def _thread_sample(threads, args): + p = mp.Pool(threads) + traces = p.map(_argsample, args) + p.close() + return merge_chains(traces) + + +def _argsample(args): + """Defined at top level so it can be pickled""" + return _sample(*args) def stop_tuning(step): @@ -132,71 +196,7 @@ def stop_tuning(step): return step -def argsample(args): - """ defined at top level so it can be pickled""" - return sample(*args) - - -def psample(draws, step, start=None, trace=None, tune=None, model=None, threads=None, - random_seeds=None): - """draw a number of samples using the given step method. - Multiple step methods supported via compound step method - returns the amount of time taken - - Parameters - ---------- - - draws : int - The number of samples to draw - step : function - A step function - start : dict - Starting point in parameter space (Defaults to trace.point(-1)) - trace : MultiTrace or list - Either a trace of past values or a list of variables to track (defaults to None) - tune : int - Number of iterations to tune, if applicable (defaults to None) - model : Model (optional if in `with` context) - threads : int - Number of parallel traces to start - - Examples - -------- - - >>> an example - +def _soft_update(a, b): + """As opposed to dict.update, don't overwrite keys if present. """ - - model = modelcontext(model) - - if not threads: - threads = max(mp.cpu_count() - 2, 1) - - if start is None: - start = {} - - if isinstance(start, dict): - start = threads * [start] - - if trace is None: - trace = model.vars - - if type(trace) is MultiTrace: - mtrace = trace - else: - mtrace = MultiTrace(threads, trace) - - p = mp.Pool(threads) - - if random_seeds is None: - random_seeds = [None] * threads - - argset = zip([draws] * threads, [step] * threads, start, mtrace.traces, - [tune] * threads, [False] * threads, [model] * threads, - random_seeds) - - traces = p.map(argsample, argset) - - p.close() - - return MultiTrace(traces) + a.update({k: v for k, v in b.items() if k not in a}) diff --git a/pymc/stats.py b/pymc/stats.py index 732a38151b..ff46bf8c3e 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -1,7 +1,6 @@ """Utility functions for PyMC""" import numpy as np -from .trace import MultiTrace import warnings @@ -14,39 +13,28 @@ def statfunc(f): """ def wrapped_f(pymc_obj, *args, **kwargs): + burn = kwargs.pop('burn', 0) + thin = kwargs.pop('thin', 1) + combine = kwargs.pop('combine', False) try: - burn = kwargs.pop('burn') - except KeyError: - burn = 0 - - try: - # MultiTrace - traces = pymc_obj.traces - - try: - vars = kwargs.pop('vars') - except KeyError: - vars = traces[0].varnames - - return [{v: f(trace[v][burn:], *args, **kwargs) for v in vars} for trace in traces] - + var_names = kwargs.pop('vars', pymc_obj.var_names) + chains = kwargs.pop('chains', pymc_obj.active_chains) except AttributeError: - pass + # If fails, assume that raw data is passed + return f(pymc_obj, *args, **kwargs) - try: - # NpTrace - try: - vars = kwargs.pop('vars') - except KeyError: - vars = pymc_obj.varnames - - return {v: f(pymc_obj[v][burn:], *args, **kwargs) for v in vars} - except AttributeError: - pass + results = {chain: {} for chain in chains} + for var_name in var_names: + samples = pymc_obj.get_values(var_name, chains=chains, burn=burn, + thin=thin, combine=combine, + squeeze=False) + for chain, data in zip(chains, samples): + results[chain][var_name] = f(np.squeeze(data), *args, **kwargs) - # If others fail, assume that raw data is passed - return f(pymc_obj, *args, **kwargs) + if len(chains) == 1 or combine: + results = results[chains[0]] + return results wrapped_f.__doc__ = f.__doc__ wrapped_f.__name__ = f.__name__ @@ -242,15 +230,16 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5)): print("Too few elements for quantile calculation") -def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): +def summary(trace, var_names=None, alpha=0.05, start=0, batches=100, + roundto=3): """ Generate a pretty-printed summary of the node. :Parameters: trace : Trace object - Trace containing MCMC sample + Trace containing MCMC samples - vars : list of strings + var_names : list of strings List of variables to summarize. Defaults to None, which results in all variables summarized. @@ -270,17 +259,15 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): The number of digits to round posterior statistics. """ - if vars is None: - vars = trace.varnames - if isinstance(trace, MultiTrace): - trace = trace.combined() + if var_names is None: + var_names = trace.var_names stat_summ = _StatSummary(roundto, batches, alpha) pq_summ = _PosteriorQuantileSummary(roundto, alpha) - for var in vars: + for var_name in var_names: # Extract sampled values - sample = trace[var][start:] + sample = trace.get_values(var_name, burn=start, combine=True) if sample.ndim == 1: sample = sample[:, None] elif sample.ndim > 2: @@ -288,7 +275,7 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): warnings.warn('Skipping {} (above 1 dimension)'.format(var)) continue - print('\n%s:' % var) + print('\n%s:' % var_name) print(' ') stat_summ.print_output(sample) diff --git a/pymc/tests/checks.py b/pymc/tests/checks.py index 2e5470bc8f..6d4adf2180 100644 --- a/pymc/tests/checks.py +++ b/pymc/tests/checks.py @@ -1,7 +1,7 @@ import pymc as pm import numpy as np -from pymc import sample, psample +from pymc import sample from numpy.testing import assert_almost_equal diff --git a/pymc/tests/test_base_backend.py b/pymc/tests/test_base_backend.py new file mode 100644 index 0000000000..9493bbe537 --- /dev/null +++ b/pymc/tests/test_base_backend.py @@ -0,0 +1,149 @@ +import numpy as np +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +import nose + +from pymc.backends import base + + +class TestBaseInit(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + def test_base_init_just_name(self): + with mock.patch('pymc.backends.base.modelcontext') as context: + variables = self.variables + context.return_value = self.model + + db = base.Backend('name') + + context.assert_called_once_with(None) + self.assertEqual(db.variables, variables) + self.assertEqual(db.var_names, variables) + self.model.fastfn.assert_called_once_with(variables) + + def test_base_init_model_supplied(self): + db = base.Backend('name', model=self.model) + + self.assertEqual(db.variables, self.variables) + self.assertEqual(db.var_names, self.variables) + self.model.fastfn.assert_called_once_with(self.variables) + + def test_base_init_variables_supplied(self): + with mock.patch('pymc.backends.base.modelcontext') as context: + variables = ['a', 'b'] + context.return_value = self.model + + db = base.Backend('name', variables=variables) + + context.assert_called_once_with(None) + self.assertEqual(db.variables, variables) + self.assertEqual(db.var_names, variables) + self.model.fastfn.assert_called_once_with(variables) + + def test_base_setup_samples_default_chain(self): + with mock.patch('pymc.backends.base.modelcontext') as context: + variables = ['a', 'b'] + context.return_value = self.model + + db = base.Backend('name', variables=variables) + + db._create_trace = mock.Mock() + db.var_shapes = {'x': (), 'y': (10,)} + draws = 3 + + patch = mock.patch('pymc.backends.base.Backend._initialize_trace') + with patch as init_trace: + db.setup_samples(draws, 0) + + init_trace.assert_called_with() + db._create_trace.assert_any_call(0, 'x', [draws]) + db._create_trace.assert_any_call(0, 'y', [draws, 10]) + + +class TestBaseTrace(unittest.TestCase): + + def setUp(self): + var_names = ['x'] + self.trace = base.Trace(var_names) + self.trace.samples = {0: {'x': None}} + + def test_nchains(self): + + self.assertEqual(self.trace.nchains, 1) + + self.trace.samples[1] = {'y': None} + self.assertEqual(self.trace.nchains, 2) + + def test_chains(self): + self.assertEqual(self.trace.chains, [0]) + + self.trace.samples[1] = {'y': None} + self.assertEqual(self.trace.chains, [0, 1]) + + def test_chains_not_sequential(self): + self.trace.samples[4] = {'y': None} + self.assertEqual(self.trace.chains, [0, 4]) + + def test_default_chain_one_chain(self): + self.assertEqual(self.trace.default_chain, 0) + + def test_default_chain_multiple_chain(self): + self.trace.samples[1] = {'y': None} + self.assertEqual(self.trace.default_chain, 1) + + def test_default_chain_multiple_chains_set(self): + self.trace.samples[1] = {'y': None} + self.trace.default_chain = 0 + self.assertEqual(self.trace.default_chain, 0) + + def test_active_chains(self): + self.assertEqual(self.trace.chains, self.trace.active_chains) + self.trace.samples[1] = {'y': None} + self.assertEqual(self.trace.chains, self.trace.active_chains) + + def test_active_chains_set_with_int(self): + self.trace.samples[1] = {'y': None} + self.trace.active_chains = 0 + self.assertEqual(self.trace.active_chains, [0]) + + +class TestMergeChains(unittest.TestCase): + + def test_merge_chains_one_trace(self): + trace = mock.Mock() + trace.samples = {0: {'x': 0, 'y': 1}} + merged = base.merge_chains([trace]) + self.assertEqual(trace.samples, merged.samples) + + def test_merge_chains_two_traces(self): + trace1 = mock.Mock() + trace1.samples = {0: {'x': 0, 'y': 1}} + trace1.chains = [0] + + trace2 = mock.Mock() + trace2.samples = {1: {'x': 3, 'y': 4}} + trace2.chains = [1] + + merged = base.merge_chains([trace1, trace2]) + self.assertEqual(trace1.samples[0], merged.samples[0]) + self.assertEqual(trace2.samples[1], merged.samples[1]) + + def test_merge_chains_two_traces_same_slot(self): + trace1 = mock.Mock() + trace1.samples = {0: {'x': 0, 'y': 1}} + trace1.chains = [0] + + trace2 = mock.Mock() + trace2.samples = {0: {'x': 3, 'y': 4}} + trace2.chains = [0] + + with self.assertRaises(ValueError): + base.merge_chains([trace1, trace2]) diff --git a/pymc/tests/test_diagnostics.py b/pymc/tests/test_diagnostics.py index 8322a7ed27..4525fa2052 100644 --- a/pymc/tests/test_diagnostics.py +++ b/pymc/tests/test_diagnostics.py @@ -9,8 +9,8 @@ def test_gelman_rubin(n=1000): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(n, [step1, step2], start, threads=2, - random_seeds=[1, 3]) + ptrace = sample(n, [step1, step2], start, threads=2, + random_seed=[1, 3]) rhat = gelman_rubin(ptrace) diff --git a/pymc/tests/test_glm.py b/pymc/tests/test_glm.py index 1882d6040c..0262261e45 100644 --- a/pymc/tests/test_glm.py +++ b/pymc/tests/test_glm.py @@ -44,9 +44,9 @@ def test_linear_component(self): step = Slice(model.vars) trace = sample(2000, step, start, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) @unittest.skip("Fails only on travis. Investigate") def test_glm(self): @@ -57,9 +57,9 @@ def test_glm(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) def test_glm_link_func(self): with Model() as model: @@ -71,5 +71,5 @@ def test_glm_link_func(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 0) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 0) diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py new file mode 100644 index 0000000000..0dba0055e6 --- /dev/null +++ b/pymc/tests/test_ndarray_backend.py @@ -0,0 +1,320 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc.backends import ndarray + + +class TestNDArraySampling(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.db = ndarray.NDArray() + + def test_create_trace_scalar(self): + db = self.db + draws = 3 + trace = db._create_trace(chain=0, var_name=None, shape=[draws]) + npt.assert_equal(trace, np.zeros(draws)) + + def test_create_trace_1d(self): + db = self.db + draws = 3 + trace = db._create_trace(chain=0, var_name=None, shape=[draws, 2]) + npt.assert_equal(trace, np.zeros([draws, 2])) + + def test_setup_samples(self): + db = self.db + draws = 3 + + db.var_shapes = {'x': (), 'y': (4,)} + db.setup_samples(draws, chain=0) + + npt.assert_equal(db.trace['x'], np.zeros([draws])) + npt.assert_equal(db.trace['y'], np.zeros([draws, 4])) + + def test_record(self): + db = self.db + draws = 3 + + db.var_shapes = {'x': (), 'y': (4,)} + db.setup_samples(draws, chain=0) + + def just_ones(*args): + while True: + yield 1. + + db._fn = just_ones + + db.record(point=None, draw=0) + npt.assert_equal(1., db.trace.get_values('x', combine=True)[0]) + npt.assert_equal(np.ones(4), db.trace['y'][0]) + + def test_clean_interrupt(self): + db = self.db + db.setup_samples(draws=3, chain=0) + db.trace.samples = {0: {'x': np.zeros(10), 'y': np.zeros((10, 5))}} + db.clean_interrupt(3) + npt.assert_equal(np.zeros(3), db.trace['x']) + npt.assert_equal(np.zeros((3, 5)), db.trace['y']) + + +class TestNDArraySelection(unittest.TestCase): + + def setUp(self): + var_names = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + self.trace = ndarray.Trace(var_names) + self.trace.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.draws = draws + self.var_names = var_names + self.var_shapes = var_shapes + + def test_get_values_default(self): + base_shape = (self.draws,) + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + xsample = self.trace.get_values('x') + npt.assert_equal(np.zeros(base_shape + xshape), xsample) + + ysample = self.trace.get_values('y') + npt.assert_equal(np.zeros(base_shape + yshape), ysample) + + def test_get_values_burn_keyword(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples[chain]['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples[chain]['y'][:burn] = np.ones((burn,) + yshape) + + xsample = self.trace.get_values('x', burn=burn) + npt.assert_equal(np.zeros(base_shape + xshape)[burn:], xsample) + + ysample = self.trace.get_values('y', burn=burn) + npt.assert_equal(np.zeros(base_shape + yshape)[burn:], ysample) + + def test_get_values_thin_keyword(self): + base_shape = (self.draws,) + thin = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + xthin = np.ones((self.draws,) + xshape)[::thin] + ythin = np.ones((self.draws,) + yshape)[::thin] + self.trace.samples[chain]['x'][::thin] = xthin + self.trace.samples[chain]['y'][::thin] = ythin + + xsample = self.trace.get_values('x', thin=thin) + npt.assert_equal(xthin, xsample) + + ysample = self.trace.get_values('y', thin=thin) + npt.assert_equal(ythin, ysample) + + def test_point(self): + idx = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples[chain]['x'][idx] = 1. + self.trace.samples[chain]['y'][idx] = 1. + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_slice(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples[chain]['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples[chain]['y'][:burn] = np.ones((burn,) + yshape) + + sliced = self.trace[burn:] + + expected = {'x': np.zeros(base_shape + xshape)[burn:], + 'y': np.zeros(base_shape + yshape)[burn:]} + + for var_name, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.samples[chain][var_name], + expected[var_name]) + + +class TestNDArrayMultipleChains(unittest.TestCase): + + def setUp(self): + var_names = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + self.trace = ndarray.Trace(var_names) + self.trace.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}, + 1: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.draws = draws + self.var_names = var_names + self.var_shapes = var_shapes + self.total_draws = 2 * draws + + def test_get_values_multi_default(self): + sample = self.trace.get_values('x') + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_only_one_element(self): + sample = self.trace.get_values('x', chains=[0]) + xshape = self.var_shapes['x'] + + expected = np.zeros((self.draws,) + xshape) + + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_two_element_reversed(self): + sample = self.trace.get_values('x', chains=[1, 0]) + xshape = self.var_shapes['x'] + + expected = [np.ones((self.draws,) + xshape), + np.zeros((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_combine(self): + sample = self.trace.get_values('x', combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn(self): + sample = self.trace.get_values('x', burn=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn_combine(self): + sample = self.trace.get_values('x', burn=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn_one_active_chain(self): + self.trace.active_chains = 0 + sample = self.trace.get_values('x', burn=2) + xshape = self.var_shapes['x'] + + expected = np.zeros((self.draws,) + xshape)[2:] + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin(self): + sample = self.trace.get_values('x', thin=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin_combine(self): + sample = self.trace.get_values('x', thin=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]]) + npt.assert_equal(sample, expected) + + def test_multichain_point(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_multichain_point_chain_arg(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.trace.point(idx, chain=0) + expected = {'x': np.squeeze(np.zeros(xshape)), + 'y': np.squeeze(np.zeros(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_multichain_point_change_default_chain(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + self.trace.default_chain = 0 + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.zeros(xshape)), + 'y': np.squeeze(np.zeros(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_multichain_slice(self): + base_shapes = [(self.draws,)] * 2 + burn = 2 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + expected = {0: + {'x': np.zeros((self.draws, ) + xshape)[burn:], + 'y': np.zeros((self.draws, ) + yshape)[burn:]}, + 1: + {'x': np.ones((self.draws, ) + xshape)[burn:], + 'y': np.ones((self.draws, ) + yshape)[burn:]}} + + sliced = self.trace[burn:] + + for chain in self.trace.chains: + for var_name, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.samples[chain][var_name], + expected[chain][var_name]) diff --git a/pymc/tests/test_plots.py b/pymc/tests/test_plots.py index f31b9a5006..b782a969fe 100644 --- a/pymc/tests/test_plots.py +++ b/pymc/tests/test_plots.py @@ -1,6 +1,6 @@ #from ..plots import * from pymc.plots import * -from pymc import psample, Slice, Metropolis, find_hessian, sample +from pymc import Slice, Metropolis, find_hessian, sample def test_plots(): @@ -15,7 +15,8 @@ def test_plots(): step = Metropolis(model.vars, h) trace = sample(3000, step, start) - forestplot(trace) + # FIXME: forestplot has not been rewritten for backend + # forestplot(trace) autocorrplot(trace) @@ -29,8 +30,9 @@ def test_multichain_plots(): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(1000, [step1, step2], start, threads=2) + ptrace = sample(1000, [step1, step2], start, threads=2) - forestplot(ptrace, vars=['early_mean', 'late_mean']) + # FIXME: forestplot has not been rewritten for backend + # forestplot(ptrace, vars=['early_mean', 'late_mean']) autocorrplot(ptrace, vars=['switchpoint']) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index d4016c40f8..8b5a40ebf9 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1,5 +1,14 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + import pymc -from pymc import sample, psample, iter_sample +from pymc import sampling +from pymc.sampling import sample from .models import simple_init # Test if multiprocessing is available @@ -11,28 +20,73 @@ test_parallel = False +@mock.patch('pymc.sampling._sample') +def test_sample_check_full_signature_single_thread(sample_func): + sample('draws', 'step', start='start', db='db', threads=1, chain=1, + tune='tune', progressbar='progressbar', model='model', + variables='variables', random_seed='random_seed') + sample_func.assert_called_with('draws', 'step', 'start', 'db', 1, + 'tune', 'progressbar', 'model', 'variables', + 'random_seed') + + +@mock.patch('pymc.sampling._thread_sample') +def test_sample_check_ful_signature_multithreads(sample_func): + sample('draws', 'step', start='start', db='db', threads=2, chain=1, + tune='tune', progressbar='progressbar', model='model', + variables='variables', random_seed=0) + + args = sample_func.call_args_list[0][0] + assert args[0] == 2 + + expected_argset = [('draws', 'step', 'start', 'db', 1, 'tune', + False, 'model', 'variables', 0), + ('draws', 'step', 'start', 'db', 2, 'tune', + False, 'model', 'variables', 0)] + argset = list(args[1]) + print(argset) + print(expected_argset) + assert argset == expected_argset + + +def test_soft_update_all_present(): + start = {'a': 1, 'b': 2} + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == {'a': 1, 'b': 2} + + +def test_soft_update_one_missing(): + start = {'a': 1, } + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == {'a': 1, 'b': 4} + + +def test_soft_update_empty(): + start = {} + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == test_point + + def test_sample(): model, start, step, _ = simple_init() - test_samplers = [sample] - - tr = sample(5, step, start, model=model) - test_traces = [None, tr] + test_threads = [1] if test_parallel: - test_samplers.append(psample) + test_threads.append(2) with model: - for trace in test_traces: - for samplr in test_samplers: - for n in [0, 1, 10, 300]: + for threads in test_threads: + for n in [1, 10, 300]: + yield sample, n, step, {}, None, threads - yield samplr, n, step, {} - yield samplr, n, step, {}, trace - yield samplr, n, step, start def test_iter_sample(): model, start, step, _ = simple_init() - for i, trace in enumerate(iter_sample(5, step, start, model=model)): + samps = sampling.iter_sample(5, step, start, model=model) + for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." diff --git a/pymc/tests/test_trace.py b/pymc/tests/test_trace.py deleted file mode 100644 index c3ee7db621..0000000000 --- a/pymc/tests/test_trace.py +++ /dev/null @@ -1,128 +0,0 @@ -from .checks import * -from .models import * -import pymc as pm -import numpy as np - -# Test if multiprocessing is available -import multiprocessing -try: - multiprocessing.Pool(2) - test_parallel = False -except: - test_parallel = False - - -def check_trace(model, trace, n, step, start): - # try using a trace object a few times - - for i in range(2): - trace = sample( - n, step, start, trace, progressbar=False, model=model) - - for (var, val) in start.items(): - - assert np.shape(trace[var]) == (n * (i + 1),) + np.shape(val) - -def test_trace(): - model, start, step, _ = simple_init() - - for h in [pm.NpTrace]: - for n in [20, 1000]: - for vars in [model.vars, model.vars + [model.vars[0] ** 2]]: - trace = h(vars) - - yield check_trace, model, trace, n, step, start - - -def test_multitrace(): - if not test_parallel: - return - model, start, step, _ = simple_init() - trace = None - for n in [20, 1000]: - - yield check_multi_trace, model, trace, n, step, start - - -def check_multi_trace(model, trace, n, step, start): - - for i in range(2): - trace = psample( - n, step, start, trace, model=model) - - for (var, val) in start.items(): - print([len(tr.samples[var].vals) for tr in trace.traces]) - for t in trace[var]: - assert np.shape(t) == (n * (i + 1),) + np.shape(val) - - ctrace = trace.combined() - for (var, val) in start.items(): - - assert np.shape( - ctrace[var]) == (len(trace.traces) * n * (i + 1),) + np.shape(val) - - -def test_get_point(): - - p, model = simple_2model() - p2 = p.copy() - p2['x'] *= 2. - - x = pm.NpTrace(model.vars) - x.record(p) - x.record(p2) - assert x.point(1) == x[1] - -def test_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = sample(iterations, start=start, step=step, progressbar=False) - - burned = tr[burn:] - - # Slicing returns a trace - assert type(burned) is pm.trace.NpTrace - - # Indexing returns an array - assert type(tr[tr.varnames[0]]) is np.ndarray - - # Burned trace is of the correct length - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v].shape == (iterations, start[v].size) for v in tr.varnames]) - - # Now take more burn-in from the burned trace - burned_again = burned[burn:] - assert np.all([burned_again[v].shape == (iterations-2*burn, start[v].size) for v in burned_again.varnames]) - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - -def test_multi_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = psample(iterations, start=start, step=step, threads=2) - - burned = tr[burn:] - - # Slicing returns a MultiTrace - assert type(burned) is pm.trace.MultiTrace - - # Indexing returns a list of arrays - assert type(tr[tr.varnames[0]]) is list - assert type(tr[tr.varnames[0]][0]) is np.ndarray - - # # Burned trace is of the correct length - assert np.all([burned[v][0].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v][0].shape == (iterations, start[v].size) for v in tr.varnames]) diff --git a/pymc/trace.py b/pymc/trace.py deleted file mode 100644 index 023580aa5b..0000000000 --- a/pymc/trace.py +++ /dev/null @@ -1,112 +0,0 @@ -import numpy as np -from .core import * -import copy -import types - -__all__ = ['NpTrace', 'MultiTrace'] - -class NpTrace(object): - """ - encapsulates the recording of a process chain - """ - def __init__(self, vars): - vars = list(vars) - model = vars[0].model - self.f = model.fastfn(vars) - self.vars = vars - self.varnames = list(map(str, vars)) - self.samples = dict((v, ListArray()) for v in self.varnames) - - def record(self, point): - """ - Records the position of a chain at a certain point in time. - """ - for var, value in zip(self.varnames, self.f(point)): - self.samples[var].append(value) - return self - - def __getitem__(self, index_value): - """ - Return copy NpTrace with sliced sample values if a slice is passed, - or the array of samples if a varname is passed. - """ - - if isinstance(index_value, slice): - - sliced_trace = NpTrace(self.vars) - sliced_trace.samples = dict((name, vals[index_value]) for (name, vals) in self.samples.items()) - - return sliced_trace - - else: - try: - return self.point(index_value) - except ValueError: - pass - except TypeError: - pass - - return self.samples[str(index_value)].value - - def __len__(self): - return len(self.samples[self.varnames[0]]) - - def point(self, index): - return dict((k, v.value[index]) for (k, v) in self.samples.items()) - - -class ListArray(object): - def __init__(self, *args): - self.vals = list(args) - - @property - def value(self): - if len(self.vals) > 1: - self.vals = [np.concatenate(self.vals, axis=0)] - - return self.vals[0] - - def __getitem__(self, idx): - return ListArray(self.value[idx]) - - - def append(self, v): - self.vals.append(v[np.newaxis]) - - def __len__(self): - if self.vals: - return self.value.shape[0] - else: - return 0 - - -class MultiTrace(object): - def __init__(self, traces, vars=None): - try: - self.traces = list(traces) - except TypeError: - if vars is None: - raise ValueError("vars can't be None if trace count specified") - self.traces = [NpTrace(vars) for _ in range(traces)] - - def __getitem__(self, index_value): - - item_list = [h[index_value] for h in self.traces] - - if isinstance(index_value, slice): - return MultiTrace(item_list) - return item_list - - @property - def varnames(self): - return self.traces[0].varnames - - def point(self, index): - return [h.point(index) for h in self.traces] - - def combined(self): - # Returns a trace consisting of concatenated MultiTrace elements - h = NpTrace(self.traces[0].vars) - for k in self.traces[0].samples: - h.samples[k].vals = [s[k] for s in self.traces] - return h diff --git a/setup.py b/setup.py index 2a338bf31f..adcd609177 100755 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ long_description=LONG_DESCRIPTION, packages=['pymc', 'pymc.distributions', 'pymc.step_methods', 'pymc.tuning', - 'pymc.tests', 'pymc.glm'], + 'pymc.tests', 'pymc.glm', 'pymc.backends'], classifiers=classifiers, install_requires=install_reqs, tests_require=test_reqs, From ae77f690916edd5dc6517c26d6100cf382b52e40 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 2 Jan 2014 19:06:12 -0500 Subject: [PATCH 06/18] Add Text backend --- pymc/backends/__init__.py | 1 + pymc/backends/text.py | 104 ++++++++++++++++++++ pymc/tests/test_text_backend.py | 169 ++++++++++++++++++++++++++++++++ 3 files changed, 274 insertions(+) create mode 100644 pymc/backends/text.py create mode 100644 pymc/tests/test_text_backend.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index feebc00bb6..b5357a5b4c 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1 +1,2 @@ from pymc.backends.ndarray import NDArray +from pymc.backends.text import Text diff --git a/pymc/backends/text.py b/pymc/backends/text.py new file mode 100644 index 0000000000..a0bac969ba --- /dev/null +++ b/pymc/backends/text.py @@ -0,0 +1,104 @@ +"""Text file trace backend + +After sampling with NDArray backend, save results as text files. + +Database format +--------------- + +For each chain, a directory named `chain-N` is created. In this +directory, one file per variable is created containing the values of the +object. To deal with multidimensional variables, the array is reshaped +to one dimension before saving with `numpy.savetxt`. The shape +information is saved in a json file in the same directory and is used to +load the database back again using `numpy.loadtxt`. +""" +import os +import glob +import json +import numpy as np +from contextlib import contextmanager + +from pymc.backends import base +from pymc.backends.ndarray import NDArray, Trace + + +class Text(NDArray): + + def __init__(self, name, model=None, variables=None): + super(Text, self).__init__(name, model, variables) + if not os.path.exists(name): + os.mkdir(name) + + def close(self): + for chain in self.trace.chains: + chain_name = 'chain-{}'.format(chain) + chain_dir = os.path.join(self.name, chain_name) + os.mkdir(chain_dir) + + shapes = {} + for var_name in self.var_names: + data = self.trace.samples[chain][var_name] + var_file = os.path.join(chain_dir, var_name + '.txt') + np.savetxt(var_file, data.reshape(-1, data.size)) + shapes[var_name] = data.shape + ## Store shape information for reloading. + with _get_shape_fh(chain_dir, 'w') as sfh: + json.dump(shapes, sfh) + + +def load(name, chains=None, model=None): + """Load text database from name + + Parameters + ---------- + name : str + Path to root directory for text database + chains : list or None + Chains to load. If None, all chains are loaded. + model : Model + If None, the model is taken from the `with` context. The trace + can be loaded without connecting by passing False (although + connecting to the original model is recommended). + + Returns + ------- + ndarray.Trace instance + """ + chain_dirs = _get_chain_dirs(name) + if chains is None: + chains = list(chain_dirs.keys()) + + trace = Trace(None) + + for chain in chains: + chain_dir = chain_dirs[chain] + with _get_shape_fh(chain_dir, 'r') as sfh: + shapes = json.load(sfh) + samples = {} + for var_name, shape in shapes.items(): + var_file = os.path.join(chain_dir, var_name + '.txt') + samples[var_name] = np.loadtxt(var_file).reshape(shape) + trace.samples[chain] = samples + trace.var_names = list(trace.samples[chain].keys()) + return trace + + +## Not opening json directory in `Text.close` and `load` for testing +## convenience +@contextmanager +def _get_shape_fh(chain_dir, mode='r'): + fh = open(os.path.join(chain_dir, 'shapes.json'), mode) + try: + yield fh + finally: + fh.close() + + +def _get_chain_dirs(name): + """Return mapping of chain number to directory""" + return {_chain_dir_to_chain(chain_dir): chain_dir + for chain_dir in glob.glob(os.path.join(name, 'chain-*'))} + + +def _chain_dir_to_chain(chain_dir): + return int(os.path.basename(chain_dir).split('-')[1]) diff --git a/pymc/tests/test_text_backend.py b/pymc/tests/test_text_backend.py new file mode 100644 index 0000000000..94f609872c --- /dev/null +++ b/pymc/tests/test_text_backend.py @@ -0,0 +1,169 @@ +import sys +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +if sys.version_info[0] == 2: + from StringIO import StringIO +else: + from io import StringIO +import json + +from pymc.backends import text + + +class TextTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + shape_fh_patch = mock.patch('pymc.backends.text._get_shape_fh') + self.addCleanup(shape_fh_patch.stop) + self.shape_fh = shape_fh_patch.start() + + mkdir_patch = mock.patch('pymc.backends.text.os.mkdir') + self.addCleanup(mkdir_patch.stop) + mkdir_patch.start() + + +class TestTextWrite(TextTestCase): + + def setUp(self): + super(TestTextWrite, self).setUp() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.db = text.Text('textdb') + + self.draws = 5 + self.db.var_shapes = {'x': (), 'y': (4,)} + self.db.setup_samples(self.draws, chain=0) + + savetxt_patch = mock.patch('pymc.backends.text.np.savetxt') + self.addCleanup(savetxt_patch.stop) + self.savetxt = savetxt_patch.start() + + def test_close_args(self): + db = self.db + + db.close() + + self.assertEqual(self.savetxt.call_count, 2) + + for call, var_name in enumerate(db.var_names): + fname, data = self.savetxt.call_args_list[call][0] + self.assertEqual(fname, 'textdb/chain-0/{}.txt'.format(var_name)) + npt.assert_equal(data, db.trace[var_name].reshape(-1, data.size)) + + def test_close_shape(self): + db = self.db + + fh = StringIO() + self.shape_fh.return_value.__enter__.return_value = fh + db.close() + self.shape_fh.assert_called_with('textdb/chain-0', 'w') + + shape_result = fh.getvalue() + expected = {var_name: [self.draws] + list(var_shape) + for var_name, var_shape in db.var_shapes.items()} + self.assertEqual(json.loads(shape_result), expected) + + +def test__chain_dir_to_chain(): + assert text._chain_dir_to_chain('/path/to/chain-0') == 0 + assert text._chain_dir_to_chain('chain-0') == 0 + + +class TestTextLoad(TextTestCase): + + def setUp(self): + super(TestTextLoad, self).setUp() + + data = {'chain-1/x.txt': np.zeros(4), 'chain-1/y.txt': np.ones(2)} + loadtxt_patch = mock.patch('pymc.backends.text.np.loadtxt') + self.addCleanup(loadtxt_patch.stop) + self.loadtxt = loadtxt_patch.start() + + chain_patch = mock.patch('pymc.backends.text._get_chain_dirs') + self.addCleanup(chain_patch.stop) + self._get_chain_dirs = chain_patch.start() + + def test_load_model_supplied_scalar(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,)})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros(draws) + self.loadtxt.return_value = data + + trace = text.load('textdb', model=self.model) + npt.assert_equal(trace.samples[0]['x'], data) + + def test_load_model_supplied_1d(self): + draws = 4 + var_shape = (2,) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + db = text.load('textdb', model=self.model) + npt.assert_equal(db['x'], data) + + def test_load_model_supplied_2d(self): + draws = 4 + var_shape = (2, 3) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + db = text.load('textdb', model=self.model) + npt.assert_equal(db['x'], data) + + def test_load_model_supplied_multichain_chains(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + trace = text.load('textdb', model=self.model) + + self.assertEqual(trace.chains, [0, 1]) + + def test_load_model_supplied_multichain_chains_select_one(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + trace = text.load('textdb', model=self.model, chains=[1]) + + self.assertEqual(trace.chains, [1]) From e3152f7363661b79add489b83175bbfb711f5a69 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 2 Jan 2014 19:07:03 -0500 Subject: [PATCH 07/18] Add SQLite backend --- pymc/backends/__init__.py | 1 + pymc/backends/sqlite.py | 293 ++++++++++++++++++++++++++++++ pymc/tests/test_sqlite_backend.py | 252 +++++++++++++++++++++++++ 3 files changed, 546 insertions(+) create mode 100644 pymc/backends/sqlite.py create mode 100644 pymc/tests/test_sqlite_backend.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index b5357a5b4c..aa5f5fce6f 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1,2 +1,3 @@ from pymc.backends.ndarray import NDArray from pymc.backends.text import Text +from pymc.backends.sqlite import SQLite diff --git a/pymc/backends/sqlite.py b/pymc/backends/sqlite.py new file mode 100644 index 0000000000..0d02b9ae42 --- /dev/null +++ b/pymc/backends/sqlite.py @@ -0,0 +1,293 @@ +"""SQLite trace backend + +Store sampling values in SQLite database file. + +Database format +--------------- +For each variable, a table is created with the following format: + + recid (INT), draw (INT), chain (INT), v1 (FLOAT), v2 (FLOAT), v3 (FLOAT) ... + +The variable column names are extended to reflect addition dimensions. +For example, a variable with the shape (2, 2) would be stored as + + key (INT), draw (INT), chain (INT), v1_1 (FLOAT), v1_2 (FLOAT), v2_1 (FLOAT) ... + +The key is autoincremented each time a new row is added to the table. +The chain column denotes the chain index, and starts at 0. +""" +import numpy as np +import sqlite3 +import warnings + +from pymc.backends import base + +QUERIES = { + 'table': ('CREATE TABLE IF NOT EXISTS [{table}] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, chain INT(5), ' + '{value_cols})'), + 'insert': ('INSERT INTO [{table}] ' + '(recid, draw, chain, {value_cols}) ' + 'VALUES (NULL, {{draw}}, {chain}, {{value}})'), + 'select': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain})'), + 'select_burn': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain}) AND (draw > {burn})'), + 'select_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain}) AND ' + '(draw - (SELECT draw FROM [{table}] ' + 'WHERE chain = {chain} ' + 'ORDER BY draw LIMIT 1)) % {thin} = 0'), + 'select_burn_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain}) AND (draw > {burn}) ' + 'AND (draw - (SELECT draw FROM [{table}] ' + 'WHERE (chain = {chain}) AND (draw > {burn}) ' + 'ORDER BY draw LIMIT 1)) % {thin} = 0'), + 'select_point': ('SELECT * FROM [{{table}}] ' + 'WHERE (chain={chain}) AND (draw={draw})'), + 'max_draw': ('SELECT MAX(draw) FROM [{table}] ' + 'WHERE chain={chain}'), +} + + +class SQLite(base.Backend): + + def __init__(self, name, model=None, variables=None): + super(SQLite, self).__init__(name, model, variables) + ## initialized by _connect + self.conn = None + self.cursor = None + self.connected = False + + def _initialize_trace(self): + return Trace(self.var_names, self) + + def connect(self): + if self.connected: + return + + self.conn = sqlite3.connect(self.name, check_same_thread=False) + self.cursor = self.conn.cursor() + self.connected = True + + ## sampling methods + + def setup_samples(self, draws, chain): + ## make connection here (versus __init__) to handle parallel + ## chains + self.connect() + super(SQLite, self).setup_samples(draws, chain) + + def _create_trace(self, chain, var_name, shape): + ## first element of trace is number of draws + var_cols = create_colnames(shape[1:]) + var_float = ', '.join([v + ' FLOAT' for v in var_cols]) + self.cursor.execute(QUERIES['table'].format(table=var_name, + value_cols=var_float)) + return QUERIES['insert'].format(table=var_name, + value_cols=', '.join(var_cols), + chain=chain) + + def _store_value(self, draw, var_trace, value): + val_str = ', '.join(['{}'.format(val) for val in np.ravel(value)]) + query = var_trace.format(draw=draw, value=val_str) + self.cursor.execute(query) + + def commit(self): + self.conn.commit() + + def close(self): + if not self.connected: + return + + self.cursor.close() + self.commit() + self.conn.close() + self.connected = False + + +class Trace(base.Trace): + + __doc__ = 'SQLite trace\n' + base.Trace.__doc__ + + def __len__(self): + try: + return super(Trace, self).__len__() + except KeyError: # draws dictionary not set up + query = QUERIES['max_draw'].format(table=self.var_names[0], + chain=self.default_chain) + self.backend.connect() + draws = self.backend.cursor.execute(query).fetchall()[0][0] + 1 + self._draws[self.default_chain] = draws + return draws + + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from samples + + Parameters + ---------- + var_name : str + burn : int + thin : int + combine : bool + If True, results from all chains will be concatenated. + chains : list + Chains to retrieve. If None, `active_chains` is used. + squeeze : bool + If `combine` is False, return a single array element if the + resulting list of values only has one element (even if + `combine` is True). + + Returns + ------- + A list of NumPy array of values + """ + if burn < 0: + raise ValueError('Negative burn values not supported ' + 'in SQLite backend.') + if thin < 1: + raise ValueError('Only positive thin values are supported ' + 'in SQLite backend.') + if chains is None: + chains = self.active_chains + + var_name = str(var_name) + + query_args = {} + if burn == 0 and thin == 1: + query = 'select' + elif thin == 1: + query = 'select_burn' + query_args = {'burn': burn - 1} + elif burn == 0: + query = 'select_thin' + query_args = {'thin': thin} + else: + query = 'select_burn_thin' + query_args = {'burn': burn - 1, 'thin': thin} + + self.backend.connect() + results = [] + for chain in chains: + call = QUERIES[query].format(table=var_name, chain=chain, + **query_args) + self.backend.cursor.execute(call) + results.append(_rows_to_ndarray(self.backend.cursor)) + + return base._squeeze_cat(results, combine, squeeze) + + def _slice(self, idx): + """Slice trace object + """ + warnings.warn('Slice for SQLite backend has no effect.') + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + if idx < 0: + raise ValueError('Negtive indexing is not supported ' + 'in SQLite backend.') + if chain is None: + chain = self.default_chain + + query = QUERIES['select_point'].format(chain=chain, + draw=idx) + self.backend.connect() + var_values = {} + for var_name in self.var_names: + self.backend.cursor.execute(query.format(table=var_name)) + var_values[var_name] = np.squeeze( + _rows_to_ndarray(self.backend.cursor)) + return var_values + + +def create_colnames(shape): + """Return column names based on `shape` + + Examples + -------- + >>> create_colnames((5,)) + ['v1', 'v2', 'v3', 'v4', 'v5'] + + >>> create_colnames((2,2)) + ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + """ + if not shape: + return ['v1'] + + size = np.prod(shape) + indices = (np.indices(shape) + 1).reshape(-1, size) + return ['v' + '_'.join(map(str, i)) for i in zip(*indices)] + + +def load(name, model=None): + """Load SQLite database from file name + + Parameters + ---------- + name : str + Path to SQLite database file + model : Model + If None, the model is taken from the `with` context. The trace + can be loaded without connecting by passing False (although + connecting to the original model is recommended). + + Returns + ------- + SQLite backend instance + """ + db = SQLite(name, model=model) + db.connect() + + var_names = _get_table_list(db.cursor) + trace = Trace(var_names, db) + var_cols = {var_name: ', '.join(_get_var_strs(db.cursor, var_name)) + for var_name in var_names} + + ## Use first var_names element to get chain list. Chains should be + ## the same for all. + chains = _get_chain_list(db.cursor, var_names[0]) + + query = QUERIES['insert'] + for chain in chains: + samples = {} + for var_name in var_names: + samples[var_name] = query.format(table=var_name, + value_cols=var_cols[var_name], + chain=chain) + trace.samples[chain] = samples + return trace + + +def _get_table_list(cursor): + """Return a list of table names in the current database""" + ## Modified from Django. Skips the sqlite_sequence system table used + ## for autoincrement key generation. + cursor.execute("SELECT name FROM sqlite_master " + "WHERE type='table' AND NOT name='sqlite_sequence' " + "ORDER BY name") + return [row[0] for row in cursor.fetchall()] + + +def _get_var_strs(cursor, var_name): + cursor.execute('SELECT * FROM [{}]'.format(var_name)) + col_names = (col_descr[0] for col_descr in cursor.description) + return [name for name in col_names if name.startswith('v')] + + +def _get_chain_list(cursor, var_name): + """Return a list of sorted chains for `var_name`""" + cursor.execute('SELECT DISTINCT chain FROM [{}]'.format(var_name)) + chains = [chain[0] for chain in cursor.fetchall()] + chains.sort() + return chains + + +def _rows_to_ndarray(cursor): + """Convert SQL row to NDArray""" + return np.array([row[3:] for row in cursor.fetchall()]) diff --git a/pymc/tests/test_sqlite_backend.py b/pymc/tests/test_sqlite_backend.py new file mode 100644 index 0000000000..8aa5765421 --- /dev/null +++ b/pymc/tests/test_sqlite_backend.py @@ -0,0 +1,252 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +import warnings + +from pymc.backends import sqlite + + +class SQLiteTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.db = sqlite.SQLite('test.db') + self.db.cursor = mock.Mock() + + connect_patch = mock.patch('pymc.backends.sqlite.SQLite.connect') + self.addCleanup(connect_patch.stop) + self.connect = connect_patch.start() + self.draws = 5 + + +class TestSQLiteSample(SQLiteTestCase): + + def test_setup_trace(self): + self.db.setup_samples(self.draws, chain=0) + self.connect.assert_called_once_with() + + def test__create_trace_scalar(self): + db = self.db + var_trace = db._create_trace(chain=0, var_name='x', + shape=(self.draws,)) + + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT)') + db.cursor.execute.assert_called_once_with(tbl_expected) + + trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' + 'VALUES (NULL, {draw}, 0, {value})') + self.assertEqual(var_trace, trace_expected) + + def test__create_trace_1d(self): + db = self.db + var_trace = db._create_trace(chain=0, var_name='x', + shape=(self.draws, 2)) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT, v2 FLOAT)') + db.cursor.execute.assert_called_once_with(tbl_expected) + + trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1, v2) ' + 'VALUES (NULL, {draw}, 0, {value})') + self.assertEqual(var_trace, trace_expected) + + def test__create_trace_2d(self): + db = self.db + var_trace = db._create_trace(chain=0, var_name='x', + shape=(self.draws, 2, 3)) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), ' + 'v1_1 FLOAT, v1_2 FLOAT, v1_3 FLOAT, ' + 'v2_1 FLOAT, v2_2 FLOAT, v2_3 FLOAT)') + db.cursor.execute.assert_called_once_with(tbl_expected) + + trace_expected = ('INSERT INTO [x] (recid, draw, chain, ' + 'v1_1, v1_2, v1_3, ' + 'v2_1, v2_2, v2_3) ' + 'VALUES (NULL, {draw}, 0, {value})') + self.assertEqual(var_trace, trace_expected) + + def test__store_value_scalar(self): + db = self.db + db.setup_samples(draws=3, chain=0) + var_name = 'x' + query = sqlite.QUERIES['insert'].format(table=var_name, + value_cols='v1', + chain=0) + db.trace.samples[0] = {'x': query} + db._store_value(draw=0, var_trace=db.trace.samples[0][var_name], + value=3.) + expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' + 'VALUES (NULL, 0, 0, 3.0)') + db.cursor.execute.assert_called_once_with(expected) + + def test__store_value_1d(self): + db = self.db + db.setup_samples(draws=3, chain=0) + var_name = 'x' + query = sqlite.QUERIES['insert'].format(table=var_name, + value_cols='v1, v2', + chain=0) + db.trace.samples[0] = {'x': query} + print(db) + db._store_value(draw=0, var_trace=db.trace.samples[0][var_name], + value=[3., 3.]) + expected = ('INSERT INTO [x] (recid, draw, chain, v1, v2) ' + 'VALUES (NULL, 0, 0, 3.0, 3.0)') + db.cursor.execute.assert_called_once_with(expected) + + +class SQLiteSelectionTestCase(SQLiteTestCase): + + def setUp(self): + super(SQLiteSelectionTestCase, self).setUp() + self.db.var_shapes = {'x': (), 'y': (4,)} + self.db.setup_samples(self.draws, chain=0) + + ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') + self.addCleanup(ndarray_patch.stop) + ndarray_patch.start() + + self.draws = 5 + + +class TestSQLiteSelection(SQLiteSelectionTestCase): + + def test_get_values_default_keywords(self): + self.db.trace.get_values('x') + expected = 'SELECT * FROM [x] WHERE (chain = 0)' + self.db.cursor.execute.assert_called_with(expected) + + def test_get_values_burn_arg(self): + self.db.trace.get_values('x', burn=2) + expected = 'SELECT * FROM [x] WHERE (chain = 0) AND (draw > 1)' + self.db.cursor.execute.assert_called_with(expected) + + def test_get_values_thin_arg(self): + self.db.trace.get_values('x', thin=2) + expected = ('SELECT * FROM [x] ' + 'WHERE (chain = 0) AND ' + '(draw - (SELECT draw FROM [x] ' + 'WHERE chain = 0 ' + 'ORDER BY draw LIMIT 1)) % 2 = 0') + self.db.cursor.execute.assert_called_with(expected) + + def test_get_values_burn_thin_arg(self): + self.db.trace.get_values('x', thin=2, burn=1) + expected = ('SELECT * FROM [x] ' + 'WHERE (chain = 0) AND (draw > 0) ' + 'AND (draw - (SELECT draw FROM [x] ' + 'WHERE (chain = 0) AND (draw > 0) ' + 'ORDER BY draw LIMIT 1)) % 2 = 0') + self.db.cursor.execute.assert_called_with(expected) + + def test_point(self): + idx = 2 + + point = self.db.trace.point(idx) + expected = {'x': + 'SELECT * FROM [x] WHERE (chain=0) AND (draw=2)', + 'y': + 'SELECT * FROM [y] WHERE (chain=0) AND (draw=2)'} + + for var_name, value in expected.items(): + self.db.cursor.execute.assert_any_call(value) + + def test_slice(self): + with warnings.catch_warnings(record=True) as wrn: + self.db.trace[:10] + self.assertEqual(len(wrn), 1) + self.assertEqual(str(wrn[0].message), + 'Slice for SQLite backend has no effect.') + + +class TestSQLiteSelectionMultipleChains(SQLiteSelectionTestCase): + + def setUp(self): + super(TestSQLiteSelectionMultipleChains, self).setUp() + self.db.trace.samples[1] = self.db.trace.samples[0] + + def test_get_values_default_keywords(self): + self.db.trace.get_values('x') + expected = ['SELECT * FROM [x] WHERE (chain = 0)', + 'SELECT * FROM [x] WHERE (chain = 1)'] + for value in expected: + self.db.cursor.execute.assert_any_call(value) + + def test_get_values_chains_one_given(self): + self.db.trace.get_values('x', chains=[0]) + expected = 'SELECT * FROM [x] WHERE (chain = 0)' + ## If 0 chain is last call, 1 was not called + self.db.cursor.execute.assert_called_with(expected) + + +class TestSQLiteLoad(unittest.TestCase): + + def setUp(self): + db_patch = mock.patch('pymc.backends.sqlite.SQLite') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + table_list_patch = mock.patch('pymc.backends.sqlite._get_table_list') + self.addCleanup(table_list_patch.stop) + self.table_list = table_list_patch.start() + self.table_list.return_value = ['x', 'y'] + + var_strs_list_patch = mock.patch('pymc.backends.sqlite._get_var_strs') + self.addCleanup(var_strs_list_patch.stop) + self.var_strs_list = var_strs_list_patch.start() + self.var_strs_list.return_value = ['v1', 'v2'] + + chain_list_patch = mock.patch('pymc.backends.sqlite._get_chain_list') + self.addCleanup(chain_list_patch.stop) + self.chain_list = chain_list_patch.start() + self.chain_list.return_value = [0, 1] + + def test_load(self): + trace = sqlite.load('test.db') + self.assertEqual(len(trace.samples), 2) + + self.assertTrue('x' in trace.samples[0]) + self.assertTrue('y' in trace.samples[0]) + + expected = ('INSERT INTO [{}] ' + '(recid, draw, chain, v1, v2) ' + 'VALUES (NULL, {{draw}}, {}, {{value}})') + for chain in [0, 1]: + for var_name in ['x', 'y']: + self.assertEqual(trace.samples[chain][var_name], + expected.format(var_name, chain)) + + +def test_create_column_empty(): + result = sqlite.create_colnames(()) + expected = ['v1'] + assert result == expected + + +def test_create_column_1d(): + result = sqlite.create_colnames((2, )) + expected = ['v1', 'v2'] + assert result == expected + + +def test_create_column_2d(): + result = sqlite.create_colnames((2, 2)) + expected = ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + assert result == expected From d9b5d2f532ae8dc8dfc3c6fd3d3681128c89e7a8 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 2 Jan 2014 19:10:46 -0500 Subject: [PATCH 08/18] Add backend documentation --- pymc/backends/__init__.py | 125 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index aa5f5fce6f..92e6d2ef6f 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -1,3 +1,128 @@ +"""Backends for traces + +Available backends +------------------ + +1. NumPy array (pymc.backends.NDArray) +2. Text files (pymc.backends.Text) +3. SQLite (pymc.backends.SQLite) + +The NumPy arrays and text files both hold the entire trace in memory, +whereas SQLite commits the trace to the database while sampling. + +Selecting a backend +------------------- + +By default, a NumPy array is used as the backend. To specify a different +backend, pass a backend instance to `sample`. + +For example, the following would save traces to the file 'test.db'. + + >>> import pymc as pm + >>> db = pm.backends.SQLite('test.db') + >>> trace = pm.sample(..., db=db) + +Selecting values from a backend +------------------------------- + +After a backend is finished sampling, values can be accessed in a few +ways. The easiest way is to index the backend object with a variable or +variable name. + + >>> trace['x'] # or trace[x] + +The call will return a list containing the sampling values for all +chains of `x`. (Each call to `pymc.sample` creates a separate chain of +samples.) + +For more control is needed of which values are returned, the +`get_values` method can be used. The call below will return values from +all chains, burning the first 1000 iterations from each chain. + + >>> trace.get_values('x', burn=1000) + +Setting the `combined` flag will concatenate the results from all the +chains. + + >>> trace.get_values('x', burn=1000, combine=True) + +The `chains` parameter of `get_values` can be used to limit the chains +that are retrieved. To work with a subset of chains without having to +specify `chains` each call, you can set the `active_chains` attribute. + + >>> trace.chains + [0, 1, 2] + >>> trace.active_chains = [0, 2] + +After this, only chains 0 and 2 will be used in operations that work +with multiple chains. + +Similary, the `default_chain` attribute sets which chain is used for +functions that require a single chain (e.g., point). + + >>> trace.point(4) # or trace[4] + +Backends can also suppport slicing the trace object. For example, the +following call would return a new trace object without the first 1000 +sampling iterations for all variables. + + >>> sliced_trace = trace[1000:] + +Loading a saved backend +----------------------- + +Saved backends can be loaded using `load` function in the module for the +specific backend. + + >>> trace = pm.backends.sqlite.load('test.db') + +Writing custom backends +----------------------- + +To write a custom backend, two base classes should be inherited: +pymc.backends.base.Backend and pymc.backends.base.Trace. The first class +handles sampling, while the second provides access to the sampled +values. + +These following sampling-related methods of base.Backend should be +define in the child class: + +- _initialize_trace: Return a trace object for to store the sampled + values. + +- _create_trace: Create the trace object for a specific variable and + chain. For example, the NumPy array backend creates an array of zeros + shaped according to the number of planned iterations and the shape of + the given variable. + +- _store_value: Store the value for a draw of a particular variable + (using the trace from `_create_trace`). + +- commit: After a set amount of iterations, the sampling results will be + committed to the backend. In the case of in memory backends (NumPy and + Text), this doesn't do anything. + +- close: This method is called following sampling and should perform any + actions necessary for finalizing and cleaning up the backend. + +If backend-specific initialization is required, redefine `__init__` to +include this and the call the parent `__init__` method. + +In addition to sampling methods, several methods in base.Trace should +also be defined. + +- get_values: This is the core method for selecting values from the + backend. It can be called directly and is used by __getitem__ when the + backend is indexed with a variable name or object. + +- _slice: Defines how the backend returns a slice of itself. This + is called if the backend is indexed with a slice range. + +- point: Returns values for each variables at a single iteration. This + is called if the backend is indexed with a single integer. + +For specific examples, see pymc.backends.{ndarray,text,sqlite}.py. +""" from pymc.backends.ndarray import NDArray from pymc.backends.text import Text from pymc.backends.sqlite import SQLite From caab2e1cc671bc55c1a22c57e67c5c59a7d12508 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 2 Jan 2014 18:22:47 -0500 Subject: [PATCH 09/18] Add enumerate_progress function Wraps `enumerate` in progress bar update. This allows for checking the progress bar flag once and choosing `enumerate` or `enumerate_progress` as function versus checking progress bar flag each iteration. --- pymc/progressbar.py | 19 +++++++++++++++++++ pymc/sampling.py | 8 ++++---- pymc/tests/test_progressbar.py | 17 +++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) create mode 100644 pymc/tests/test_progressbar.py diff --git a/pymc/progressbar.py b/pymc/progressbar.py index 3027c4ba07..1ca256414e 100644 --- a/pymc/progressbar.py +++ b/pymc/progressbar.py @@ -121,3 +121,22 @@ def progress_bar(iters): return TextProgressBar(iters, ipythonprint) else: return TextProgressBar(iters, consoleprint) + + +def enumerate_progress(iterable, total, meter=None): + """Wrapper `enumerate` in a progress bar update + + Parameters + ---------- + iterable + total : int + meter + Any object with an `update` method that accepts current + iteration as an argument. If none, `progress_bar` is used. + """ + if meter is None: + meter = progress_bar + progress = meter(total) + for i, item in enumerate(iterable): + progress.update(i) + yield i, item diff --git a/pymc/sampling.py b/pymc/sampling.py index 18c08cb902..30a183bb8e 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -5,7 +5,7 @@ from time import time from .core import * from . import step_methods -from .progressbar import progress_bar +from .progressbar import enumerate_progress from numpy.random import seed __all__ = ['sample', 'iter_sample'] @@ -81,15 +81,15 @@ def _sample(draws, step, start=None, db=None, chain=0, tune=None, progressbar=True, model=None, variables=None, random_seed=None): sampling = _iter_sample(draws, step, start, db, chain, tune, model, variables, random_seed) + if progressbar: sampling = enumerate_progress(sampling, draws) else: sampling = enumerate(sampling) try: - for i, trace in enumerate(sampling): - if progressbar: - progress.update(i) + for i, trace in sampling: + pass except KeyboardInterrupt: trace.backend.clean_interrupt(i) trace.backend.close() diff --git a/pymc/tests/test_progressbar.py b/pymc/tests/test_progressbar.py new file mode 100644 index 0000000000..ea6bab90a6 --- /dev/null +++ b/pymc/tests/test_progressbar.py @@ -0,0 +1,17 @@ +try: + import unittest.mock as mock # py3 +except ImportError: + import mock + +from pymc import progressbar + + +def test_enumerate_progress(): + iterable = list(range(5, 8)) + meter = mock.Mock() + results = list(progressbar.enumerate_progress(iterable, + len(iterable), + meter)) + for i, _ in enumerate(iterable): + assert meter.update.called_with(i) + assert list(zip(*results))[1] == tuple(iterable) From 40eb2b391c68601dc8f69fa0133e20fbc4eb469f Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Mon, 6 Jan 2014 02:54:30 -0500 Subject: [PATCH 10/18] Dump/load tests for text and SQLite --- pymc/examples/sqlite_dump_load.py | 50 +++++++++++++++++++++++++++++++ pymc/examples/text_dump_load.py | 47 +++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 pymc/examples/sqlite_dump_load.py create mode 100644 pymc/examples/text_dump_load.py diff --git a/pymc/examples/sqlite_dump_load.py b/pymc/examples/sqlite_dump_load.py new file mode 100644 index 0000000000..9a678e35ed --- /dev/null +++ b/pymc/examples/sqlite_dump_load.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import numpy.testing as npt + +import pymc as pm + +# import pydevd +# pydevd.set_pm_excepthook() +np.seterr(invalid='raise') + +data = np.random.normal(size=(2, 20)) +model = pm.Model() + +with model: + x = pm.Normal('x', mu=.5, tau=2. ** -2, shape=(2, 1)) + z = pm.Beta('z', alpha=10, beta=5.5) + d = pm.Normal('data', mu=x, tau=.75 ** -2, observed=data) + + +def run(n=50): + if n == 'short': + n = 5 + with model: + try: + trace = pm.sample(n, step=pm.Metropolis(), + db=pm.backends.SQLite('test.db'), + threads=2) + dumped = pm.backends.sqlite.load('test.db') + + assert trace[x][0].shape[0] == n + assert trace[x][1].shape[0] == n + assert trace.get_values('z', burn=3, + combine=True).shape[0] == n * 2 - 3 * 2 + + assert trace.nchains == dumped.nchains + assert list(sorted(trace.var_names)) == list(sorted(dumped.var_names)) + + for chain in trace.chains: + for var_name in trace.var_names: + data = trace.samples[chain][var_name] + dumped_data = dumped.samples[chain][var_name] + npt.assert_equal(data, dumped_data) + finally: + try: + os.remove('test.db') + except FileNotFoundError: + pass + +if __name__ == '__main__': + run('short') diff --git a/pymc/examples/text_dump_load.py b/pymc/examples/text_dump_load.py new file mode 100644 index 0000000000..361d95949d --- /dev/null +++ b/pymc/examples/text_dump_load.py @@ -0,0 +1,47 @@ +import shutil +import numpy as np +import numpy.testing as npt + +import pymc as pm + +# import pydevd +# pydevd.set_pm_excepthook() +np.seterr(invalid='raise') + +data = np.random.normal(size=(2, 20)) +model = pm.Model() + +with model: + x = pm.Normal('x', mu=.5, tau=2. ** -2, shape=(2, 1)) + z = pm.Beta('z', alpha=10, beta=5.5) + d = pm.Normal('data', mu=x, tau=.75 ** -2, observed=data) + + +def run(n=50): + if n == 'short': + n = 5 + with model: + try: + trace = pm.sample(n, step=pm.Metropolis(), + db=pm.backends.Text('textdb'), + threads=2) + dumped = pm.backends.text.load('textdb') + + assert trace[x][0].shape[0] == n + assert trace[x][1].shape[0] == n + assert trace.get_values('z', burn=3, + combine=True).shape[0] == n * 2 - 3 * 2 + + assert trace.nchains == dumped.nchains + assert list(sorted(trace.var_names)) == list(sorted(dumped.var_names)) + + for chain in trace.chains: + for var_name in trace.var_names: + data = trace.samples[chain][var_name] + dumped_data = dumped.samples[chain][var_name] + npt.assert_equal(data, dumped_data) + finally: + shutil.rmtree('textdb') + +if __name__ == '__main__': + run('short') From 156fc425849e27785e21433389e41b5cffca11de Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Mon, 6 Jan 2014 15:17:11 -0500 Subject: [PATCH 11/18] Test equality of NDArray and SQLite selections This test compares of selection methods for NDArray to SQLite traces. --- pymc/tests/test_ndarray_sqlite_selection.py | 114 ++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 pymc/tests/test_ndarray_sqlite_selection.py diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py new file mode 100644 index 0000000000..3de2977ac1 --- /dev/null +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -0,0 +1,114 @@ +import os +import numpy as np +import numpy.testing as npt +import unittest +import multiprocessing as mp + +import pymc as pm + + +def remove_db(db): + try: + os.remove(db) + except FileNotFoundError: + pass + + +class TestCompareNDArraySQLite(unittest.TestCase): + + @classmethod + def setUpClass(cls): + ## Use two threads if available + try: + mp.Pool(2) + threads = 2 + except: + threads = 1 + + data = np.random.normal(size=(3, 20)) + n = 1 + + model = pm.Model() + with model: + x = pm.Normal('x', 0, 1., shape=n) + + # start sampling at the MAP + start = {'x': 0.} + step = pm.Metropolis() + cls.db = 'test.db' + + try: + cls.draws = 10 + cls.ntrace = pm.sample(cls.draws, step=step, + threads=threads, random_seed=4) + cls.strace = pm.sample(cls.draws, step=step, + threads=threads, random_seed=4, + db=pm.backends.SQLite(cls.db)) + except: + remove_db(cls.db) + raise + + @classmethod + def tearDownClass(cls): + remove_db(cls.db) + + def test_chain_length(self): + assert self.ntrace.nchains == self.strace.nchains + assert len(self.ntrace) == len(self.strace) + + def test_get_item(self): + npt.assert_equal(self.ntrace['x'], self.strace['x']) + + def test_get_values(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf), + self.strace.get_values('x', combine=cf)) + + def test_get_values_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=False, + squeeze=False), + self.strace.get_values('x', combine=False, + squeeze=False)) + + def test_get_values_combine_and_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=True, + squeeze=False), + self.strace.get_values('x', combine=True, + squeeze=False)) + + def test_get_values_with_burn(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, burn=3), + self.strace.get_values('x', combine=cf, burn=3)) + + ## burn so just one value left + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=self.draws - 1), + self.strace.get_values('x', combine=cf, + burn=self.draws - 1)) + + def test_get_values_with_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, thin=2), + self.strace.get_values('x', combine=cf, thin=2)) + + def test_get_values_with_burn_and_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=2, thin=2), + self.strace.get_values('x', combine=cf, + burn=2, thin=2)) + + def test_get_values_with_chains_arg(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', chains=[0]), + self.strace.get_values('x', chains=[0])) + + def test_point(self): + npoint, spoint = self.ntrace[4], self.strace[4] + npt.assert_equal(npoint['x'], spoint['x']) + + def test_point_with_chain_arg(self): + npoint = self.ntrace.point(4, chain=0) + spoint = self.strace.point(4, chain=0) + npt.assert_equal(npoint['x'], spoint['x']) From 978a5280fa7421e2fa8491c0fca0dac45840bad5 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 9 Jan 2014 11:59:28 -0500 Subject: [PATCH 12/18] Fix error in summary warning Summary was not updated to reflect variables name change (which is caught by 'test_summary_2dim_value_model'). --- pymc/stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/stats.py b/pymc/stats.py index ff46bf8c3e..8a166c920c 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -272,7 +272,7 @@ def summary(trace, var_names=None, alpha=0.05, start=0, batches=100, sample = sample[:, None] elif sample.ndim > 2: ## trace dimensions greater than 2 (variable greater than 1) - warnings.warn('Skipping {} (above 1 dimension)'.format(var)) + warnings.warn('Skipping {} (above 1 dimension)'.format(var_name)) continue print('\n%s:' % var_name) From 2e40ff585f5cf9dfa29af94e948620a99c39d690 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 9 Jan 2014 11:53:59 -0500 Subject: [PATCH 13/18] Revert "Move summary to stats module" This reverts commit 6dda7e1ab641a40336bd86d53201ddbefe47c5e8. The is being reverted to make it easier to compare changes with the master branch. Conflicts: pymc/__init__.py pymc/stats.py pymc/tests/test_trace.py pymc/trace.py --- pymc/__init__.py | 2 +- pymc/stats.py | 161 +------------------------------------- pymc/tests/test_stats.py | 164 +++----------------------------------- pymc/tests/test_trace.py | 150 +++++++++++++++++++++++++++++++++++ pymc/trace.py | 165 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 326 insertions(+), 316 deletions(-) create mode 100644 pymc/tests/test_trace.py create mode 100644 pymc/trace.py diff --git a/pymc/__init__.py b/pymc/__init__.py index 1dae5edbbf..cb0e468c69 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -6,7 +6,7 @@ from .sampling import * -from .stats import summary +from .trace import * from .step_methods import * from .tuning import * diff --git a/pymc/stats.py b/pymc/stats.py index 8a166c920c..60ca02fe71 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -1,10 +1,8 @@ """Utility functions for PyMC""" import numpy as np -import warnings - -__all__ = ['autocorr', 'autocov', 'hpd', 'quantiles', 'mc_error', 'summary'] +__all__ = ['autocorr', 'autocov', 'hpd', 'quantiles', 'mc_error'] def statfunc(f): """ @@ -228,160 +226,3 @@ def quantiles(x, qlist=(2.5, 25, 50, 75, 97.5)): except IndexError: print("Too few elements for quantile calculation") - - -def summary(trace, var_names=None, alpha=0.05, start=0, batches=100, - roundto=3): - """ - Generate a pretty-printed summary of the node. - - :Parameters: - trace : Trace object - Trace containing MCMC samples - - var_names : list of strings - List of variables to summarize. Defaults to None, which results - in all variables summarized. - - alpha : float - The alpha level for generating posterior intervals. Defaults to - 0.05. - - start : int - The starting index from which to summarize (each) chain. Defaults - to zero. - - batches : int - Batch size for calculating standard deviation for non-independent - samples. Defaults to 100. - - roundto : int - The number of digits to round posterior statistics. - - """ - if var_names is None: - var_names = trace.var_names - - stat_summ = _StatSummary(roundto, batches, alpha) - pq_summ = _PosteriorQuantileSummary(roundto, alpha) - - for var_name in var_names: - # Extract sampled values - sample = trace.get_values(var_name, burn=start, combine=True) - if sample.ndim == 1: - sample = sample[:, None] - elif sample.ndim > 2: - ## trace dimensions greater than 2 (variable greater than 1) - warnings.warn('Skipping {} (above 1 dimension)'.format(var_name)) - continue - - print('\n%s:' % var_name) - print(' ') - - stat_summ.print_output(sample) - pq_summ.print_output(sample) - - -class _Summary(object): - """Base class for summary output""" - def __init__(self, roundto): - self.roundto = roundto - self.header_lines = None - self.leader = ' ' - self.spaces = None - - def print_output(self, sample): - print('\n'.join(list(self._get_lines(sample))) + '\n') - - def _get_lines(self, sample): - for line in self.header_lines: - yield self.leader + line - summary_lines = self._calculate_values(sample) - for line in self._create_value_output(summary_lines): - yield self.leader + line - - def _create_value_output(self, lines): - for values in lines: - self._format_values(values) - yield self.value_line.format(pad=self.spaces, **values).strip() - - def _calculate_values(self, sample): - raise NotImplementedError - - def _format_values(self, summary_values): - for key, val in summary_values.items(): - summary_values[key] = '{:.{ndec}f}'.format( - float(val), ndec=self.roundto) - - -class _StatSummary(_Summary): - def __init__(self, roundto, batches, alpha): - super(_StatSummary, self).__init__(roundto) - spaces = 17 - hpd_name = '{}% HPD interval'.format(int(100 * (1 - alpha))) - value_line = '{mean:<{pad}}{sd:<{pad}}{mce:<{pad}}{hpd:<{pad}}' - header = value_line.format(mean='Mean', sd='SD', mce='MC Error', - hpd=hpd_name, pad=spaces).strip() - hline = '-' * len(header) - - self.header_lines = [header, hline] - self.spaces = spaces - self.value_line = value_line - self.batches = batches - self.alpha = alpha - - def _calculate_values(self, sample): - return _calculate_stats(sample, self.batches, self.alpha) - - def _format_values(self, summary_values): - roundto = self.roundto - for key, val in summary_values.items(): - if key == 'hpd': - summary_values[key] = '[{:.{ndec}f}, {:.{ndec}f}]'.format( - *val, ndec=roundto) - else: - summary_values[key] = '{:.{ndec}f}'.format( - float(val), ndec=roundto) - - -class _PosteriorQuantileSummary(_Summary): - def __init__(self, roundto, alpha): - super(_PosteriorQuantileSummary, self).__init__(roundto) - spaces = 15 - title = 'Posterior quantiles:' - value_line = '{lo:<{pad}}{q25:<{pad}}{q50:<{pad}}{q75:<{pad}}{hi:<{pad}}' - lo, hi = 100 * alpha / 2, 100 * (1. - alpha / 2) - qlist = (lo, 25, 50, 75, hi) - header = value_line.format(lo=lo, q25=25, q50=50, q75=75, hi=hi, - pad=spaces).strip() - hline = '|{thin}|{thick}|{thick}|{thin}|'.format( - thin='-' * (spaces - 1), thick='=' * (spaces - 1)) - - self.header_lines = [title, header, hline] - self.spaces = spaces - self.lo, self.hi = lo, hi - self.qlist = qlist - self.value_line = value_line - - def _calculate_values(self, sample): - return _calculate_posterior_quantiles(sample, self.qlist) - - -def _calculate_stats(sample, batches, alpha): - means = sample.mean(0) - sds = sample.std(0) - mces = mc_error(sample, batches) - intervals = hpd(sample, alpha) - for index in range(sample.shape[1]): - mean, sd, mce = [stat[index] for stat in (means, sds, mces)] - interval = intervals[index].squeeze().tolist() - yield {'mean': mean, 'sd': sd, 'mce': mce, 'hpd': interval} - - -def _calculate_posterior_quantiles(sample, qlist): - var_quantiles = quantiles(sample, qlist=qlist) - ## Replace ends of qlist with 'lo' and 'hi' - qends = {qlist[0]: 'lo', qlist[-1]: 'hi'} - qkeys = {q: qends[q] if q in qends else 'q{}'.format(q) for q in qlist} - for index in range(sample.shape[1]): - yield {qkeys[q]: var_quantiles[q][index] for q in qlist} diff --git a/pymc/tests/test_stats.py b/pymc/tests/test_stats.py index 371dfc736e..3ec98f760f 100644 --- a/pymc/tests/test_stats.py +++ b/pymc/tests/test_stats.py @@ -1,10 +1,6 @@ -import pymc as pm -from pymc import stats -import numpy as np +from ..stats import * from numpy.random import random, normal, seed from numpy.testing import assert_equal, assert_almost_equal, assert_array_almost_equal -import warnings -import nose seed(111) normal_sample = normal(0, 1, 1000000) @@ -12,179 +8,37 @@ def test_autocorr(): """Test autocorrelation and autocovariance functions""" - assert_almost_equal(stats.autocorr(normal_sample), 0, 2) + assert_almost_equal(autocorr(normal_sample), 0, 2) y = [(normal_sample[i-1] + normal_sample[i])/2 for i in range(1, len(normal_sample))] - assert_almost_equal(stats.autocorr(y), 0.5, 2) + assert_almost_equal(autocorr(y), 0.5, 2) def test_hpd(): """Test HPD calculation""" - interval = stats.hpd(normal_sample) + interval = hpd(normal_sample) assert_array_almost_equal(interval, [-1.96, 1.96], 2) def test_make_indices(): """Test make_indices function""" + from ..stats import make_indices + ind = [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)] - assert_equal(ind, stats.make_indices((2, 3))) + assert_equal(ind, make_indices((2, 3))) def test_mc_error(): """Test batch standard deviation function""" x = random(100000) - assert(stats.mc_error(x) < 0.0025) + assert(mc_error(x) < 0.0025) def test_quantiles(): """Test quantiles function""" - q = stats.quantiles(normal_sample) + q = quantiles(normal_sample) assert_array_almost_equal(sorted(q.values()), [-1.96, -0.67, 0, 0.67, 1.96], 2) - - -def test_summary_1_value_model(): - mu = -2.1 - tau = 1.3 - with pm.Model() as model: - x = pm.Normal('x', mu, tau, testval=.1) - step = pm.Metropolis(model.vars, np.diag([1.])) - trace = pm.sample(100, step=step) - stats.summary(trace) - - -def test_summary_2_value_model(): - mu = -2.1 - tau = 1.3 - with pm.Model() as model: - x = pm.Normal('x', mu, tau, shape=2, testval=[.1, .1]) - step = pm.Metropolis(model.vars, np.diag([1.])) - trace = pm.sample(100, step=step) - stats.summary(trace) - - -def test_summary_2dim_value_model(): - mu = -2.1 - tau = 1.3 - with pm.Model() as model: - x = pm.Normal('x', mu, tau, shape=(2, 2), - testval=np.tile(.1, (2, 2))) - step = pm.Metropolis(model.vars, np.diag([1.])) - trace = pm.sample(100, step=step) - - with warnings.catch_warnings(record=True) as wrn: - stats.summary(trace) - assert len(wrn) == 1 - assert str(wrn[0].message) == 'Skipping x (above 1 dimension)' - - -def test_summary_format_values(): - roundto = 2 - summ = stats._Summary(roundto) - d = {'nodec': 1, 'onedec': 1.0, 'twodec': 1.00, 'threedec': 1.000} - summ._format_values(d) - for val in d.values(): - assert val == '1.00' - - -def test_stat_summary_format_hpd_values(): - roundto = 2 - summ = stats._StatSummary(roundto, None, 0.05) - d = {'nodec': 1, 'hpd': [1, 1]} - summ._format_values(d) - for key, val in d.items(): - if key == 'hpd': - assert val == '[1.00, 1.00]' - else: - assert val == '1.00' - - -@nose.tools.raises(IndexError) -def test_calculate_stats_variable_size1_not_adjusted(): - sample = np.arange(10) - list(stats._calculate_stats(sample, 5, 0.05)) - - -def test_calculate_stats_variable_size1_adjusted(): - sample = np.arange(10)[:, None] - result_size = len(list(stats._calculate_stats(sample, 5, 0.05))) - assert result_size == 1 - -def test_calculate_stats_variable_size2(): - ## 2 traces of 5 - sample = np.arange(10).reshape(5, 2) - result_size = len(list(stats._calculate_stats(sample, 5, 0.05))) - assert result_size == 2 - - -@nose.tools.raises(IndexError) -def test_calculate_pquantiles_variable_size1_not_adjusted(): - sample = np.arange(10) - qlist = (0.25, 25, 50, 75, 0.98) - list(stats._calculate_posterior_quantiles(sample, - qlist)) - - -def test_calculate_pquantiles_variable_size1_adjusted(): - sample = np.arange(10)[:, None] - qlist = (0.25, 25, 50, 75, 0.98) - result_size = len(list(stats._calculate_posterior_quantiles(sample, - qlist))) - assert result_size == 1 - - -def test_stats_value_line(): - roundto = 1 - summ = stats._StatSummary(roundto, None, 0.05) - values = [{'mean': 0, 'sd': 1, 'mce': 2, 'hpd': [4, 4]}, - {'mean': 5, 'sd': 6, 'mce': 7, 'hpd': [8, 8]},] - - expected = ['0.0 1.0 2.0 [4.0, 4.0]', - '5.0 6.0 7.0 [8.0, 8.0]'] - result = list(summ._create_value_output(values)) - assert result == expected - - -def test_post_quantile_value_line(): - roundto = 1 - summ = stats._PosteriorQuantileSummary(roundto, 0.05) - values = [{'lo': 0, 'q25': 1, 'q50': 2, 'q75': 4, 'hi': 5}, - {'lo': 6, 'q25': 7, 'q50': 8, 'q75': 9, 'hi': 10},] - - expected = ['0.0 1.0 2.0 4.0 5.0', - '6.0 7.0 8.0 9.0 10.0'] - result = list(summ._create_value_output(values)) - assert result == expected - - -def test_stats_output_lines(): - roundto = 1 - x = np.arange(10).reshape(5, 2) - - summ = stats._StatSummary(roundto, 5, 0.05) - - expected = [' Mean SD MC Error 95% HPD interval', - ' -------------------------------------------------------------------', - ' 4.0 2.8 1.3 [0.0, 8.0]', - ' 5.0 2.8 1.3 [1.0, 9.0]',] - result = list(summ._get_lines(x)) - assert result == expected - - -def test_posterior_quantiles_output_lines(): - roundto = 1 - x = np.arange(10).reshape(5, 2) - - summ = stats._PosteriorQuantileSummary(roundto, 0.05) - - expected = [' Posterior quantiles:', - ' 2.5 25 50 75 97.5', - ' |--------------|==============|==============|--------------|', - ' 0.0 2.0 4.0 6.0 8.0', - ' 1.0 3.0 5.0 7.0 9.0'] - - result = list(summ._get_lines(x)) - assert result == expected diff --git a/pymc/tests/test_trace.py b/pymc/tests/test_trace.py new file mode 100644 index 0000000000..ef82ee29df --- /dev/null +++ b/pymc/tests/test_trace.py @@ -0,0 +1,150 @@ +from .checks import * +from .models import * +import pymc as pm +import numpy as np +import warnings +import nose + + +def test_summary_1_value_model(): + mu = -2.1 + tau = 1.3 + with Model() as model: + x = Normal('x', mu, tau, testval=.1) + step = Metropolis(model.vars, np.diag([1.])) + trace = sample(100, step=step) + pm.summary(trace) + + +def test_summary_2_value_model(): + mu = -2.1 + tau = 1.3 + with Model() as model: + x = Normal('x', mu, tau, shape=2, testval=[.1, .1]) + step = Metropolis(model.vars, np.diag([1.])) + trace = sample(100, step=step) + pm.summary(trace) + + +def test_summary_2dim_value_model(): + mu = -2.1 + tau = 1.3 + with Model() as model: + x = Normal('x', mu, tau, shape=(2, 2), + testval=np.tile(.1, (2, 2))) + step = Metropolis(model.vars, np.diag([1.])) + trace = sample(100, step=step) + + with warnings.catch_warnings(record=True) as wrn: + pm.summary(trace) + assert len(wrn) == 1 + assert str(wrn[0].message) == 'Skipping x (above 1 dimension)' + + +def test_summary_format_values(): + roundto = 2 + summ = pm.trace._Summary(roundto) + d = {'nodec': 1, 'onedec': 1.0, 'twodec': 1.00, 'threedec': 1.000} + summ._format_values(d) + for val in d.values(): + assert val == '1.00' + + +def test_stat_summary_format_hpd_values(): + roundto = 2 + summ = pm.trace._StatSummary(roundto, None, 0.05) + d = {'nodec': 1, 'hpd': [1, 1]} + summ._format_values(d) + for key, val in d.items(): + if key == 'hpd': + assert val == '[1.00, 1.00]' + else: + assert val == '1.00' + + +@nose.tools.raises(IndexError) +def test_calculate_stats_variable_size1_not_adjusted(): + sample = np.arange(10) + list(pm.trace._calculate_stats(sample, 5, 0.05)) + + +def test_calculate_stats_variable_size1_adjusted(): + sample = np.arange(10)[:, None] + result_size = len(list(pm.trace._calculate_stats(sample, 5, 0.05))) + assert result_size == 1 + +def test_calculate_stats_variable_size2(): + ## 2 traces of 5 + sample = np.arange(10).reshape(5, 2) + result_size = len(list(pm.trace._calculate_stats(sample, 5, 0.05))) + assert result_size == 2 + + +@nose.tools.raises(IndexError) +def test_calculate_pquantiles_variable_size1_not_adjusted(): + sample = np.arange(10) + qlist = (0.25, 25, 50, 75, 0.98) + list(pm.trace._calculate_posterior_quantiles(sample, + qlist)) + + +def test_calculate_pquantiles_variable_size1_adjusted(): + sample = np.arange(10)[:, None] + qlist = (0.25, 25, 50, 75, 0.98) + result_size = len(list(pm.trace._calculate_posterior_quantiles(sample, + qlist))) + assert result_size == 1 + + +def test_stats_value_line(): + roundto = 1 + summ = pm.trace._StatSummary(roundto, None, 0.05) + values = [{'mean': 0, 'sd': 1, 'mce': 2, 'hpd': [4, 4]}, + {'mean': 5, 'sd': 6, 'mce': 7, 'hpd': [8, 8]},] + + expected = ['0.0 1.0 2.0 [4.0, 4.0]', + '5.0 6.0 7.0 [8.0, 8.0]'] + result = list(summ._create_value_output(values)) + assert result == expected + + +def test_post_quantile_value_line(): + roundto = 1 + summ = pm.trace._PosteriorQuantileSummary(roundto, 0.05) + values = [{'lo': 0, 'q25': 1, 'q50': 2, 'q75': 4, 'hi': 5}, + {'lo': 6, 'q25': 7, 'q50': 8, 'q75': 9, 'hi': 10},] + + expected = ['0.0 1.0 2.0 4.0 5.0', + '6.0 7.0 8.0 9.0 10.0'] + result = list(summ._create_value_output(values)) + assert result == expected + + +def test_stats_output_lines(): + roundto = 1 + x = np.arange(10).reshape(5, 2) + + summ = pm.trace._StatSummary(roundto, 5, 0.05) + + expected = [' Mean SD MC Error 95% HPD interval', + ' -------------------------------------------------------------------', + ' 4.0 2.8 1.3 [0.0, 8.0]', + ' 5.0 2.8 1.3 [1.0, 9.0]',] + result = list(summ._get_lines(x)) + assert result == expected + + +def test_posterior_quantiles_output_lines(): + roundto = 1 + x = np.arange(10).reshape(5, 2) + + summ = pm.trace._PosteriorQuantileSummary(roundto, 0.05) + + expected = [' Posterior quantiles:', + ' 2.5 25 50 75 97.5', + ' |--------------|==============|==============|--------------|', + ' 0.0 2.0 4.0 6.0 8.0', + ' 1.0 3.0 5.0 7.0 9.0'] + + result = list(summ._get_lines(x)) + assert result == expected diff --git a/pymc/trace.py b/pymc/trace.py new file mode 100644 index 0000000000..cf43352bf5 --- /dev/null +++ b/pymc/trace.py @@ -0,0 +1,165 @@ +import numpy as np +from .core import * +from .stats import * +import copy +import types +import warnings + +__all__ = ['summary'] + + +def summary(trace, var_names=None, alpha=0.05, start=0, batches=100, + roundto=3): + """ + Generate a pretty-printed summary of the node. + + :Parameters: + trace : Trace object + Trace containing MCMC samples + + var_names : list of strings + List of variables to summarize. Defaults to None, which results + in all variables summarized. + + alpha : float + The alpha level for generating posterior intervals. Defaults to + 0.05. + + start : int + The starting index from which to summarize (each) chain. Defaults + to zero. + + batches : int + Batch size for calculating standard deviation for non-independent + samples. Defaults to 100. + + roundto : int + The number of digits to round posterior statistics. + + """ + if var_names is None: + var_names = trace.var_names + + stat_summ = _StatSummary(roundto, batches, alpha) + pq_summ = _PosteriorQuantileSummary(roundto, alpha) + + for var_name in var_names: + # Extract sampled values + sample = trace.get_values(var_name, burn=start, combine=True) + if sample.ndim == 1: + sample = sample[:, None] + elif sample.ndim > 2: + ## trace dimensions greater than 2 (variable greater than 1) + warnings.warn('Skipping {} (above 1 dimension)'.format(var_name)) + continue + + print('\n%s:' % var_name) + print(' ') + + stat_summ.print_output(sample) + pq_summ.print_output(sample) + + +class _Summary(object): + """Base class for summary output""" + def __init__(self, roundto): + self.roundto = roundto + self.header_lines = None + self.leader = ' ' + self.spaces = None + + def print_output(self, sample): + print('\n'.join(list(self._get_lines(sample))) + '\n') + + def _get_lines(self, sample): + for line in self.header_lines: + yield self.leader + line + summary_lines = self._calculate_values(sample) + for line in self._create_value_output(summary_lines): + yield self.leader + line + + def _create_value_output(self, lines): + for values in lines: + self._format_values(values) + yield self.value_line.format(pad=self.spaces, **values).strip() + + def _calculate_values(self, sample): + raise NotImplementedError + + def _format_values(self, summary_values): + for key, val in summary_values.items(): + summary_values[key] = '{:.{ndec}f}'.format( + float(val), ndec=self.roundto) + + +class _StatSummary(_Summary): + def __init__(self, roundto, batches, alpha): + super(_StatSummary, self).__init__(roundto) + spaces = 17 + hpd_name = '{}% HPD interval'.format(int(100 * (1 - alpha))) + value_line = '{mean:<{pad}}{sd:<{pad}}{mce:<{pad}}{hpd:<{pad}}' + header = value_line.format(mean='Mean', sd='SD', mce='MC Error', + hpd=hpd_name, pad=spaces).strip() + hline = '-' * len(header) + + self.header_lines = [header, hline] + self.spaces = spaces + self.value_line = value_line + self.batches = batches + self.alpha = alpha + + def _calculate_values(self, sample): + return _calculate_stats(sample, self.batches, self.alpha) + + def _format_values(self, summary_values): + roundto = self.roundto + for key, val in summary_values.items(): + if key == 'hpd': + summary_values[key] = '[{:.{ndec}f}, {:.{ndec}f}]'.format( + *val, ndec=roundto) + else: + summary_values[key] = '{:.{ndec}f}'.format( + float(val), ndec=roundto) + + +class _PosteriorQuantileSummary(_Summary): + def __init__(self, roundto, alpha): + super(_PosteriorQuantileSummary, self).__init__(roundto) + spaces = 15 + title = 'Posterior quantiles:' + value_line = '{lo:<{pad}}{q25:<{pad}}{q50:<{pad}}{q75:<{pad}}{hi:<{pad}}' + lo, hi = 100 * alpha / 2, 100 * (1. - alpha / 2) + qlist = (lo, 25, 50, 75, hi) + header = value_line.format(lo=lo, q25=25, q50=50, q75=75, hi=hi, + pad=spaces).strip() + hline = '|{thin}|{thick}|{thick}|{thin}|'.format( + thin='-' * (spaces - 1), thick='=' * (spaces - 1)) + + self.header_lines = [title, header, hline] + self.spaces = spaces + self.lo, self.hi = lo, hi + self.qlist = qlist + self.value_line = value_line + + def _calculate_values(self, sample): + return _calculate_posterior_quantiles(sample, self.qlist) + + +def _calculate_stats(sample, batches, alpha): + means = sample.mean(0) + sds = sample.std(0) + mces = mc_error(sample, batches) + intervals = hpd(sample, alpha) + for index in range(sample.shape[1]): + mean, sd, mce = [stat[index] for stat in (means, sds, mces)] + interval = intervals[index].squeeze().tolist() + yield {'mean': mean, 'sd': sd, 'mce': mce, 'hpd': interval} + + +def _calculate_posterior_quantiles(sample, qlist): + var_quantiles = quantiles(sample, qlist=qlist) + ## Replace ends of qlist with 'lo' and 'hi' + qends = {qlist[0]: 'lo', qlist[-1]: 'hi'} + qkeys = {q: qends[q] if q in qends else 'q{}'.format(q) for q in qlist} + for index in range(sample.shape[1]): + yield {qkeys[q]: var_quantiles[q][index] for q in qlist} From bcc7633007603d224c32f67d8ee584b2aad0ba30 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Thu, 9 Jan 2014 12:08:29 -0500 Subject: [PATCH 14/18] Revert "Rename sample.py to sampling.py" This reverts commit 683507b08504c751a32dcc7e12bc2711bfbe2ecc. The is being reverted to make it easier to compare changes with the master branch. However, this reintroduces the naming conflict that the commit fixed. 'pymc.sample' now refers to the function, but the module of the same name cannot be accessed. This makes it difficult to test anything other than the top-level functions that are imported into pymc/__init__.py (sample and iter_sample). Conflicts: pymc/__init__.py --- pymc/__init__.py | 3 +- pymc/{sampling.py => sample.py} | 0 pymc/tests/test_sampling.py | 56 ++------------------------------- 3 files changed, 3 insertions(+), 56 deletions(-) rename pymc/{sampling.py => sample.py} (100%) diff --git a/pymc/__init__.py b/pymc/__init__.py index cb0e468c69..44760fddf3 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -4,9 +4,8 @@ from .distributions import * from .math import * - -from .sampling import * from .trace import * +from .sample import * from .step_methods import * from .tuning import * diff --git a/pymc/sampling.py b/pymc/sample.py similarity index 100% rename from pymc/sampling.py rename to pymc/sample.py diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 8b5a40ebf9..0e9b4d5bfd 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -6,9 +6,7 @@ import mock import unittest -import pymc -from pymc import sampling -from pymc.sampling import sample +from pymc import sample, iter_sample from .models import simple_init # Test if multiprocessing is available @@ -20,56 +18,6 @@ test_parallel = False -@mock.patch('pymc.sampling._sample') -def test_sample_check_full_signature_single_thread(sample_func): - sample('draws', 'step', start='start', db='db', threads=1, chain=1, - tune='tune', progressbar='progressbar', model='model', - variables='variables', random_seed='random_seed') - sample_func.assert_called_with('draws', 'step', 'start', 'db', 1, - 'tune', 'progressbar', 'model', 'variables', - 'random_seed') - - -@mock.patch('pymc.sampling._thread_sample') -def test_sample_check_ful_signature_multithreads(sample_func): - sample('draws', 'step', start='start', db='db', threads=2, chain=1, - tune='tune', progressbar='progressbar', model='model', - variables='variables', random_seed=0) - - args = sample_func.call_args_list[0][0] - assert args[0] == 2 - - expected_argset = [('draws', 'step', 'start', 'db', 1, 'tune', - False, 'model', 'variables', 0), - ('draws', 'step', 'start', 'db', 2, 'tune', - False, 'model', 'variables', 0)] - argset = list(args[1]) - print(argset) - print(expected_argset) - assert argset == expected_argset - - -def test_soft_update_all_present(): - start = {'a': 1, 'b': 2} - test_point = {'a': 3, 'b': 4} - sampling._soft_update(start, test_point) - assert start == {'a': 1, 'b': 2} - - -def test_soft_update_one_missing(): - start = {'a': 1, } - test_point = {'a': 3, 'b': 4} - sampling._soft_update(start, test_point) - assert start == {'a': 1, 'b': 4} - - -def test_soft_update_empty(): - start = {} - test_point = {'a': 3, 'b': 4} - sampling._soft_update(start, test_point) - assert start == test_point - - def test_sample(): model, start, step, _ = simple_init() @@ -87,6 +35,6 @@ def test_sample(): def test_iter_sample(): model, start, step, _ = simple_init() - samps = sampling.iter_sample(5, step, start, model=model) + samps = iter_sample(5, step, start, model=model) for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." From ad02151b349162c83141379b0c198af4edb9a2c8 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Fri, 10 Jan 2014 15:38:44 -0500 Subject: [PATCH 15/18] Simplify backend classes Changes in this commit are intended to simplify the backend storage class. - Reduce the number of storage class methods that need to be overridden. Now only `record` must be defined. During sampling, `setup` and `close` are also called, so the object should have these methods, but they do not need to do anything. - Sampling returns the storage object's `trace` attribute. In all the backends provided, this is a base.Trace object inherited to define value access methods specific to that backend. - As long as the methods above are defined, the storage object will work. This gives more flexibility to implement the backend storage class, so long as the `record` method properly stores the values. However, the setup in backends.base.Backend.__init__ should be useful to most backends (because of access to model information). - The load functions have been modified so that they only work if a model is supplied or if within model context, which removes the option to load the values but not connect to the model. - The base Trace object still provides a lot of structure. This is meant to help create a child Trace object for a backend that behaves the same as Traces from other backends. This means that the user can select values in the same way regardless of the backend. However, it is still possible for the user to create a very different Trace backend (assigned to the sampling objects trace attribute), as long as it has a `merge_chains` method to combine the results from parallel sampling. --- pymc/backends/__init__.py | 66 ++++++---- pymc/backends/base.py | 203 +++++++++-------------------- pymc/backends/ndarray.py | 102 +++++++++++---- pymc/backends/sqlite.py | 139 +++++++++++--------- pymc/backends/text.py | 14 +- pymc/examples/sqlite_dump_load.py | 4 +- pymc/sample.py | 15 +-- pymc/tests/test_base_backend.py | 149 --------------------- pymc/tests/test_ndarray_backend.py | 85 ++++++++---- pymc/tests/test_sqlite_backend.py | 107 +++++++-------- pymc/tests/test_text_backend.py | 2 +- 11 files changed, 385 insertions(+), 501 deletions(-) delete mode 100644 pymc/tests/test_base_backend.py diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index 92e6d2ef6f..8998f26108 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -25,9 +25,9 @@ Selecting values from a backend ------------------------------- -After a backend is finished sampling, values can be accessed in a few -ways. The easiest way is to index the backend object with a variable or -variable name. +After a backend is finished sampling, it returns a Trace object. Values +can be accessed in a few ways. The easiest way is to index the backend +object with a variable or variable name. >>> trace['x'] # or trace[x] @@ -79,37 +79,34 @@ Writing custom backends ----------------------- -To write a custom backend, two base classes should be inherited: -pymc.backends.base.Backend and pymc.backends.base.Trace. The first class -handles sampling, while the second provides access to the sampled -values. +Backends consist of two classes: one that handles storing the sample +results (e.g., backends.ndarray.NDArray or backends.sqlite.SQLite) and +one that handles value selection (e.g., backends.ndarray.Trace or +backends.sqlite.Trace). -These following sampling-related methods of base.Backend should be -define in the child class: +Three methods of the storage class will be called: -- _initialize_trace: Return a trace object for to store the sampled - values. +- setup: Before sampling is started, the `setup` method will be called + with two arguments: the number of draws and the chain number. This is + useful setting up any structure for storing the sampling values that + require the above information. -- _create_trace: Create the trace object for a specific variable and - chain. For example, the NumPy array backend creates an array of zeros - shaped according to the number of planned iterations and the shape of - the given variable. - -- _store_value: Store the value for a draw of a particular variable - (using the trace from `_create_trace`). - -- commit: After a set amount of iterations, the sampling results will be - committed to the backend. In the case of in memory backends (NumPy and - Text), this doesn't do anything. +- record: Record the sampling results for the current draw. This method + will be called with a dictionary of values mapped to the variable + names. This is the only function that *must* do something to have a + meaningful backend. - close: This method is called following sampling and should perform any actions necessary for finalizing and cleaning up the backend. -If backend-specific initialization is required, redefine `__init__` to -include this and the call the parent `__init__` method. +The base storage class `backends.base.Backend` provides model setup that +is used by PyMC backends. -In addition to sampling methods, several methods in base.Trace should -also be defined. +After sampling has completed, the `trace` attribute of the storage +object will be returned. To have a consistent interface with the backend +trace objects in PyMC, this attribute should be an instance of a class +that inherits from pymc.backends.base.Trace, and several methods in the +inherited Trace object should be defined. - get_values: This is the core method for selecting values from the backend. It can be called directly and is used by __getitem__ when the @@ -121,6 +118,23 @@ - point: Returns values for each variables at a single iteration. This is called if the backend is indexed with a single integer. +- __len__: This should return the number of draws (for the default + chain). + +- chains: Property that returns a list of chains + +In addtion, a `merge_chains` method should be defined if the backend +will be used with parallel sampling. This method describes how to merge +sampling chains from a list of other traces. + +As mentioned above, the only method necessary to store the sampling +values is `record`. Other methods in the storage may consist of only a +pass statement. The storage object should have an attribute `trace` +(with a `merge_chains` method for parallel sampling), but this does not +have to do anything if storing the values is all that is desired. The +backends.base.Trace is provided for convenience in setting up a +consistent Trace object. + For specific examples, see pymc.backends.{ndarray,text,sqlite}.py. """ from pymc.backends.ndarray import NDArray diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 5afe8ec371..34eda3fd48 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -1,124 +1,57 @@ """Base backend for traces -These are the base classes for all trace backends. They define all the -required methods for sampling and value selection that should be -overridden or implementented in children classes. See the docstring for -pymc.backends for more information (includng creating custom backends). +See the docstring for pymc.backends for more information (includng +creating custom backends). """ import numpy as np from pymc.model import modelcontext class Backend(object): + """Base storage class + Parameters + ---------- + name : str + Name of backend. + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ def __init__(self, name, model=None, variables=None): self.name = name - ## model attributes - self.variables = None - self.var_names = None - self.var_shapes = None - self._fn = None - model = modelcontext(model) - self.model = model - if model: - self._setup_model(model, variables) - - ## set by setup_samples - self.chain = None - self.trace = None - - self._draws = {} - - def _setup_model(self, model, variables): if variables is None: variables = model.unobserved_RVs self.variables = variables self.var_names = [str(var) for var in variables] - self._fn = model.fastfn(variables) + self.fn = model.fastfn(variables) - var_values = zip(self.var_names, self._fn(model.test_point)) + ## get variable shapes. common enough that I think most backends + ## will use this + var_values = zip(self.var_names, self.fn(model.test_point)) self.var_shapes = {var: value.shape for var, value in var_values} + self.chain = None + self.trace = None - def setup_samples(self, draws, chain): - """Prepare structure to store traces + def setup(self, draws, chain): + """Perform chain-specific setup - Parameters - ---------- draws : int - Number of sampling iterations + Expected number of draws chain : int - Chain number to store trace under + chain number """ - self.chain = chain - self._draws[chain] = draws - - if self.trace is None: - self.trace = self._initialize_trace() - trace = self.trace - trace._draws[chain] = draws - trace.backend = self - - trace.samples[chain] = {} - for var_name, var_shape in self.var_shapes.items(): - trace_shape = [draws] + list(var_shape) - trace.samples[chain][var_name] = self._create_trace(chain, - var_name, - trace_shape) + pass - def record(self, point, draw): - """Record the value of the current iteration + def record(self, point): + """Record results of a sampling iteration - Parameters - ---------- point : dict - Map of point values to variable names - draw : int - Current sampling iteration - """ - for var_name, value in zip(self.var_names, self._fn(point)): - self._store_value(draw, - self.trace.samples[self.chain][var_name], - value) - - def clean_interrupt(self, current_draw): - """Clean up sampling after interruption - - Perform any clean up not taken care of by `close`. After - KeyboardInterrupt, `sample` calls `close`, so `close` should not - be called here. - """ - self.trace._draws[self.chain] = current_draw - - ## Sampling methods that children must define - - def _initialize_trace(self): - raise NotImplementedError - - def _create_trace(self, chain, var_name, shape): - """Create trace for a variable - - Parameters - ---------- - chain : int - Current chain number - var_name : str - Name of variable - shape : tuple - Shape of the trace. The first element corresponds to the - number of draws. - """ - raise NotImplementedError - - def _store_value(self, draw, var_trace, value): - raise NotImplementedError - - def commit(self): - """Commit samples to backend - - This is called at set intervals during sampling. + Values mappled to variable names """ raise NotImplementedError @@ -127,7 +60,7 @@ def close(self): This is called after sampling has finished. """ - raise NotImplementedError + pass class Trace(object): @@ -140,12 +73,8 @@ class Trace(object): Attributes ---------- - backend : Backend object var_names - var_shapes : dict - Map of variables shape to variable names - samples : dict of dicts - Sample values keyed by chain and variable name + backend : Backend object nchains : int Number of sampling chains chains : list of ints @@ -157,27 +86,10 @@ class Trace(object): """ def __init__(self, var_names, backend=None): self.var_names = var_names - - self.samples = {} - self._draws = {} self.backend = backend self._active_chains = [] self._default_chain = None - @property - def nchains(self): - """Number of chains - - A chain is created for each sample call (including parallel - threads). - """ - return len(self.samples) - - @property - def chains(self): - """All chains in trace""" - return list(self.samples.keys()) - @property def default_chain(self): """Default chain to use for operations that require one chain (e.g., @@ -206,8 +118,14 @@ def active_chains(self, values): except TypeError: self._active_chains = [values] - def __len__(self): - return self._draws[self.default_chain] + @property + def nchains(self): + """Number of chains + + A chain is created for each sample call (including parallel + threads). + """ + return len(self.chains) def __getitem__(self, idx): if isinstance(idx, slice): @@ -223,6 +141,14 @@ def __getitem__(self, idx): ## Selection methods that children must define + @property + def chains(self): + """All chains in trace""" + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, squeeze=True): """Get values from samples @@ -259,32 +185,25 @@ def point(self, idx, chain=None): """ raise NotImplementedError + def merge_chains(traces): + """Merge chains from trace instances -def merge_chains(traces): - """Merge chains from trace instances + Parameters + ---------- + traces : list + Backend trace instances. Each instance should have only one + chain, and all chain numbers should be unique. - Parameters - ---------- - traces : list - Backend trace instances. Each instance should have only one - chain, and all chain numbers should be unique. - - Raises - ------ - ValueError is raised if any traces have the same current chain - number. - - Returns - ------- - Backend instance with merge chains - """ - base_trace = traces[0] - for new_trace in traces[1:]: - new_chain = new_trace.chains[0] - if new_chain in base_trace.samples: - raise ValueError('Trace chain numbers conflict.') - base_trace.samples[new_chain] = new_trace.samples[new_chain] - return base_trace + Raises + ------ + ValueError is raised if any traces have the same current chain + number. + + Returns + ------- + Backend instance with merge chains + """ + raise NotImplementedError def _squeeze_cat(results, combine, squeeze): diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index ebd8b7d10d..284be35e22 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -7,31 +7,57 @@ class NDArray(base.Backend): - + """NDArray storage + + Parameters + ---------- + name : str + Name of backend. + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ ## make `name` an optional argument for NDArray def __init__(self, name=None, model=None, variables=None): super(NDArray, self).__init__(name, model, variables) - def _initialize_trace(self): - return Trace(self.var_names) - - def _create_trace(self, chain, var_name, shape): - return np.zeros(shape) + self.trace = Trace(self.var_names) + self.draw_idx = 0 + self.draws = None - def _store_value(self, draw, var_trace, value): - var_trace[draw] = value + def setup(self, draws, chain): + """Perform chain-specific setup - def commit(self): - pass + draws : int + Expected number of draws + chain : int + chain number + """ + self.draws = draws + self.chain = chain + ## Make array of zeros for each variable + var_arrays = {} + for var_name, shape in self.var_shapes.items(): + var_arrays[var_name] = np.zeros((draws, ) + shape) + self.trace.samples[chain] = var_arrays + + def record(self, point): + """Record results of a sampling iteration + + point : dict + Values mappled to variable names + """ + for var_name, value in zip(self.var_names, self.fn(point)): + self.trace.samples[self.chain][var_name][self.draw_idx] = value + self.draw_idx += 1 def close(self): - pass - - def clean_interrupt(self, current_draw): - super(NDArray, self).clean_interrupt(current_draw) + if self.draw_idx == self.draws - 1: + return + ## Remove trailing zeros if interrupted before completed all draws traces = self.trace.samples[self.chain] - ## get rid of trailing zeros - traces = {var: trace[:current_draw] for var, trace in traces.items()} + traces = {var: trace[:self.draw_idx] for var, trace in traces.items()} self.trace.samples[self.chain] = traces @@ -39,14 +65,18 @@ class Trace(base.Trace): __doc__ = 'NumPy array trace\n' + base.Trace.__doc__ + def __init__(self, var_names, backend=None): + super(Trace, self).__init__(var_names, backend) + self.samples = {} # chain -> var name -> values + def __len__(self): - try: - return super(Trace, self).__len__() - except KeyError: - var_name = self.var_names[0] - draws = self.samples[self.default_chain][var_name].shape[0] - self._draws[self.default_chain] = draws - return draws + var_name = self.var_names[0] + return self.samples[self.default_chain][var_name].shape[0] + + @property + def chains(self): + """All chains in trace""" + return list(self.samples.keys()) def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, squeeze=True): @@ -86,12 +116,10 @@ def _slice(self, idx): sliced._default_chain = sliced._default_chain sliced.samples = {} - sliced._draws = {} for chain, trace in self.samples.items(): sliced_values = {var_name: values[idx] for var_name, values in trace.items()} sliced.samples[chain] = sliced_values - sliced._draws[chain] = sliced_values[self.var_names[0]].shape[0] return sliced def point(self, idx, chain=None): @@ -104,3 +132,27 @@ def point(self, idx, chain=None): chain = self.default_chain return {var_name: values[idx] for var_name, values in self.samples[chain].items()} + + def merge_chains(self, traces): + """Merge chains from trace instances + + Parameters + ---------- + traces : list + Backend trace instances. Each instance should have only one + chain, and all chain numbers should be unique. + + Raises + ------ + ValueError is raised if any traces have the same current chain + number. + + Returns + ------- + Backend instance with merge chains + """ + for new_trace in traces: + new_chain = new_trace.chains[0] + if new_chain in self.samples: + raise ValueError('Trace chain numbers conflict.') + self.samples[new_chain] = new_trace.samples[new_chain] diff --git a/pymc/backends/sqlite.py b/pymc/backends/sqlite.py index 0d02b9ae42..af6d7ce960 100644 --- a/pymc/backends/sqlite.py +++ b/pymc/backends/sqlite.py @@ -52,16 +52,55 @@ class SQLite(base.Backend): + """SQLite storage + + Parameters + ---------- + name : str + Name of database file + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ def __init__(self, name, model=None, variables=None): super(SQLite, self).__init__(name, model, variables) - ## initialized by _connect + + self.trace = Trace(self.var_names, self) + + ## These are set in `setup` to avoid sqlite3.OperationalError + ## (Base Cursor.__init__ not called) when performing parallel + ## sampling self.conn = None self.cursor = None self.connected = False - def _initialize_trace(self): - return Trace(self.var_names, self) + self.var_inserts = {} # var_name -> insert query + self.draw_idx = 0 + + def setup(self, draws, chain): + """Perform chain-specific setup + + draws : int + Expected number of draws + chain : int + chain number + """ + self.connect() + table = QUERIES['table'] + insert = QUERIES['insert'] + for var_name, shape in self.var_shapes.items(): + var_cols = _create_colnames(shape) + var_float = ', '.join([v + ' FLOAT' for v in var_cols]) + ## Create table + self.cursor.execute(table.format(table=var_name, + value_cols=var_float)) + ## Create insert query for each variable + var_str = ', '.join(var_cols) + self.var_inserts[var_name] = insert.format(table=var_name, + value_cols=var_str, + chain=chain) def connect(self): if self.connected: @@ -71,38 +110,28 @@ def connect(self): self.cursor = self.conn.cursor() self.connected = True - ## sampling methods + def record(self, point): + """Record results of a sampling iteration - def setup_samples(self, draws, chain): - ## make connection here (versus __init__) to handle parallel - ## chains - self.connect() - super(SQLite, self).setup_samples(draws, chain) - - def _create_trace(self, chain, var_name, shape): - ## first element of trace is number of draws - var_cols = create_colnames(shape[1:]) - var_float = ', '.join([v + ' FLOAT' for v in var_cols]) - self.cursor.execute(QUERIES['table'].format(table=var_name, - value_cols=var_float)) - return QUERIES['insert'].format(table=var_name, - value_cols=', '.join(var_cols), - chain=chain) - - def _store_value(self, draw, var_trace, value): - val_str = ', '.join(['{}'.format(val) for val in np.ravel(value)]) - query = var_trace.format(draw=draw, value=val_str) - self.cursor.execute(query) - - def commit(self): - self.conn.commit() + point : dict + Values mappled to variable names + """ + for var_name, value in zip(self.var_names, self.fn(point)): + val_str = ', '.join(['{}'.format(val) for val in np.ravel(value)]) + query = self.var_inserts[var_name].format(draw=self.draw_idx, + value=val_str) + self.cursor.execute(query) + self.draw_idx += 1 + + if not self.draw_idx % 1000: + self.conn.commit() def close(self): if not self.connected: return self.cursor.close() - self.commit() + self.conn.commit() self.conn.close() self.connected = False @@ -111,16 +140,27 @@ class Trace(base.Trace): __doc__ = 'SQLite trace\n' + base.Trace.__doc__ + def __init__(self, var_names, backend=None): + super(Trace, self).__init__(var_names, backend) + self._len = None + self._chains = None + def __len__(self): - try: - return super(Trace, self).__len__() - except KeyError: # draws dictionary not set up + if self._len is None: query = QUERIES['max_draw'].format(table=self.var_names[0], chain=self.default_chain) self.backend.connect() - draws = self.backend.cursor.execute(query).fetchall()[0][0] + 1 - self._draws[self.default_chain] = draws - return draws + self._len = self.backend.cursor.execute(query).fetchall()[0][0] + 1 + return self._len + + @property + def chains(self): + """All chains in trace""" + if self._chains is None: + self.backend.connect() + var_name = self.var_names[0] # any variable should do + self._chains = _get_chain_list(self.backend.cursor, var_name) + return self._chains def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, squeeze=True): @@ -175,12 +215,9 @@ def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, **query_args) self.backend.cursor.execute(call) results.append(_rows_to_ndarray(self.backend.cursor)) - return base._squeeze_cat(results, combine, squeeze) def _slice(self, idx): - """Slice trace object - """ warnings.warn('Slice for SQLite backend has no effect.') def point(self, idx, chain=None): @@ -205,8 +242,11 @@ def point(self, idx, chain=None): _rows_to_ndarray(self.backend.cursor)) return var_values + def merge_chains(self, traces): + pass + -def create_colnames(shape): +def _create_colnames(shape): """Return column names based on `shape` Examples @@ -233,9 +273,7 @@ def load(name, model=None): name : str Path to SQLite database file model : Model - If None, the model is taken from the `with` context. The trace - can be loaded without connecting by passing False (although - connecting to the original model is recommended). + If None, the model is taken from the `with` context. Returns ------- @@ -243,25 +281,8 @@ def load(name, model=None): """ db = SQLite(name, model=model) db.connect() - var_names = _get_table_list(db.cursor) - trace = Trace(var_names, db) - var_cols = {var_name: ', '.join(_get_var_strs(db.cursor, var_name)) - for var_name in var_names} - - ## Use first var_names element to get chain list. Chains should be - ## the same for all. - chains = _get_chain_list(db.cursor, var_names[0]) - - query = QUERIES['insert'] - for chain in chains: - samples = {} - for var_name in var_names: - samples[var_name] = query.format(table=var_name, - value_cols=var_cols[var_name], - chain=chain) - trace.samples[chain] = samples - return trace + return Trace(var_names, db) def _get_table_list(cursor): diff --git a/pymc/backends/text.py b/pymc/backends/text.py index a0bac969ba..f7e4978c91 100644 --- a/pymc/backends/text.py +++ b/pymc/backends/text.py @@ -23,7 +23,17 @@ class Text(NDArray): + """Text storage + Parameters + ---------- + name : str + Name of directory to store text files. + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ def __init__(self, name, model=None, variables=None): super(Text, self).__init__(name, model, variables) if not os.path.exists(name): @@ -56,9 +66,7 @@ def load(name, chains=None, model=None): chains : list or None Chains to load. If None, all chains are loaded. model : Model - If None, the model is taken from the `with` context. The trace - can be loaded without connecting by passing False (although - connecting to the original model is recommended). + If None, the model is taken from the `with` context. Returns ------- diff --git a/pymc/examples/sqlite_dump_load.py b/pymc/examples/sqlite_dump_load.py index 9a678e35ed..b46636d1f1 100644 --- a/pymc/examples/sqlite_dump_load.py +++ b/pymc/examples/sqlite_dump_load.py @@ -37,8 +37,8 @@ def run(n=50): for chain in trace.chains: for var_name in trace.var_names: - data = trace.samples[chain][var_name] - dumped_data = dumped.samples[chain][var_name] + data = trace.get_values(var_name, chains=[chain]) + dumped_data = dumped.get_values(var_name, chains=[chain]) npt.assert_equal(data, dumped_data) finally: try: diff --git a/pymc/sample.py b/pymc/sample.py index 30a183bb8e..f3138a0a29 100644 --- a/pymc/sample.py +++ b/pymc/sample.py @@ -1,6 +1,5 @@ from .point import * from pymc.backends.ndarray import NDArray -from pymc.backends.base import merge_chains import multiprocessing as mp from time import time from .core import * @@ -24,7 +23,7 @@ def sample(draws, step, start=None, db=None, chain=0, threads=1, tune=None, Starting point in parameter space (or partial point). Defaults to model.test_point. db : backend - If None, NDArray is used. + A storage backend. If None, NDArray is used. chain : int Chain number used to store sample in trace. If threads greater than one, chain numbers will start here @@ -44,7 +43,7 @@ def sample(draws, step, start=None, db=None, chain=0, threads=1, tune=None, Returns ------- - Backend object with access to sampling values + Trace object with access to sampling values """ if threads is None: threads = max(mp.cpu_count() - 2, 1) @@ -91,7 +90,6 @@ def _sample(draws, step, start=None, db=None, chain=0, tune=None, for i, trace in sampling: pass except KeyboardInterrupt: - trace.backend.clean_interrupt(i) trace.backend.close() return trace @@ -158,15 +156,13 @@ def _iter_sample(draws, step, start=None, db=None, chain=0, tune=None, if db is None: db = NDArray(model=model, variables=variables) - db.setup_samples(draws, chain) + db.setup(draws, chain) for i in range(draws): if i == tune: step = stop_tuning(step) point = step.step(point) - db.record(point, i) - if not i % 1000: - db.commit() + db.record(point) yield db.trace else: db.close() @@ -176,7 +172,8 @@ def _thread_sample(threads, args): p = mp.Pool(threads) traces = p.map(_argsample, args) p.close() - return merge_chains(traces) + traces[0].merge_chains(traces[1:]) + return traces[0] def _argsample(args): diff --git a/pymc/tests/test_base_backend.py b/pymc/tests/test_base_backend.py deleted file mode 100644 index 9493bbe537..0000000000 --- a/pymc/tests/test_base_backend.py +++ /dev/null @@ -1,149 +0,0 @@ -import numpy as np -try: - import unittest.mock as mock # py3 -except ImportError: - import mock -import unittest -import nose - -from pymc.backends import base - - -class TestBaseInit(unittest.TestCase): - - def setUp(self): - self.variables = ['x', 'y'] - self.model = mock.Mock() - self.model.unobserved_RVs = self.variables - self.model.fastfn = mock.MagicMock() - - def test_base_init_just_name(self): - with mock.patch('pymc.backends.base.modelcontext') as context: - variables = self.variables - context.return_value = self.model - - db = base.Backend('name') - - context.assert_called_once_with(None) - self.assertEqual(db.variables, variables) - self.assertEqual(db.var_names, variables) - self.model.fastfn.assert_called_once_with(variables) - - def test_base_init_model_supplied(self): - db = base.Backend('name', model=self.model) - - self.assertEqual(db.variables, self.variables) - self.assertEqual(db.var_names, self.variables) - self.model.fastfn.assert_called_once_with(self.variables) - - def test_base_init_variables_supplied(self): - with mock.patch('pymc.backends.base.modelcontext') as context: - variables = ['a', 'b'] - context.return_value = self.model - - db = base.Backend('name', variables=variables) - - context.assert_called_once_with(None) - self.assertEqual(db.variables, variables) - self.assertEqual(db.var_names, variables) - self.model.fastfn.assert_called_once_with(variables) - - def test_base_setup_samples_default_chain(self): - with mock.patch('pymc.backends.base.modelcontext') as context: - variables = ['a', 'b'] - context.return_value = self.model - - db = base.Backend('name', variables=variables) - - db._create_trace = mock.Mock() - db.var_shapes = {'x': (), 'y': (10,)} - draws = 3 - - patch = mock.patch('pymc.backends.base.Backend._initialize_trace') - with patch as init_trace: - db.setup_samples(draws, 0) - - init_trace.assert_called_with() - db._create_trace.assert_any_call(0, 'x', [draws]) - db._create_trace.assert_any_call(0, 'y', [draws, 10]) - - -class TestBaseTrace(unittest.TestCase): - - def setUp(self): - var_names = ['x'] - self.trace = base.Trace(var_names) - self.trace.samples = {0: {'x': None}} - - def test_nchains(self): - - self.assertEqual(self.trace.nchains, 1) - - self.trace.samples[1] = {'y': None} - self.assertEqual(self.trace.nchains, 2) - - def test_chains(self): - self.assertEqual(self.trace.chains, [0]) - - self.trace.samples[1] = {'y': None} - self.assertEqual(self.trace.chains, [0, 1]) - - def test_chains_not_sequential(self): - self.trace.samples[4] = {'y': None} - self.assertEqual(self.trace.chains, [0, 4]) - - def test_default_chain_one_chain(self): - self.assertEqual(self.trace.default_chain, 0) - - def test_default_chain_multiple_chain(self): - self.trace.samples[1] = {'y': None} - self.assertEqual(self.trace.default_chain, 1) - - def test_default_chain_multiple_chains_set(self): - self.trace.samples[1] = {'y': None} - self.trace.default_chain = 0 - self.assertEqual(self.trace.default_chain, 0) - - def test_active_chains(self): - self.assertEqual(self.trace.chains, self.trace.active_chains) - self.trace.samples[1] = {'y': None} - self.assertEqual(self.trace.chains, self.trace.active_chains) - - def test_active_chains_set_with_int(self): - self.trace.samples[1] = {'y': None} - self.trace.active_chains = 0 - self.assertEqual(self.trace.active_chains, [0]) - - -class TestMergeChains(unittest.TestCase): - - def test_merge_chains_one_trace(self): - trace = mock.Mock() - trace.samples = {0: {'x': 0, 'y': 1}} - merged = base.merge_chains([trace]) - self.assertEqual(trace.samples, merged.samples) - - def test_merge_chains_two_traces(self): - trace1 = mock.Mock() - trace1.samples = {0: {'x': 0, 'y': 1}} - trace1.chains = [0] - - trace2 = mock.Mock() - trace2.samples = {1: {'x': 3, 'y': 4}} - trace2.chains = [1] - - merged = base.merge_chains([trace1, trace2]) - self.assertEqual(trace1.samples[0], merged.samples[0]) - self.assertEqual(trace2.samples[1], merged.samples[1]) - - def test_merge_chains_two_traces_same_slot(self): - trace1 = mock.Mock() - trace1.samples = {0: {'x': 0, 'y': 1}} - trace1.chains = [0] - - trace2 = mock.Mock() - trace2.samples = {0: {'x': 3, 'y': 4}} - trace2.chains = [0] - - with self.assertRaises(ValueError): - base.merge_chains([trace1, trace2]) diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py index 0dba0055e6..4859d098d2 100644 --- a/pymc/tests/test_ndarray_backend.py +++ b/pymc/tests/test_ndarray_backend.py @@ -21,50 +21,45 @@ def setUp(self): context.return_value = self.model self.db = ndarray.NDArray() - def test_create_trace_scalar(self): + def test_setup_scalar(self): db = self.db - draws = 3 - trace = db._create_trace(chain=0, var_name=None, shape=[draws]) - npt.assert_equal(trace, np.zeros(draws)) - - def test_create_trace_1d(self): - db = self.db - draws = 3 - trace = db._create_trace(chain=0, var_name=None, shape=[draws, 2]) - npt.assert_equal(trace, np.zeros([draws, 2])) + db.var_shapes = {'x': ()} + draws, chain = 3, 0 + db.setup(draws, chain) + npt.assert_equal(db.trace.samples[chain]['x'], np.zeros(draws)) - def test_setup_samples(self): + def test_setup_1d(self): db = self.db - draws = 3 - - db.var_shapes = {'x': (), 'y': (4,)} - db.setup_samples(draws, chain=0) - - npt.assert_equal(db.trace['x'], np.zeros([draws])) - npt.assert_equal(db.trace['y'], np.zeros([draws, 4])) + shape = (2,) + db.var_shapes = {'x': shape} + draws, chain = 3, 0 + db.setup(draws, chain) + npt.assert_equal(db.trace.samples[chain]['x'], np.zeros((draws,) + shape)) def test_record(self): db = self.db draws = 3 db.var_shapes = {'x': (), 'y': (4,)} - db.setup_samples(draws, chain=0) + db.setup(draws, chain=0) def just_ones(*args): while True: yield 1. - db._fn = just_ones + db.fn = just_ones + db.draw_idx = 0 - db.record(point=None, draw=0) - npt.assert_equal(1., db.trace.get_values('x', combine=True)[0]) + db.record(point=None) + npt.assert_equal(1., db.trace.get_values('x')[0]) npt.assert_equal(np.ones(4), db.trace['y'][0]) def test_clean_interrupt(self): db = self.db - db.setup_samples(draws=3, chain=0) + db.setup(draws=10, chain=0) db.trace.samples = {0: {'x': np.zeros(10), 'y': np.zeros((10, 5))}} - db.clean_interrupt(3) + db.draw_idx = 3 + db.close() npt.assert_equal(np.zeros(3), db.trace['x']) npt.assert_equal(np.zeros((3, 5)), db.trace['y']) @@ -83,6 +78,12 @@ def setUp(self): self.var_names = var_names self.var_shapes = var_shapes + def test_chains_single_chain(self): + self.trace.chains == [0] + + def test_nchains_single_chain(self): + self.trace.nchains == 1 + def test_get_values_default(self): base_shape = (self.draws,) xshape = self.var_shapes['x'] @@ -188,6 +189,12 @@ def setUp(self): self.var_shapes = var_shapes self.total_draws = 2 * draws + def test_chains_multichain(self): + self.trace.chains == [0, 1] + + def test_nchains_multichain(self): + self.trace.nchains == [0, 1] + def test_get_values_multi_default(self): sample = self.trace.get_values('x') xshape = self.var_shapes['x'] @@ -318,3 +325,33 @@ def test_multichain_slice(self): for var_name, var_shape in self.var_shapes.items(): npt.assert_equal(sliced.samples[chain][var_name], expected[chain][var_name]) + +class TestMergeChains(unittest.TestCase): + + def setUp(self): + var_names = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + self.trace1 = ndarray.Trace(var_names) + self.trace1.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + + self.trace2 = ndarray.Trace(var_names) + self.trace2.samples = {1: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.draws = draws + self.var_names = var_names + self.var_shapes = var_shapes + self.total_draws = 2 * draws + + def test_merge_chains_two_traces(self): + self.trace1.merge_chains([self.trace2]) + self.assertEqual(self.trace1.samples[1], self.trace2.samples[1]) + + def test_merge_chains_two_traces_same_slot(self): + self.trace2.samples = self.trace1.samples + + with self.assertRaises(ValueError): + self.trace1.merge_chains([self.trace2]) diff --git a/pymc/tests/test_sqlite_backend.py b/pymc/tests/test_sqlite_backend.py index 8aa5765421..16b0da4556 100644 --- a/pymc/tests/test_sqlite_backend.py +++ b/pymc/tests/test_sqlite_backend.py @@ -22,6 +22,7 @@ def setUp(self): context.return_value = self.model self.db = sqlite.SQLite('test.db') self.db.cursor = mock.Mock() + self.db.var_shapes = {'x': (), 'y': (3,)} connect_patch = mock.patch('pymc.backends.sqlite.SQLite.connect') self.addCleanup(connect_patch.stop) @@ -32,84 +33,88 @@ def setUp(self): class TestSQLiteSample(SQLiteTestCase): def test_setup_trace(self): - self.db.setup_samples(self.draws, chain=0) + self.db.setup(self.draws, chain=0) self.connect.assert_called_once_with() - def test__create_trace_scalar(self): + def test_setup_scalar(self): db = self.db - var_trace = db._create_trace(chain=0, var_name='x', - shape=(self.draws,)) - + db.setup(draws=3, chain=0) tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' 'draw INTEGER, ' 'chain INT(5), v1 FLOAT)') - db.cursor.execute.assert_called_once_with(tbl_expected) + db.cursor.execute.assert_any_call(tbl_expected) trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' 'VALUES (NULL, {draw}, 0, {value})') - self.assertEqual(var_trace, trace_expected) + self.assertEqual(db.var_inserts['x'], trace_expected) - def test__create_trace_1d(self): + def test_setup_1d(self): db = self.db - var_trace = db._create_trace(chain=0, var_name='x', - shape=(self.draws, 2)) - tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + db.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [y] ' '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' 'draw INTEGER, ' - 'chain INT(5), v1 FLOAT, v2 FLOAT)') - db.cursor.execute.assert_called_once_with(tbl_expected) + 'chain INT(5), v1 FLOAT, v2 FLOAT, v3 FLOAT)') + db.cursor.execute.assert_any_call(tbl_expected) - trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1, v2) ' + trace_expected = ('INSERT INTO [y] (recid, draw, chain, v1, v2, v3) ' 'VALUES (NULL, {draw}, 0, {value})') - self.assertEqual(var_trace, trace_expected) + self.assertEqual(db.var_inserts['y'], trace_expected) - def test__create_trace_2d(self): + def test_setup_2d(self): db = self.db - var_trace = db._create_trace(chain=0, var_name='x', - shape=(self.draws, 2, 3)) + db.var_shapes = {'x': (2, 3)} + db.setup(draws=3, chain=0) tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' 'draw INTEGER, ' 'chain INT(5), ' 'v1_1 FLOAT, v1_2 FLOAT, v1_3 FLOAT, ' 'v2_1 FLOAT, v2_2 FLOAT, v2_3 FLOAT)') - db.cursor.execute.assert_called_once_with(tbl_expected) + db.cursor.execute.assert_any_call(tbl_expected) trace_expected = ('INSERT INTO [x] (recid, draw, chain, ' 'v1_1, v1_2, v1_3, ' 'v2_1, v2_2, v2_3) ' 'VALUES (NULL, {draw}, 0, {value})') - self.assertEqual(var_trace, trace_expected) + self.assertEqual(db.var_inserts['x'], trace_expected) - def test__store_value_scalar(self): + def test_record_scalar(self): db = self.db - db.setup_samples(draws=3, chain=0) + db.setup(draws=3, chain=0) + + db.fn = mock.Mock(return_value=iter([3.])) var_name = 'x' + db.var_names = ['x'] + query = sqlite.QUERIES['insert'].format(table=var_name, value_cols='v1', chain=0) - db.trace.samples[0] = {'x': query} - db._store_value(draw=0, var_trace=db.trace.samples[0][var_name], - value=3.) + db.var_inserts = {'x': query} + db.draw_idx = 0 + db.record({'x': 3.}) expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' 'VALUES (NULL, 0, 0, 3.0)') - db.cursor.execute.assert_called_once_with(expected) + db.cursor.execute.assert_any_call(expected) - def test__store_value_1d(self): + def test_record_1d(self): db = self.db - db.setup_samples(draws=3, chain=0) + db.setup(draws=3, chain=0) + + db.fn = mock.Mock(return_value=iter([[3., 3.]])) var_name = 'x' + db.var_names = ['x'] + query = sqlite.QUERIES['insert'].format(table=var_name, value_cols='v1, v2', chain=0) - db.trace.samples[0] = {'x': query} - print(db) - db._store_value(draw=0, var_trace=db.trace.samples[0][var_name], - value=[3., 3.]) + db.var_inserts = {'x': query} + db.draw_idx = 0 + db.record({'x': [3., 3.]}) expected = ('INSERT INTO [x] (recid, draw, chain, v1, v2) ' 'VALUES (NULL, 0, 0, 3.0, 3.0)') - db.cursor.execute.assert_called_once_with(expected) + db.cursor.execute.assert_any_call(expected) class SQLiteSelectionTestCase(SQLiteTestCase): @@ -117,13 +122,14 @@ class SQLiteSelectionTestCase(SQLiteTestCase): def setUp(self): super(SQLiteSelectionTestCase, self).setUp() self.db.var_shapes = {'x': (), 'y': (4,)} - self.db.setup_samples(self.draws, chain=0) + self.db.setup(self.draws, chain=0) ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') self.addCleanup(ndarray_patch.stop) ndarray_patch.start() self.draws = 5 + self.db.trace.active_chains = 0 class TestSQLiteSelection(SQLiteSelectionTestCase): @@ -180,7 +186,7 @@ class TestSQLiteSelectionMultipleChains(SQLiteSelectionTestCase): def setUp(self): super(TestSQLiteSelectionMultipleChains, self).setUp() - self.db.trace.samples[1] = self.db.trace.samples[0] + self.db.trace.active_chains = [0, 1] def test_get_values_default_keywords(self): self.db.trace.get_values('x') @@ -208,45 +214,24 @@ def setUp(self): self.table_list = table_list_patch.start() self.table_list.return_value = ['x', 'y'] - var_strs_list_patch = mock.patch('pymc.backends.sqlite._get_var_strs') - self.addCleanup(var_strs_list_patch.stop) - self.var_strs_list = var_strs_list_patch.start() - self.var_strs_list.return_value = ['v1', 'v2'] - - chain_list_patch = mock.patch('pymc.backends.sqlite._get_chain_list') - self.addCleanup(chain_list_patch.stop) - self.chain_list = chain_list_patch.start() - self.chain_list.return_value = [0, 1] - def test_load(self): trace = sqlite.load('test.db') - self.assertEqual(len(trace.samples), 2) - - self.assertTrue('x' in trace.samples[0]) - self.assertTrue('y' in trace.samples[0]) - - expected = ('INSERT INTO [{}] ' - '(recid, draw, chain, v1, v2) ' - 'VALUES (NULL, {{draw}}, {}, {{value}})') - for chain in [0, 1]: - for var_name in ['x', 'y']: - self.assertEqual(trace.samples[chain][var_name], - expected.format(var_name, chain)) - + assert self.table_list.called + assert self.db.called def test_create_column_empty(): - result = sqlite.create_colnames(()) + result = sqlite._create_colnames(()) expected = ['v1'] assert result == expected def test_create_column_1d(): - result = sqlite.create_colnames((2, )) + result = sqlite._create_colnames((2, )) expected = ['v1', 'v2'] assert result == expected def test_create_column_2d(): - result = sqlite.create_colnames((2, 2)) + result = sqlite._create_colnames((2, 2)) expected = ['v1_1', 'v1_2', 'v2_1', 'v2_2'] assert result == expected diff --git a/pymc/tests/test_text_backend.py b/pymc/tests/test_text_backend.py index 94f609872c..4ccb0a580e 100644 --- a/pymc/tests/test_text_backend.py +++ b/pymc/tests/test_text_backend.py @@ -43,7 +43,7 @@ def setUp(self): self.draws = 5 self.db.var_shapes = {'x': (), 'y': (4,)} - self.db.setup_samples(self.draws, chain=0) + self.db.setup(self.draws, chain=0) savetxt_patch = mock.patch('pymc.backends.text.np.savetxt') self.addCleanup(savetxt_patch.stop) From a1c14585e0b724d3ba874a6ced6b999ad7b0a08c Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Sun, 12 Jan 2014 22:44:06 -0500 Subject: [PATCH 16/18] Call NDArray close when close text When cleaning interrupt was merged with NDArray close, NDArray close call should have been added here so that interrupt is cleaned for Text too. --- pymc/backends/text.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/backends/text.py b/pymc/backends/text.py index f7e4978c91..5a753b4a7b 100644 --- a/pymc/backends/text.py +++ b/pymc/backends/text.py @@ -40,6 +40,7 @@ def __init__(self, name, model=None, variables=None): os.mkdir(name) def close(self): + super(Text, self).close() for chain in self.trace.chains: chain_name = 'chain-{}'.format(chain) chain_dir = os.path.join(self.name, chain_name) From 9133a4c14475df3cd0770f2dbab827205eb5667f Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Sun, 12 Jan 2014 22:47:42 -0500 Subject: [PATCH 17/18] Fix NDArray close check for interrupt This was off by one, which resulted in unnecessary slicing (although the result is the same). --- pymc/backends/ndarray.py | 2 +- pymc/tests/test_ndarray_backend.py | 9 +++++++++ pymc/tests/test_ndarray_sqlite_selection.py | 4 ++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index 284be35e22..e4f5894ed7 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -53,7 +53,7 @@ def record(self, point): self.draw_idx += 1 def close(self): - if self.draw_idx == self.draws - 1: + if self.draw_idx == self.draws: return ## Remove trailing zeros if interrupted before completed all draws traces = self.trace.samples[self.chain] diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py index 4859d098d2..6f6909ca44 100644 --- a/pymc/tests/test_ndarray_backend.py +++ b/pymc/tests/test_ndarray_backend.py @@ -63,6 +63,15 @@ def test_clean_interrupt(self): npt.assert_equal(np.zeros(3), db.trace['x']) npt.assert_equal(np.zeros((3, 5)), db.trace['y']) + def test_standard_close(self): + db = self.db + db.setup(draws=10, chain=0) + db.trace.samples = {0: {'x': np.zeros(10), 'y': np.zeros((10, 5))}} + db.draw_idx = 10 + db.close() + npt.assert_equal(np.zeros(10), db.trace['x']) + npt.assert_equal(np.zeros((10, 5)), db.trace['y']) + class TestNDArraySelection(unittest.TestCase): diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py index 3de2977ac1..c0288087a5 100644 --- a/pymc/tests/test_ndarray_sqlite_selection.py +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -56,6 +56,10 @@ def test_chain_length(self): assert self.ntrace.nchains == self.strace.nchains assert len(self.ntrace) == len(self.strace) + def test_number_of_draws(self): + assert self.ntrace['x'][0].shape[0] == self.draws + assert self.strace['x'][0].shape[0] == self.draws + def test_get_item(self): npt.assert_equal(self.ntrace['x'], self.strace['x']) From 8aa4d9a054986a91fe574670105f49b305fd9532 Mon Sep 17 00:00:00 2001 From: Kyle Meyer Date: Mon, 13 Jan 2014 15:08:47 -0500 Subject: [PATCH 18/18] Allow previous chains to be extended This commit enables sampling to extend the a chain from a previous call. This is done for the NumPy array by concatenating more zeros and setting the draw index to the right position. For SQLite, this involves setting the draw index to the right position for the given chain. --- pymc/backends/ndarray.py | 38 +++++++++++----- pymc/backends/sqlite.py | 43 ++++++++++++++---- pymc/examples/stochastic_volatility.py | 8 +--- pymc/tests/test_ndarray_backend.py | 50 ++++++++++++++++----- pymc/tests/test_ndarray_sqlite_selection.py | 25 +++++++---- pymc/tests/test_sqlite_backend.py | 11 ++++- pymc/tests/test_text_backend.py | 1 + 7 files changed, 128 insertions(+), 48 deletions(-) diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index e4f5894ed7..3960325d97 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -22,7 +22,7 @@ class NDArray(base.Backend): def __init__(self, name=None, model=None, variables=None): super(NDArray, self).__init__(name, model, variables) - self.trace = Trace(self.var_names) + self.trace = Trace(self.var_names, self) self.draw_idx = 0 self.draws = None @@ -34,13 +34,24 @@ def setup(self, draws, chain): chain : int chain number """ - self.draws = draws self.chain = chain - ## Make array of zeros for each variable - var_arrays = {} - for var_name, shape in self.var_shapes.items(): - var_arrays[var_name] = np.zeros((draws, ) + shape) - self.trace.samples[chain] = var_arrays + ## Concatenate new array if chain is already present. + if chain in self.trace.samples: + chain_trace = self.trace.samples[chain] + old_draws = len(self.trace) + self.draws = old_draws + draws + self.draws_idx = old_draws + for var_name, shape in self.var_shapes.items(): + old_trace = chain_trace[var_name] + new_trace = np.zeros((draws, ) + shape) + chain_trace[var_name] = np.concatenate((old_trace, new_trace), + axis=0) + else: # Otherwise, make array of zeros for each variable. + self.draws = draws + var_arrays = {} + for var_name, shape in self.var_shapes.items(): + var_arrays[var_name] = np.zeros((draws, ) + shape) + self.trace.samples[chain] = var_arrays def record(self, point): """Record results of a sampling iteration @@ -151,8 +162,13 @@ def merge_chains(self, traces): ------- Backend instance with merge chains """ + var_name = self.var_names[0] # Select any variable. for new_trace in traces: - new_chain = new_trace.chains[0] - if new_chain in self.samples: - raise ValueError('Trace chain numbers conflict.') - self.samples[new_chain] = new_trace.samples[new_chain] + for new_chain in new_trace.chains: + if new_chain in self.chains: + ## Take the new chain if it has more draws. + draws = self.samples[new_chain][var_name].shape[0] + new_draws = new_trace.samples[new_chain][var_name].shape[0] + if draws >= new_draws: + continue + self.samples[new_chain] = new_trace.samples[new_chain] diff --git a/pymc/backends/sqlite.py b/pymc/backends/sqlite.py index af6d7ce960..4e66f35b75 100644 --- a/pymc/backends/sqlite.py +++ b/pymc/backends/sqlite.py @@ -48,6 +48,8 @@ 'WHERE (chain={chain}) AND (draw={draw})'), 'max_draw': ('SELECT MAX(draw) FROM [{table}] ' 'WHERE chain={chain}'), + 'draw_count': ('SELECT COUNT(*) FROM [{table}] ' + 'WHERE chain={chain}'), } @@ -76,6 +78,7 @@ def __init__(self, name, model=None, variables=None): self.cursor = None self.connected = False + self._var_cols = {} self.var_inserts = {} # var_name -> insert query self.draw_idx = 0 @@ -87,15 +90,30 @@ def setup(self, draws, chain): chain : int chain number """ + self.chain = chain + + if not self._var_cols: # Table has not been created. + self._var_cols = {var_name: _create_colnames(shape) + for var_name, shape in self.var_shapes.items()} + self._create_table() + else: + self.draw_idx = self.trace._get_max_draw(chain) + 1 + self.trace._len = None + + self._create_insert_queries(chain) + + def _create_table(self): self.connect() table = QUERIES['table'] - insert = QUERIES['insert'] - for var_name, shape in self.var_shapes.items(): - var_cols = _create_colnames(shape) + for var_name, var_cols in self._var_cols.items(): var_float = ', '.join([v + ' FLOAT' for v in var_cols]) - ## Create table self.cursor.execute(table.format(table=var_name, value_cols=var_float)) + + def _create_insert_queries(self, chain): + self.connect() + insert = QUERIES['insert'] + for var_name, var_cols in self._var_cols.items(): ## Create insert query for each variable var_str = ', '.join(var_cols) self.var_inserts[var_name] = insert.format(table=var_name, @@ -147,12 +165,21 @@ def __init__(self, var_names, backend=None): def __len__(self): if self._len is None: - query = QUERIES['max_draw'].format(table=self.var_names[0], - chain=self.default_chain) - self.backend.connect() - self._len = self.backend.cursor.execute(query).fetchall()[0][0] + 1 + self._len = self._get_number_draws(self.default_chain) return self._len + def _get_max_draw(self, chain): + self.backend.connect() + query = QUERIES['max_draw'].format(table=self.var_names[0], + chain=chain) + return self.backend.cursor.execute(query).fetchall()[0][0] + + def _get_number_draws(self, chain): + self.backend.connect() + query = QUERIES['draw_count'].format(table=self.var_names[0], + chain=chain) + return self.backend.cursor.execute(query).fetchall()[0][0] + @property def chains(self): """All chains in trace""" diff --git a/pymc/examples/stochastic_volatility.py b/pymc/examples/stochastic_volatility.py index 1275ff3f91..a62692317f 100644 --- a/pymc/examples/stochastic_volatility.py +++ b/pymc/examples/stochastic_volatility.py @@ -121,16 +121,10 @@ def run(n=2000): with model: trace = sample(5, step, start, variables=model.vars + [sigma]) - ## FIXME: At the moment, there isn't a method for updating the - ## same trace. Below makes a new trace in the same backend that - ## has both the chains. The chain needs to be manually set to - ## avoid overwriting the previous chain. A check could be added - ## to override the chain argument to previous chain + 1. - # Start next run at the last sampled position. start2 = trace.point(-1) step2 = HamiltonianMC(model.vars, hessian(start2, 6), path_length=4.) - trace = sample(n, step2, start=start2, db=trace.backend, chain=1) + trace = sample(n, step2, db=trace.backend) # diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py index 6f6909ca44..e329ed5844 100644 --- a/pymc/tests/test_ndarray_backend.py +++ b/pymc/tests/test_ndarray_backend.py @@ -340,27 +340,53 @@ class TestMergeChains(unittest.TestCase): def setUp(self): var_names = ['x', 'y'] var_shapes = {'x': (), 'y': (2,)} - draws = 3 self.trace1 = ndarray.Trace(var_names) + + self.trace2 = ndarray.Trace(var_names) + self.var_names = var_names + self.var_shapes = var_shapes + + def test_merge_chains_two_traces(self): + draws = 3 self.trace1.samples = {0: {'x': np.zeros(draws), 'y': np.zeros((draws, 2))}} - - self.trace2 = ndarray.Trace(var_names) self.trace2.samples = {1: {'x': np.ones(draws), 'y': np.ones((draws, 2))}} - self.draws = draws - self.var_names = var_names - self.var_shapes = var_shapes - self.total_draws = 2 * draws - def test_merge_chains_two_traces(self): self.trace1.merge_chains([self.trace2]) self.assertEqual(self.trace1.samples[1], self.trace2.samples[1]) - def test_merge_chains_two_traces_same_slot(self): - self.trace2.samples = self.trace1.samples + def test_merge_chains_two_traces_same_slot_same_size(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.trace2.samples = {0: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.trace1.merge_chains([self.trace2]) + npt.assert_equal(self.trace1.samples[0]['x'], np.zeros(draws)) + + def test_merge_chains_two_traces_same_slot_base_longer(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.trace2.samples = {0: + {'x': np.ones(draws - 1), + 'y': np.ones((draws - 1, 2))}} + self.trace1.merge_chains([self.trace2]) + npt.assert_equal(self.trace1.samples[0]['x'], np.zeros(draws)) - with self.assertRaises(ValueError): - self.trace1.merge_chains([self.trace2]) + def test_merge_chains_two_traces_same_slot_new_longer(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws - 1), + 'y': np.zeros((draws - 1, 2))}} + self.trace2.samples = {0: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.trace1.merge_chains([self.trace2]) + npt.assert_equal(self.trace1.samples[0]['x'], np.ones(draws)) diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py index c0288087a5..d076a7ef35 100644 --- a/pymc/tests/test_ndarray_sqlite_selection.py +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -29,21 +29,28 @@ def setUpClass(cls): n = 1 model = pm.Model() + draws = 5 with model: x = pm.Normal('x', 0, 1., shape=n) - # start sampling at the MAP start = {'x': 0.} step = pm.Metropolis() cls.db = 'test.db' try: - cls.draws = 10 - cls.ntrace = pm.sample(cls.draws, step=step, - threads=threads, random_seed=4) - cls.strace = pm.sample(cls.draws, step=step, - threads=threads, random_seed=4, + cls.ntrace = pm.sample(draws, step=step, + threads=threads, random_seed=9) + cls.strace = pm.sample(draws, step=step, + threads=threads, random_seed=9, db=pm.backends.SQLite(cls.db)) + ## Extend each trace. + cls.ntrace = pm.sample(draws, step=step, + threads=threads, random_seed=4, + db=cls.ntrace.backend) + cls.strace = pm.sample(draws, step=step, + threads=threads, random_seed=4, + db=cls.strace.backend) + cls.draws = draws * 2 # Account for extension. except: remove_db(cls.db) raise @@ -57,8 +64,10 @@ def test_chain_length(self): assert len(self.ntrace) == len(self.strace) def test_number_of_draws(self): - assert self.ntrace['x'][0].shape[0] == self.draws - assert self.strace['x'][0].shape[0] == self.draws + nvalues = self.ntrace.get_values('x', squeeze=False) + svalues = self.strace.get_values('x', squeeze=False) + assert nvalues[0].shape[0] == self.draws + assert svalues[0].shape[0] == self.draws def test_get_item(self): npt.assert_equal(self.ntrace['x'], self.strace['x']) diff --git a/pymc/tests/test_sqlite_backend.py b/pymc/tests/test_sqlite_backend.py index 16b0da4556..51e0f5e9dd 100644 --- a/pymc/tests/test_sqlite_backend.py +++ b/pymc/tests/test_sqlite_backend.py @@ -21,20 +21,25 @@ def setUp(self): with mock.patch('pymc.backends.base.modelcontext') as context: context.return_value = self.model self.db = sqlite.SQLite('test.db') + + self.draws = 5 + self.db.cursor = mock.Mock() self.db.var_shapes = {'x': (), 'y': (3,)} + self.db.trace._chains = [0] + self.db.trace._len = self.draws + connect_patch = mock.patch('pymc.backends.sqlite.SQLite.connect') self.addCleanup(connect_patch.stop) self.connect = connect_patch.start() - self.draws = 5 class TestSQLiteSample(SQLiteTestCase): def test_setup_trace(self): self.db.setup(self.draws, chain=0) - self.connect.assert_called_once_with() + assert self.connect.called def test_setup_scalar(self): db = self.db @@ -52,6 +57,8 @@ def test_setup_scalar(self): def test_setup_1d(self): db = self.db db.setup(draws=3, chain=0) + db.trace._chains = [] + tbl_expected = ('CREATE TABLE IF NOT EXISTS [y] ' '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' 'draw INTEGER, ' diff --git a/pymc/tests/test_text_backend.py b/pymc/tests/test_text_backend.py index 4ccb0a580e..520324fa82 100644 --- a/pymc/tests/test_text_backend.py +++ b/pymc/tests/test_text_backend.py @@ -44,6 +44,7 @@ def setUp(self): self.draws = 5 self.db.var_shapes = {'x': (), 'y': (4,)} self.db.setup(self.draws, chain=0) + self.db.draw_idx = self.draws savetxt_patch = mock.patch('pymc.backends.text.np.savetxt') self.addCleanup(savetxt_patch.stop)