diff --git a/.travis.yml b/.travis.yml index 12ff62a264..f258fd05c0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,6 +16,7 @@ install: - conda create -n testenv --yes pip python=$TRAVIS_PYTHON_VERSION - source activate testenv - conda install --yes ipython==1.1.0 pyzmq numpy==1.8.0 scipy nose matplotlib pandas Cython patsy statsmodels + - if [ ${TRAVIS_PYTHON_VERSION:0:1} == "2" ]; then conda install --yes mock; fi - pip install --no-deps numdifftools - pip install git+https://github.com/Theano/Theano.git - python setup.py build_ext --inplace diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py new file mode 100644 index 0000000000..6e087dc66d --- /dev/null +++ b/pymc/backends/__init__.py @@ -0,0 +1,119 @@ +"""Backends for traces + +Available backends +------------------ + +1. NumPy array (pymc.backends.NDArray) +2. Text files (pymc.backends.Text) +3. SQLite (pymc.backends.SQLite) + +The NumPy arrays and text files both hold the entire trace in memory, +whereas SQLite commits the trace to the database while sampling. + +Selecting a backend +------------------- + +By default, a NumPy array is used as the backend. To specify a different +backend, pass a backend instance to `sample`. + +For example, the following would save traces to the file 'test.db'. + + >>> import pymc as pm + >>> db = pm.backends.SQLite('test.db') + >>> trace = pm.sample(..., trace=db) + +Selecting values from a backend +------------------------------- + +After a backend is finished sampling, it returns a MultiTrace object. +Values can be accessed in a few ways. The easiest way is to index the +backend object with a variable or variable name. + + >>> trace['x'] # or trace[x] + +The call will return a list containing the sampling values for all +chains of `x`. (Each call to `pymc.sample` creates a separate chain of +samples.) + +For more control of which values are returned, the `get_values` method +can be used. The call below will return values from all chains, burning +the first 1000 iterations from each chain. + + >>> trace.get_values('x', burn=1000) + +Setting the `combined` flag will concatenate the results from all the +chains. + + >>> trace.get_values('x', burn=1000, combine=True) + +The `chains` parameter of `get_values` can be used to limit the chains +that are retrieved. + + >>> trace.get_values('x', burn=1000, combine=True, chains=[0, 2]) + +Backends can also suppport slicing the MultiTrace object. For example, +the following call would return a new trace object without the first +1000 sampling iterations for all traces and variables. + + >>> sliced_trace = trace[1000:] + +Loading a saved backend +----------------------- + +Saved backends can be loaded using `load` function in the module for the +specific backend. + + >>> trace = pm.backends.sqlite.load('test.db') + +Writing custom backends +----------------------- + +Backends consist of a class that handles sampling storage and value +selection. Three sampling methods of backend will be called: + +- setup: Before sampling is started, the `setup` method will be called + with two arguments: the number of draws and the chain number. This is + useful setting up any structure for storing the sampling values that + require the above information. + +- record: Record the sampling results for the current draw. This method + will be called with a dictionary of values mapped to the variable + names. This is the only sampling function that *must* do something to + have a meaningful backend. + +- close: This method is called following sampling and should perform any + actions necessary for finalizing and cleaning up the backend. + +The base storage class `backends.base.BaseTrace` provides common model +setup that is used by all the PyMC backends. + +Several selection methods must also be defined: + +- get_values: This is the core method for selecting values from the + backend. It can be called directly and is used by __getitem__ when the + backend is indexed with a variable name or object. + +- _slice: Defines how the backend returns a slice of itself. This + is called if the backend is indexed with a slice range. + +- point: Returns values for each variable at a single iteration. This is + called if the backend is indexed with a single integer. + +- __len__: This should return the number of draws (for the highest chain + number). + +When `pymc.sample` finishes, it wraps all trace objects in a MultiTrace +object that provides a consistent selection interface for all backends. +If the traces are stored on disk, then a `load` function should also be +defined that returns a MultiTrace object. + +For specific examples, see pymc.backends.{ndarray,text,sqlite}.py. +""" +from pymc.backends.ndarray import NDArray +from pymc.backends.text import Text +from pymc.backends.sqlite import SQLite + +_shortcuts = {'text': {'backend': Text, + 'name': 'mcmc'}, + 'sqlite': {'backend': SQLite, + 'name': 'mcmc.sqlite'}} diff --git a/pymc/backends/base.py b/pymc/backends/base.py new file mode 100644 index 0000000000..31855beced --- /dev/null +++ b/pymc/backends/base.py @@ -0,0 +1,271 @@ +"""Base backend for traces + +See the docstring for pymc.backends for more information (includng +creating custom backends). +""" +import numpy as np +from pymc.model import modelcontext + + +class BaseTrace(object): + """Base trace object + + Parameters + ---------- + name : str + Name of backend + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name, model=None, vars=None): + self.name = name + + model = modelcontext(model) + self.model = model + ## `vars` is used throughout these backends to be consistent + ## with other code, but I'd prefer to rename this since it is a + ## built-in. + if vars is None: + vars = model.unobserved_RVs + self.vars = vars + self.varnames = [str(var) for var in vars] + self.fn = model.fastfn(vars) + + ## Get variable shapes. Most backends will need this + ## information. + var_values = zip(self.varnames, self.fn(model.test_point)) + self.var_shapes = {var: value.shape + for var, value in var_values} + self.chain = None + + ## Sampling methods + + def setup(self, draws, chain): + """Perform chain-specific setup. + + Parameters + ---------- + draws : int + Expected number of draws + chain : int + Chain number + """ + pass + + def record(self, point): + """Record results of a sampling iteration. + + Parameters + ---------- + point : dict + Values mapped to variable names + """ + raise NotImplementedError + + def close(self): + """Close the database backend. + + This is called after sampling has finished. + """ + pass + + ## Selection methods + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self._slice(idx) + + try: + return self.point(idx) + except ValueError: + pass + except TypeError: + pass + return self.get_values(idx) + + def __len__(self): + raise NotImplementedError + + def get_values(self, varname, burn=0, thin=1): + """Get values from trace. + + Parameters + ---------- + varname : str + burn : int + thin : int + + Returns + ------- + A NumPy array + """ + raise NotImplementedError + + def _slice(self, idx): + """Slice trace object.""" + raise NotImplementedError + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + """ + raise NotImplementedError + + +class MultiTrace(object): + """MultiTrace provides the main interface for accessing values from + traces objects. + + The core method to select values is `get_values`. Values can also be + accessed by indexing the MultiTrace object. Indexing can behave in + three ways: + + 1. Indexing with a variable or variable name (str) returns all + values for that variable. + 2. Indexing with an integer returns a dictionary with values for + each variable at the given index (corresponding to a single + sampling iteration). + 3. Slicing with a range returns a new trace with the number of draws + corresponding to the range. + + For any methods that require a single trace (e.g., taking the length + of the MultiTrace instance, which returns the number of draws), the + trace with the highest chain number is always used. + + Parameters + ---------- + traces : list of traces + Each object must have a unique `chain` attribute. + """ + def __init__(self, traces): + self._traces = {} + for trace in traces: + if trace.chain in self._traces: + raise ValueError("Chains are not unique.") + self._traces[trace.chain] = trace + + @property + def nchains(self): + return len(self._traces) + + @property + def chains(self): + return list(sorted(self._traces.keys())) + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self._slice(idx) + + try: + return self.point(idx) + except ValueError: + pass + except TypeError: + pass + return self.get_values(idx) + + def __len__(self): + chain = self.chains[-1] + return len(self._traces[chain]) + + @property + def varnames(self): + chain = self.chains[-1] + return self._traces[chain].varnames + + def get_values(self, varname, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from traces. + + Parameters + ---------- + varname : str + burn : int + thin : int + combine : bool + If True, results from `chains` will be concatenated. + chains : int or list of ints + Chains to retrieve. If None, all chains are used. A single + values can also accepted. + squeeze : bool + Return a single array element if the resulting list of + values only has one element. If this is not true, the result + will always be a list of arrays, even if `combine` is True. + + Returns + ------- + A list of NumPy arrays or a single NumPy array (depending on + `squeeze`). + """ + if chains is None: + chains = self.chains + varname = str(varname) + try: + results = [self._traces[chain].get_values(varname, burn, thin) + for chain in chains] + except TypeError: # Single chain passed. + results = [self._traces[chains].get_values(varname, burn, thin)] + return _squeeze_cat(results, combine, squeeze) + + def _slice(self, idx): + """Return a new MultiTrace object sliced according to `idx`.""" + chain = self.chains[-1] + model = self._traces[chain].model + vars = self._traces[chain].vars + + new_traces = [trace._slice(idx) for trace in self._traces.values()] + return MultiTrace(new_traces) + + def point(self, idx, chain=None): + """Return a dictionary of point values at `idx`. + + Parameters + ---------- + idx : int + chain : int + If a chain is not given, the highest chain number is used. + """ + if chain is None: + chain = self.chains[-1] + return self._traces[chain].point(idx) + + +def merge_traces(mtraces): + """Merge MultiTrace objects. + + Parameters + ---------- + mtraces : list of MultiTraces + Each instance should have unique chain numbers. + + Raises + ------ + A ValueError is raised if any traces have overlapping chain numbers. + + Returns + ------- + A MultiTrace instance with merged chains + """ + base_mtrace = mtraces[0] + for new_mtrace in mtraces[1:]: + for new_chain, trace in new_mtrace._traces.items(): + if new_chain in base_mtrace._traces: + raise ValueError("Chains are not unique.") + base_mtrace._traces[new_chain] = trace + return base_mtrace + + +def _squeeze_cat(results, combine, squeeze): + """Squeeze and concatenate the results depending on values of + `combine` and `squeeze`.""" + if combine: + results = np.concatenate(results) + if not squeeze: + results = [results] + else: + if squeeze and len(results) == 1: + results = results[0] + return results diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py new file mode 100644 index 0000000000..a59457f0ce --- /dev/null +++ b/pymc/backends/ndarray.py @@ -0,0 +1,111 @@ +"""NumPy array trace backend + +Store sampling values in memory as a NumPy array. +""" +import numpy as np +from pymc.backends import base + + +class NDArray(base.BaseTrace): + """NDArray trace object + + Parameters + ---------- + name : str + Name of backend. This has no meaning for the NDArray backend. + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name=None, model=None, vars=None): + super(NDArray, self).__init__(name, model, vars) + self.draw_idx = 0 + self.draws = None + self.samples = {} + + ## Sampling methods + + def setup(self, draws, chain): + """Perform chain-specific setup. + + Parameters + ---------- + draws : int + Expected number of draws + chain : int + Chain number + """ + self.chain = chain + if self.samples: # Concatenate new array if chain is already present. + old_draws = len(self) + self.draws = old_draws + draws + self.draws_idx = old_draws + for varname, shape in self.var_shapes.items(): + old_trace = self.samples[varname] + new_trace = np.zeros((draws, ) + shape) + self.samples[varname] = np.concatenate((old_trace, new_trace), + axis=0) + else: # Otherwise, make array of zeros for each variable. + self.draws = draws + for varname, shape in self.var_shapes.items(): + self.samples[varname] = np.zeros((draws, ) + shape) + + def record(self, point): + """Record results of a sampling iteration. + + Parameters + ---------- + point : dict + Values mapped to variable names + """ + for varname, value in zip(self.varnames, self.fn(point)): + self.samples[varname][self.draw_idx] = value + self.draw_idx += 1 + + def close(self): + if self.draw_idx == self.draws: + return + ## Remove trailing zeros if interrupted before completed all + ## draws. + self.samples = {var: trace[:self.draw_idx] + for var, trace in self.samples.items()} + + ## Selection methods + + def __len__(self): + if not self.samples: # `setup` has not been called. + return 0 + varname = self.varnames[0] + return self.samples[varname].shape[0] + + def get_values(self, varname, burn=0, thin=1): + """Get values from trace. + + Parameters + ---------- + varname : str + burn : int + thin : int + + Returns + ------- + A NumPy array + """ + return self.samples[varname][burn::thin] + + def _slice(self, idx): + sliced = NDArray(model=self.model, vars=self.vars) + sliced.chain = self.chain + sliced.samples = {varname: values[idx] + for varname, values in self.samples.items()} + return sliced + + def point(self, idx): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + """ + idx = int(idx) + return {varname: values[idx] + for varname, values in self.samples.items()} diff --git a/pymc/backends/sqlite.py b/pymc/backends/sqlite.py new file mode 100644 index 0000000000..24dd968c31 --- /dev/null +++ b/pymc/backends/sqlite.py @@ -0,0 +1,341 @@ +"""SQLite trace backend + +Store and retrieve sampling values in SQLite database file. + +Database format +--------------- +For each variable, a table is created with the following format: + + recid (INT), draw (INT), chain (INT), v1 (FLOAT), v2 (FLOAT), v3 (FLOAT) ... + +The variable column names are extended to reflect addition dimensions. +For example, a variable with the shape (2, 2) would be stored as + + key (INT), draw (INT), chain (INT), v1_1 (FLOAT), v1_2 (FLOAT), v2_1 (FLOAT) ... + +The key is autoincremented each time a new row is added to the table. +The chain column denotes the chain index, and starts at 0. +""" +import numpy as np +import sqlite3 +import warnings + +from pymc.backends import base + +TEMPLATES = { + 'table': ('CREATE TABLE IF NOT EXISTS [{table}] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, chain INT(5), ' + '{value_cols})'), + 'insert': ('INSERT INTO [{table}] ' + '(recid, draw, chain, {value_cols}) ' + 'VALUES (NULL, ?, ?, {values})'), + 'max_draw': ('SELECT MAX(draw) FROM [{table}] ' + 'WHERE chain = ?'), + 'draw_count': ('SELECT COUNT(*) FROM [{table}] ' + 'WHERE chain = ?'), + ## Named placeholders are used in the selection templates because + ## some values occur more than once in the same template. + 'select': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain)'), + 'select_burn': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw > :burn)'), + 'select_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND ' + '(draw - (SELECT draw FROM [{table}] ' + 'WHERE chain = :chain ' + 'ORDER BY draw LIMIT 1)) % :thin = 0'), + 'select_burn_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw > :burn) ' + 'AND (draw - (SELECT draw FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw > :burn) ' + 'ORDER BY draw LIMIT 1)) % :thin = 0'), + 'select_point': ('SELECT * FROM [{table}] ' + 'WHERE (chain = :chain) AND (draw = :draw)'), +} + + +class SQLite(base.BaseTrace): + """SQLite trace object + + Parameters + ---------- + name : str + Name of database file + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name, model=None, vars=None): + super(SQLite, self).__init__(name, model, vars) + self._var_cols = {} + self.var_inserts = {} # varname -> insert statement + self.draw_idx = 0 + self._is_setup = False + self._len = None + + self.db = _SQLiteDB(name) + ## Inserting sampling information is queued to avoid locks + ## caused by hitting the database with transactions each + ## iteration. + self._queue = {varname: [] for varname in self.varnames} + self._queue_limit = 5000 + + ## Sampling methods + + def setup(self, draws, chain): + """Perform chain-specific setup. + + Parameters + ---------- + draws : int + Expected number of draws + chain : int + Chain number + """ + self.db.connect() + self.chain = chain + + if not self._is_setup: # Table has not been created. + self._var_cols = {varname: _create_colnames(shape) + for varname, shape in self.var_shapes.items()} + self._create_table() + self._is_setup = True + else: + self.draw_idx = self._get_max_draw(chain) + 1 + self._len = None + self._create_insert_queries(chain) + + def _create_table(self): + template = TEMPLATES['table'] + with self.db.con: + for varname, var_cols in self._var_cols.items(): + var_float = ', '.join([v + ' FLOAT' for v in var_cols]) + statement = template.format(table=varname, + value_cols=var_float) + self.db.cursor.execute(statement) + + def _create_insert_queries(self, chain): + template = TEMPLATES['insert'] + for varname, var_cols in self._var_cols.items(): + ## Create insert statement for each variable. + var_str = ', '.join(var_cols) + placeholders = ', '.join(['?'] * len(var_cols)) + statement = template.format(table=varname, + value_cols=var_str, + values=placeholders) + self.var_inserts[varname] = statement + + def record(self, point): + """Record results of a sampling iteration. + + Parameters + ---------- + point : dict + Values mapped to variable names + """ + for varname, value in zip(self.varnames, self.fn(point)): + values = (self.draw_idx, self.chain) + tuple(np.ravel(value)) + self._queue[varname].append(values) + + if len(self._queue[varname]) > self._queue_limit: + self._execute_queue() + self.draw_idx += 1 + + def _execute_queue(self): + with self.db.con: + for varname in self.varnames: + if not self._queue[varname]: + continue + self.db.cursor.executemany(self.var_inserts[varname], + self._queue[varname]) + self._queue[varname] = [] + + def close(self): + self._execute_queue() + self.db.close() + + ## Selection methods + + def __len__(self): + if not self._is_setup: + return 0 + if self._len is None: + self._len = self._get_number_draws() + return self._len + + def _get_number_draws(self): + self.db.connect() + statement = TEMPLATES['draw_count'].format(table=self.varnames[0]) + self.db.cursor.execute(statement, (self.chain,)) + return self.db.cursor.fetchall()[0][0] + + def _get_max_draw(self, chain): + self.db.connect() + statement = TEMPLATES['max_draw'].format(table=self.varnames[0]) + self.db.cursor.execute(statement, (chain, )) + return self.db.cursor.fetchall()[0][0] + + def get_values(self, varname, burn=0, thin=1): + """Get values from trace. + + Parameters + ---------- + varname : str + burn : int + thin : int + + Returns + ------- + A NumPy array + """ + if burn < 0: + raise ValueError('Negative burn values not supported ' + 'in SQLite backend.') + if thin < 1: + raise ValueError('Only positive thin values are supported ' + 'in SQLite backend.') + varname = str(varname) + + statement_args = {'chain': self.chain} + if burn == 0 and thin == 1: + action = 'select' + elif thin == 1: + action = 'select_burn' + statement_args['burn'] = burn - 1 + elif burn == 0: + action = 'select_thin' + statement_args['thin'] = thin + else: + action = 'select_burn_thin' + statement_args['burn'] = burn - 1 + statement_args['thin'] = thin + + self.db.connect() + statement = TEMPLATES[action].format(table=varname) + self.db.cursor.execute(statement, statement_args) + return _rows_to_ndarray(self.db.cursor) + + def _slice(self, idx): + warnings.warn('Slice for SQLite backend has no effect.') + + def point(self, idx): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + """ + idx = int(idx) + if idx < 0: + idx = self._get_max_draw(self.chain) - idx - 1 + statement = TEMPLATES['select_point'] + self.db.connect() + var_values = {} + statement_args = {'chain': self.chain, 'draw': idx} + for varname in self.varnames: + self.db.cursor.execute(statement.format(table=varname), + statement_args) + var_values[varname] = np.squeeze( + _rows_to_ndarray(self.db.cursor)) + return var_values + + +class _SQLiteDB(object): + def __init__(self, name): + self.name = name + self.con = None + self.cursor = None + self.connected = False + + def connect(self): + if self.connected: + return + self.con = sqlite3.connect(self.name) + self.connected = True + self.cursor = self.con.cursor() + + def close(self): + if not self.connected: + return + self.con.commit() + self.cursor.close() + self.con.close() + self.connected = False + + +def _create_colnames(shape): + """Return column names based on `shape`. + + Examples + -------- + >>> create_colnames((5,)) + ['v1', 'v2', 'v3', 'v4', 'v5'] + + >>> create_colnames((2,2)) + ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + """ + if not shape: + return ['v1'] + + size = np.prod(shape) + indices = (np.indices(shape) + 1).reshape(-1, size) + return ['v' + '_'.join(map(str, i)) for i in zip(*indices)] + + +def load(name, model=None): + """Load SQLite database. + + Parameters + ---------- + name : str + Path to SQLite database file + model : Model + If None, the model is taken from the `with` context. + + Returns + ------- + A MultiTrace instance + """ + db = _SQLiteDB(name) + db.connect() + varnames = _get_table_list(db.cursor) + chains = _get_chain_list(db.cursor, varnames[0]) + + traces = [] + for chain in chains: + trace = SQLite(name, model=model) + trace.varnames = varnames + trace.chain = chain + trace._is_setup = True + trace.db = db # Share the db with all traces. + traces.append(trace) + return base.MultiTrace(traces) + + +def _get_table_list(cursor): + """Return a list of table names in the current database.""" + ## Modified from Django. Skips the sqlite_sequence system table used + ## for autoincrement key generation. + cursor.execute("SELECT name FROM sqlite_master " + "WHERE type='table' AND NOT name='sqlite_sequence' " + "ORDER BY name") + return [row[0] for row in cursor.fetchall()] + + +def _get_var_strs(cursor, varname): + cursor.execute('SELECT * FROM [{}]'.format(varname)) + col_names = (col_descr[0] for col_descr in cursor.description) + return [name for name in col_names if name.startswith('v')] + + +def _get_chain_list(cursor, varname): + """Return a list of sorted chains for `varname`.""" + cursor.execute('SELECT DISTINCT chain FROM [{}]'.format(varname)) + chains = [chain[0] for chain in cursor.fetchall()] + chains.sort() + return chains + + +def _rows_to_ndarray(cursor): + """Convert SQL row to NDArray.""" + return np.array([row[3:] for row in cursor.fetchall()]) diff --git a/pymc/backends/text.py b/pymc/backends/text.py new file mode 100644 index 0000000000..40ef1a24d6 --- /dev/null +++ b/pymc/backends/text.py @@ -0,0 +1,114 @@ +"""Text file trace backend + +After sampling with NDArray backend, save results as text files. + +Database format +--------------- + +For each chain, a directory named `chain-N` is created. In this +directory, one file per variable is created containing the values of the +object. To deal with multidimensional variables, the array is reshaped +to one dimension before saving with `numpy.savetxt`. The shape +information is saved in a json file in the same directory and is used to +load the database back again using `numpy.loadtxt`. +""" +import os +import glob +import json +import numpy as np +from contextlib import contextmanager + +from pymc.backends import base +from pymc.backends.ndarray import NDArray + + +class Text(NDArray): + """Text storage + + Parameters + ---------- + name : str + Name of directory to store text files. + model : Model + If None, the model is taken from the `with` context. + vars : list of variables + Sampling values will be stored for these variables. If None, + `model.unobserved_RVs` is used. + """ + def __init__(self, name, model=None, vars=None): + super(Text, self).__init__(name, model, vars) + if not os.path.exists(name): + os.mkdir(name) + + def close(self): + super(Text, self).close() + chain_name = 'chain-{}'.format(self.chain) + chain_dir = os.path.join(self.name, chain_name) + os.mkdir(chain_dir) + + shapes = {} + for varname in self.varnames: + data = self.samples[varname] + var_file = os.path.join(chain_dir, varname + '.txt') + np.savetxt(var_file, data.reshape(-1, data.size)) + shapes[varname] = data.shape + ## Store shape information for reloading. + with _get_shape_fh(chain_dir, 'w') as sfh: + json.dump(shapes, sfh) + + +def load(name, chains=None, model=None): + """Load text database. + + Parameters + ---------- + name : str + Path to root directory for text database + chains : list + Chains to load. If None, all chains are loaded. + model : Model + If None, the model is taken from the `with` context. + + Returns + ------- + ndarray.Trace instance + """ + chain_dirs = _get_chain_dirs(name) + if chains is None: + chains = list(chain_dirs.keys()) + + traces = [] + for chain in chains: + chain_dir = chain_dirs[chain] + with _get_shape_fh(chain_dir, 'r') as sfh: + shapes = json.load(sfh) + samples = {} + for varname, shape in shapes.items(): + var_file = os.path.join(chain_dir, varname + '.txt') + samples[varname] = np.loadtxt(var_file).reshape(shape) + trace = NDArray(model=model) + trace.samples = samples + trace.chain = chain + traces.append(trace) + return base.MultiTrace(traces) + + +## The json file is opened here instead of `Text.close` and `load` for +## testing convenience. +@contextmanager +def _get_shape_fh(chain_dir, mode='r'): + fh = open(os.path.join(chain_dir, 'shapes.json'), mode) + try: + yield fh + finally: + fh.close() + + +def _get_chain_dirs(name): + """Return mapping of chain number to directory.""" + return {_chain_dir_to_chain(chain_dir): chain_dir + for chain_dir in glob.glob(os.path.join(name, 'chain-*'))} + + +def _chain_dir_to_chain(chain_dir): + return int(os.path.basename(chain_dir).split('-')[1]) diff --git a/pymc/diagnostics.py b/pymc/diagnostics.py index 88e15a83a7..1c91a68218 100644 --- a/pymc/diagnostics.py +++ b/pymc/diagnostics.py @@ -87,7 +87,7 @@ def geweke(x, first=.1, last=.5, intervals=20): return np.array(zscores) -def gelman_rubin(mtrace): +def gelman_rubin(trace): """ Returns estimate of R for a set of traces. The Gelman-Rubin diagnostic tests for lack of convergence by comparing @@ -99,8 +99,8 @@ def gelman_rubin(mtrace): Parameters ---------- - mtrace : MultiTrace - A MultiTrace object containing parallel traces (minimum 2) + trace + A trace object containing parallel traces (minimum 2) of one or more stochastic parameters. Returns @@ -126,8 +126,7 @@ def gelman_rubin(mtrace): Brooks and Gelman (1998) Gelman and Rubin (1992)""" - m = len(mtrace.traces) - if m < 2: + if trace.nchains < 2: raise ValueError( 'Gelman-Rubin diagnostic requires multiple chains of the same length.') @@ -157,10 +156,10 @@ def calc_rhat(x): return np.squeeze([calc_rhat(xi) for xi in x.transpose(rotated_indices)]) Rhat = {} - for var in mtrace.varnames: + for var in trace.varnames: # Get all traces for var - x = np.array([mtrace.traces[i][var] for i in range(m)]) + x = np.array(trace.get_values(var)) try: Rhat[var] = calc_rhat(x) @@ -169,9 +168,11 @@ def calc_rhat(x): return Rhat + def trace_to_dataframe(trace): """Convert a PyMC trace consisting of 1-D variables to a pandas DataFrame """ import pandas as pd - return pd.DataFrame({name: np.squeeze(trace_var.vals) - for name, trace_var in trace.samples.items()}) + return pd.DataFrame( + {varname: np.squeeze(trace.get_values(varname, combine=True)) + for varname in trace.varnames}) diff --git a/pymc/examples/gelman_bioassay.py b/pymc/examples/gelman_bioassay.py index 4725ded00b..5aa23d038d 100644 --- a/pymc/examples/gelman_bioassay.py +++ b/pymc/examples/gelman_bioassay.py @@ -28,7 +28,7 @@ def run(n=1000): if n == "short": n = 50 with model: - trace = sample(n, step, trace=model.unobserved_RVs) + trace = sample(n, step) if __name__ == '__main__': run() diff --git a/pymc/examples/hierarchical_sqlite.py b/pymc/examples/hierarchical_sqlite.py new file mode 100644 index 0000000000..d8850e8dae --- /dev/null +++ b/pymc/examples/hierarchical_sqlite.py @@ -0,0 +1,8 @@ + +if __name__ == '__main__': + ## Avoid loading during tests. + import pymc as pm + import pymc.examples.hierarchical as hier + + with hier.model: + trace = pm.sample(3000, hier.step, hier.start, trace='sqlite') diff --git a/pymc/plots.py b/pymc/plots.py index c8008de66d..ce8d8fa556 100644 --- a/pymc/plots.py +++ b/pymc/plots.py @@ -1,7 +1,6 @@ import numpy as np from scipy.stats import kde from .stats import * -from .trace import * __all__ = ['traceplot', 'kdeplot', 'kde2plot', 'forestplot', 'autocorrplot'] @@ -23,8 +22,8 @@ def traceplot(trace, vars=None, figsize=None, lines to the posteriors and horizontal lines on sample values e.g. mean of posteriors, true values of a simulation combined : bool - Flag for combining MultiTrace into a single trace. If False (default) - traces will be plotted separately on the same set of axes. + Flag for combining multiple chains into a single chain. If False + (default), chains will be plotted separately. grid : bool Flag for adding gridlines to histogram. Defaults to True. @@ -38,14 +37,6 @@ def traceplot(trace, vars=None, figsize=None, if vars is None: vars = trace.varnames - if isinstance(trace, MultiTrace): - if combined: - traces = [trace.combined()] - else: - traces = trace.traces - else: - traces = [trace] - n = len(vars) if figsize is None: @@ -53,11 +44,10 @@ def traceplot(trace, vars=None, figsize=None, fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize) - for trace in traces: - for i, v in enumerate(vars): - d = np.squeeze(trace[v]) - - if trace[v].dtype.kind == 'i': + for i, v in enumerate(vars): + for d in trace.get_values(v, combine=combined, squeeze=False): + d = np.squeeze(d) + if d.dtype.kind == 'i': histplot_op(ax[i, 0], d) else: kdeplot_op(ax[i, 0], d) @@ -133,35 +123,23 @@ def kde2plot(x, y, grid=200): def autocorrplot(trace, vars=None, fontmap=None, max_lag=100): """Bar plot of the autocorrelation function for a trace""" import matplotlib.pyplot as plt - try: - # MultiTrace - traces = trace.traces - - except AttributeError: - # NpTrace - traces = [trace] - if fontmap is None: fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4} if vars is None: - vars = traces[0].varnames - - # Extract sample data - samples = [{v: trace[v] for v in vars} for trace in traces] + vars = trace.varnames + else: + vars = [str(var) for var in vars] - chains = len(traces) + chains = trace.nchains - n = len(samples[0]) - f, ax = plt.subplots(n, chains, squeeze=False) + f, ax = plt.subplots(len(vars), chains, squeeze=False) - max_lag = min(len(samples[0][vars[0]])-1, max_lag) + max_lag = min(len(trace) - 1, max_lag) for i, v in enumerate(vars): - for j in range(chains): - - d = np.squeeze(samples[j][v]) + d = np.squeeze(trace.get_values(v, chains=[j])) ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag) @@ -279,41 +257,26 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, interval_plot = None rhat_plot = None - try: - # First try MultiTrace type - traces = trace_obj.traces - - if rhat and len(traces) > 1: - - from .diagnostics import gelman_rubin - - R = gelman_rubin(trace_obj) - if vars is not None: - R = {v: R[v] for v in vars} - - else: - - rhat = False - - except AttributeError: - - # Single NpTrace - traces = [trace_obj] + nchains = trace_obj.nchains + if nchains > 1: + from .diagnostics import gelman_rubin + R = gelman_rubin(trace_obj) + if vars is not None: + R = {v: R[v] for v in vars} + else: # Can't calculate Gelman-Rubin with a single trace rhat = False if vars is None: - vars = traces[0].varnames + vars = trace_obj.varnames # Empty list for y-axis labels labels = [] - chains = len(traces) - if gs is None: # Initialize plot - if rhat and chains > 1: + if rhat and nchains > 1: gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1]) else: @@ -323,21 +286,18 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, # Subplot for confidence intervals interval_plot = plt.subplot(gs[0]) - for j, tr in enumerate(traces): - # Get quantiles - trace_quantiles = quantiles(tr, qlist) - hpd_intervals = hpd(tr, alpha) + trace_quantiles = quantiles(trace_obj, qlist, squeeze=False) + hpd_intervals = hpd(trace_obj, alpha, squeeze=False) + for j, chain in enumerate(trace_obj.chains): # Counter for current variable var = 1 - for varname in vars: - - var_quantiles = trace_quantiles[varname] + var_quantiles = trace_quantiles[chain][varname] quants = list(var_quantiles.values()) - var_hpd = hpd_intervals[varname].T + var_hpd = hpd_intervals[chain][varname].T # Substitute HPD interval for quantile quants[0] = var_hpd[0].T @@ -354,7 +314,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, plotrange = [np.min(quants), np.max(quants)] # Number of elements in current variable - value = tr[varname][0] + value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) # Append variable name(s) to list @@ -368,7 +328,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, # Add spacing for each chain, if more than one e = [0] + [(chain_spacing * ((i + 2) / 2)) * - (-1) ** i for i in range(chains - 1)] + (-1) ** i for i in range(nchains - 1)] # Deal with multivariate nodes if k > 1: @@ -480,7 +440,7 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, plt.axvline(vline, color='k', linestyle='--') # Genenerate Gelman-Rubin plot - if rhat and chains > 1: + if rhat and nchains > 1: # If there are multiple chains, calculate R-hat rhat_plot = plt.subplot(gs[1]) @@ -498,7 +458,8 @@ def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, i = 1 for varname in vars: - value = traces[0][varname][0] + chain = trace_obj.chains[0] + value = trace_obj.get_values(varname, chains=[chain])[0] k = np.size(value) if k > 1: diff --git a/pymc/sampling.py b/pymc/sampling.py index f6df93689b..162d105788 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -1,5 +1,7 @@ from .point import * -from .trace import NpTrace, MultiTrace +from pymc import backends +from pymc.backends.base import merge_traces, BaseTrace, MultiTrace +from pymc.backends.ndarray import NDArray import multiprocessing as mp from time import time from .core import * @@ -7,10 +9,11 @@ from .progressbar import progress_bar from numpy.random import seed -__all__ = ['sample', 'psample', 'iter_sample'] +__all__ = ['sample', 'iter_sample'] -def sample(draws, step, start=None, trace=None, tune=None, progressbar=True, model=None, random_seed=None): +def sample(draws, step, start=None, trace=None, chain=0, njobs=1, tune=None, + progressbar=True, model=None, random_seed=None): """ Draw a number of samples using the given step method. Multiple step methods supported via compound step method @@ -27,37 +30,87 @@ def sample(draws, step, start=None, trace=None, tune=None, progressbar=True, mod Starting point in parameter space (or partial point) Defaults to trace.point(-1)) if there is a trace provided and model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + trace : backend, list, or MultiTrace + This should be a backend instance, a list of variables to track, + or a MultiTrace object with past values. If a MultiTrace object + is given, it must contain samples for the chain number `chain`. + If None or a list of variables, the NDArray backend is used. + Passing either "text" or "sqlite" is taken as a shortcut to set + up the corresponding backend (with "mcmc" used as the base + name). + chain : int + Chain number used to store sample in backend. If `njobs` is + greater than one, chain numbers will start here. + njobs : int + Number of parallel jobs to start. If None, set to number of cpus + in the system - 2. tune : int Number of iterations to tune, if applicable (defaults to None) progressbar : bool Flag for progress bar model : Model (optional if in `with` context) + random_seed : int or list of ints + A list is accepted if more if `njobs` is greater than one. + Returns + ------- + MultiTrace object with access to sampling values """ - progress = progress_bar(draws) + if njobs is None: + njobs = max(mp.cpu_count() - 2, 1) + if njobs > 1: + try: + if not len(random_seed) == njobs: + random_seeds = [random_seed] * njobs + else: + random_seeds = random_seed + except TypeError: # None, int + random_seeds = [random_seed] * njobs + + chains = list(range(chain, chain + njobs)) + + pbars = [progressbar] + [False] * (njobs - 1) + + argset = zip([draws] * njobs, + [step] * njobs, + [start] * njobs, + [trace] * njobs, + chains, + [tune] * njobs, + pbars, + [model] * njobs, + random_seeds) + sample_func = _mp_sample + sample_args = [njobs, argset] + else: + sample_func = _sample + sample_args = [draws, step, start, trace, chain, + tune, progressbar, model, random_seed] + return sample_func(*sample_args) + +def _sample(draws, step, start=None, trace=None, chain=0, tune=None, + progressbar=True, model=None, random_seed=None): + sampling = _iter_sample(draws, step, start, trace, chain, + tune, model, random_seed) + progress = progress_bar(draws) try: - for i, trace in enumerate(iter_sample(draws, step, - start=start, - trace=trace, - tune=tune, - model=model, - random_seed=random_seed)): + for i, trace in enumerate(sampling): if progressbar: progress.update(i) except KeyboardInterrupt: - pass - return trace + trace.close() + return MultiTrace([trace]) -def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, random_seed=None): + +def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, + model=None, random_seed=None): """ Generator that returns a trace on each iteration using the given step method. Multiple step methods supported via compound step method returns the amount of time taken. + Parameters ---------- @@ -69,41 +122,49 @@ def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, rand Starting point in parameter space (or partial point) Defaults to trace.point(-1)) if there is a trace provided and model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + trace : backend, list, or MultiTrace + This should be a backend instance, a list of variables to track, + or a MultiTrace object with past values. If a MultiTrace object + is given, it must contain samples for the chain number `chain`. + If None or a list of variables, the NDArray backend is used. + chain : int + Chain number used to store sample in backend. If `njobs` is + greater than one, chain numbers will start here. tune : int Number of iterations to tune, if applicable (defaults to None) model : Model (optional if in `with` context) + random_seed : int or list of ints + A list is accepted if more if `njobs` is greater than one. Example ------- for trace in iter_sample(500, step): ... - """ + sampling = _iter_sample(draws, step, start, trace, chain, tune, + model, random_seed) + for i, trace in enumerate(sampling): + yield trace[:i + 1] + + +def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None, + model=None, random_seed=None): model = modelcontext(model) draws = int(draws) seed(random_seed) + if draws < 1: + raise ValueError('Argument `draws` should be above 0.') if start is None: start = {} - if isinstance(trace, NpTrace) and len(trace) > 0: - trace_point = trace.point(-1) - trace_point.update(start) - start = trace_point + trace = _choose_backend(trace, chain, model=model) + if len(trace) > 0: + _soft_update(start, trace.point(-1)) else: - test_point = model.test_point.copy() - test_point.update(start) - start = test_point - - if not isinstance(trace, NpTrace): - if trace is None: - trace = model.unobserved_RVs - trace = NpTrace(trace) + _soft_update(start, model.test_point) try: step = step_methods.CompoundStep(step) @@ -112,12 +173,43 @@ def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, rand point = Point(start, model=model) + trace.setup(draws, chain) for i in range(draws): - if (i == tune): + if i == tune: step = stop_tuning(step) point = step.step(point) trace.record(point) yield trace + else: + trace.close() + + +def _choose_backend(trace, chain, shortcuts=None, **kwds): + if isinstance(trace, BaseTrace): + return trace + if isinstance(trace, MultiTrace): + return trace._traces[chain] + if trace is None: + return NDArray(**kwds) + + if shortcuts is None: + shortcuts = backends._shortcuts + + try: + backend = shortcuts[trace]['backend'] + name = shortcuts[trace]['name'] + return backend(name, **kwds) + except TypeError: + return NDArray(vars=trace, **kwds) + except KeyError: + raise ValueError('Argument `trace` is invalid.') + + +def _mp_sample(njobs, args): + p = mp.Pool(njobs) + traces = p.map(argsample, args) + p.close() + return merge_traces(traces) def stop_tuning(step): @@ -134,71 +226,10 @@ def stop_tuning(step): def argsample(args): """ defined at top level so it can be pickled""" - return sample(*args) + return _sample(*args) -def psample(draws, step, start=None, trace=None, tune=None, progressbar=True, - model=None, threads=None, random_seeds=None): - """draw a number of samples using the given step method. - Multiple step methods supported via compound step method - returns the amount of time taken - - Parameters - ---------- - - draws : int - The number of samples to draw - step : function - A step function - start : dict - Starting point in parameter space (Defaults to trace.point(-1)) - trace : MultiTrace or list - Either a trace of past values or a list of variables to track (defaults to None) - tune : int - Number of iterations to tune, if applicable (defaults to None) - progressbar : bool - Flag for progress bar - model : Model (optional if in `with` context) - threads : int - Number of parallel traces to start - - Examples - -------- - - >>> an example - +def _soft_update(a, b): + """As opposed to dict.update, don't overwrite keys if present. """ - - model = modelcontext(model) - - if not threads: - threads = max(mp.cpu_count() - 2, 1) - - if start is None: - start = {} - - if isinstance(start, dict): - start = threads * [start] - - if trace is None: - trace = model.vars - - if type(trace) is MultiTrace: - mtrace = trace - else: - mtrace = MultiTrace(threads, trace) - - p = mp.Pool(threads) - - if random_seeds is None: - random_seeds = [None] * threads - pbars = [progressbar] + [False] * (threads - 1) - - argset = zip([draws] * threads, [step] * threads, start, mtrace.traces, - [tune] * threads, pbars, [model] * threads, random_seeds) - - traces = p.map(argsample, argset) - - p.close() - - return MultiTrace(traces) + a.update({k: v for k, v in b.items() if k not in a}) diff --git a/pymc/stats.py b/pymc/stats.py index b028e35b54..22dbdc1901 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -11,39 +11,30 @@ def statfunc(f): """ def wrapped_f(pymc_obj, *args, **kwargs): - - try: - burn = kwargs.pop('burn') - except KeyError: - burn = 0 - try: - # MultiTrace - traces = pymc_obj.traces - - try: - vars = kwargs.pop('vars') - except KeyError: - vars = traces[0].varnames - - return [{v: f(trace[v][burn:], *args, **kwargs) for v in vars} for trace in traces] - + vars = kwargs.pop('vars', pymc_obj.varnames) + chains = kwargs.pop('chains', pymc_obj.chains) except AttributeError: - pass - - try: - # NpTrace - try: - vars = kwargs.pop('vars') - except KeyError: - vars = pymc_obj.varnames - - return {v: f(pymc_obj[v][burn:], *args, **kwargs) for v in vars} - except AttributeError: - pass - - # If others fail, assume that raw data is passed - return f(pymc_obj, *args, **kwargs) + # If fails, assume that raw data was passed. + return f(pymc_obj, *args, **kwargs) + + burn = kwargs.pop('burn', 0) + thin = kwargs.pop('thin', 1) + combine = kwargs.pop('combine', False) + ## Remove outer level chain keys if only one chain) + squeeze = kwargs.pop('squeeze', True) + + results = {chain: {} for chain in chains} + for var in vars: + samples = pymc_obj.get_values(var, chains=chains, burn=burn, + thin=thin, combine=combine, + squeeze=False) + for chain, data in zip(chains, samples): + results[chain][var] = f(np.squeeze(data), *args, **kwargs) + + if squeeze and (len(chains) == 1 or combine): + results = results[chains[0]] + return results wrapped_f.__doc__ = f.__doc__ wrapped_f.__name__ = f.__name__ diff --git a/pymc/tests/checks.py b/pymc/tests/checks.py index 2e5470bc8f..6d4adf2180 100644 --- a/pymc/tests/checks.py +++ b/pymc/tests/checks.py @@ -1,7 +1,7 @@ import pymc as pm import numpy as np -from pymc import sample, psample +from pymc import sample from numpy.testing import assert_almost_equal diff --git a/pymc/tests/test_backend_dump_load.py b/pymc/tests/test_backend_dump_load.py new file mode 100644 index 0000000000..8f1b705302 --- /dev/null +++ b/pymc/tests/test_backend_dump_load.py @@ -0,0 +1,101 @@ +import os +import shutil +import numpy as np +import numpy.testing as npt +import unittest + +import pymc as pm + +## Set to False to keep effect of cea5659. Should this be set to True? +TEST_PARALLEL = False + + +def remove_file_or_directory(name): + try: + os.remove(name) + except OSError: + shutil.rmtree(name, ignore_errors=True) + + +class DumpLoadTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if TEST_PARALLEL: + njobs = 2 + else: + njobs = 1 + + data = np.random.normal(size=(2, 20)) + model = pm.Model() + with model: + x = pm.Normal('x', mu=.5, tau=2. ** -2, shape=(2, 1)) + z = pm.Beta('z', alpha=10, beta=5.5) + d = pm.Normal('data', mu=x, tau=.75 ** -2, observed=data) + data = np.random.normal(size=(3, 20)) + n = 1 + + draws = 5 + cls.draws = draws + + with model: + try: + cls.trace = pm.sample(n, step=pm.Metropolis(), + trace=cls.backend(cls.db), + njobs=njobs) + cls.dumped = cls.load_func(cls.db) + except: + remove_file_or_directory(cls.db) + raise + + @classmethod + def tearDownClass(cls): + remove_file_or_directory(cls.db) + + +class TestTextDumpLoad(DumpLoadTestCase): + + backend = pm.backends.Text + load_func = staticmethod(pm.backends.text.load) + db = 'text-db' + + def test_nchains(self): + self.assertEqual(self.trace.nchains, self.dumped.nchains) + + def test_varnames(self): + trace_names = list(sorted(self.trace.varnames)) + dumped_names = list(sorted(self.dumped.varnames)) + self.assertEqual(trace_names, dumped_names) + + def test_values(self): + trace = self.trace + dumped = self.dumped + for chain in trace.chains: + for varname in trace.varnames: + data = trace.get_values(varname, chains=[chain]) + dumped_data = dumped.get_values(varname, chains=[chain]) + npt.assert_equal(data, dumped_data) + + +class TestSQLiteDumpLoad(DumpLoadTestCase): + + backend = pm.backends.SQLite + load_func = staticmethod(pm.backends.sqlite.load) + db = 'test.db' + + def test_nchains(self): + self.assertEqual(self.trace.nchains, self.dumped.nchains) + + def test_varnames(self): + trace_names = list(sorted(self.trace.varnames)) + dumped_names = list(sorted(self.dumped.varnames)) + self.assertEqual(trace_names, dumped_names) + + def test_values(self): + trace = self.trace + dumped = self.dumped + for chain in trace.chains: + for varname in trace.varnames: + data = trace.get_values(varname, chains=[chain]) + dumped_data = dumped.get_values(varname, chains=[chain]) + npt.assert_equal(data, dumped_data) diff --git a/pymc/tests/test_base_backend.py b/pymc/tests/test_base_backend.py new file mode 100644 index 0000000000..16536e1823 --- /dev/null +++ b/pymc/tests/test_base_backend.py @@ -0,0 +1,56 @@ +import numpy as np +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc.backends import base + + +class TestMultiTrace(unittest.TestCase): + + def test_multitrace_init_unique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + trace1 = mock.Mock() + trace1.chain = 1 + mtrace = base.MultiTrace([trace0, trace1]) + self.assertEqual(mtrace._traces[0], trace0) + self.assertEqual(mtrace._traces[1], trace1) + + def test_multitrace_init_nonunique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + trace1 = mock.Mock() + trace1.chain = 0 + self.assertRaises(ValueError, + base.MultiTrace, [trace0, trace1]) + + +class TestMergeChains(unittest.TestCase): + + def test_merge_traces_unique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + mtrace0 = base.MultiTrace([trace0]) + + trace1 = mock.Mock() + trace1.chain = 1 + mtrace1 = base.MultiTrace([trace1]) + + merged = base.merge_traces([mtrace0, mtrace1]) + self.assertEqual(merged._traces[0], trace0) + self.assertEqual(merged._traces[1], trace1) + + def test_merge_traces_nonunique_chains(self): + trace0 = mock.Mock() + trace0.chain = 0 + mtrace0 = base.MultiTrace([trace0]) + + trace1 = mock.Mock() + trace1.chain = 0 + mtrace1 = base.MultiTrace([trace1]) + + self.assertRaises(ValueError, + base.merge_traces, [mtrace0, mtrace1]) diff --git a/pymc/tests/test_diagnostics.py b/pymc/tests/test_diagnostics.py index 8322a7ed27..6a04aeea28 100644 --- a/pymc/tests/test_diagnostics.py +++ b/pymc/tests/test_diagnostics.py @@ -9,8 +9,8 @@ def test_gelman_rubin(n=1000): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(n, [step1, step2], start, threads=2, - random_seeds=[1, 3]) + ptrace = sample(n, [step1, step2], start, njobs=2, + random_seed=[1, 3]) rhat = gelman_rubin(ptrace) diff --git a/pymc/tests/test_glm.py b/pymc/tests/test_glm.py index 1882d6040c..0262261e45 100644 --- a/pymc/tests/test_glm.py +++ b/pymc/tests/test_glm.py @@ -44,9 +44,9 @@ def test_linear_component(self): step = Slice(model.vars) trace = sample(2000, step, start, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) @unittest.skip("Fails only on travis. Investigate") def test_glm(self): @@ -57,9 +57,9 @@ def test_glm(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) def test_glm_link_func(self): with Model() as model: @@ -71,5 +71,5 @@ def test_glm_link_func(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 0) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 0) diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py new file mode 100644 index 0000000000..e0319090a8 --- /dev/null +++ b/pymc/tests/test_ndarray_backend.py @@ -0,0 +1,321 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc.backends import base, ndarray + + +class NDArrayTestCase(unittest.TestCase): + def setUp(self): + self.varnames = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.varnames + self.model.fastfn = mock.MagicMock() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace = ndarray.NDArray() + + +class TestNDArraySampling(NDArrayTestCase): + + def test_setup_scalar(self): + trace = self.trace + trace.var_shapes = {'x': ()} + draws, chain = 3, 0 + trace.setup(draws, chain) + npt.assert_equal(trace.samples['x'], np.zeros(draws)) + + def test_setup_1d(self): + trace = self.trace + shape = (2,) + trace.var_shapes = {'x': shape} + draws, chain = 3, 0 + trace.setup(draws, chain) + npt.assert_equal(trace.samples['x'], np.zeros((draws,) + shape)) + + def test_record(self): + trace = self.trace + draws = 3 + + trace.var_shapes = {'x': (), 'y': (4,)} + trace.setup(draws, chain=0) + + def just_ones(*args): + while True: + yield 1. + + trace.fn = just_ones + trace.draw_idx = 0 + + trace.record(point=None) + npt.assert_equal(1., trace.get_values('x')[0]) + npt.assert_equal(np.ones(4), trace['y'][0]) + + def test_clean_interrupt(self): + trace = self.trace + trace.setup(draws=10, chain=0) + trace.samples = {'x': np.zeros(10), 'y': np.zeros((10, 5))} + trace.draw_idx = 3 + trace.close() + npt.assert_equal(np.zeros(3), trace['x']) + npt.assert_equal(np.zeros((3, 5)), trace['y']) + + def test_standard_close(self): + trace = self.trace + trace.setup(draws=10, chain=0) + trace.samples = {'x': np.zeros(10), 'y': np.zeros((10, 5))} + trace.draw_idx = 10 + trace.close() + npt.assert_equal(np.zeros(10), trace['x']) + npt.assert_equal(np.zeros((10, 5)), trace['y']) + + +class TestNDArraySelection(NDArrayTestCase): + + def setUp(self): + super(TestNDArraySelection, self).setUp() + draws = 3 + self.trace.samples = {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))} + self.draws = draws + var_shapes = {'x': (), 'y': (2,)} + self.var_shapes = var_shapes + self.trace.var_shapes = var_shapes + + def test_get_values_default(self): + base_shape = (self.draws,) + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + xsample = self.trace.get_values('x') + npt.assert_equal(np.zeros(base_shape + xshape), xsample) + + ysample = self.trace.get_values('y') + npt.assert_equal(np.zeros(base_shape + yshape), ysample) + + def test_get_values_burn_keyword(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples['y'][:burn] = np.ones((burn,) + yshape) + + xsample = self.trace.get_values('x', burn=burn) + npt.assert_equal(np.zeros(base_shape + xshape)[burn:], xsample) + + ysample = self.trace.get_values('y', burn=burn) + npt.assert_equal(np.zeros(base_shape + yshape)[burn:], ysample) + + def test_get_values_thin_keyword(self): + base_shape = (self.draws,) + thin = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + xthin = np.ones((self.draws,) + xshape)[::thin] + ythin = np.ones((self.draws,) + yshape)[::thin] + self.trace.samples['x'][::thin] = xthin + self.trace.samples['y'][::thin] = ythin + + xsample = self.trace.get_values('x', thin=thin) + npt.assert_equal(xthin, xsample) + + ysample = self.trace.get_values('y', thin=thin) + npt.assert_equal(ythin, ysample) + + def test_point(self): + idx = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples['x'][idx] = 1. + self.trace.samples['y'][idx] = 1. + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for varname, value in expected.items(): + npt.assert_equal(value, point[varname]) + + def test_slice(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples['y'][:burn] = np.ones((burn,) + yshape) + + sliced = self.trace[burn:] + + expected = {'x': np.zeros(base_shape + xshape)[burn:], + 'y': np.zeros(base_shape + yshape)[burn:]} + + for varname, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.samples[varname], + expected[varname]) + + +class TestNDArrayMultipleChains(unittest.TestCase): + + def setUp(self): + varnames = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + + self.varnames = varnames + self.var_shapes = var_shapes + self.draws = draws + self.total_draws = 2 * draws + + self.model = mock.Mock() + self.model.unobserved_RVs = varnames + self.model.fastfn = mock.MagicMock() + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + trace0 = ndarray.NDArray(varnames) + trace0.samples = {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))} + trace0.chain = 0 + + trace1 = ndarray.NDArray(varnames) + trace1.samples = {'x': np.ones(draws), + 'y': np.ones((draws, 2))} + trace1.chain = 1 + + self.mtrace = base.MultiTrace([trace0, trace1]) + + def test_chains_multichain(self): + self.mtrace.chains == [0, 1] + + def test_nchains_multichain(self): + self.mtrace.nchains == 1 + + def test_get_values_multi_default(self): + sample = self.mtrace.get_values('x') + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_one_chain_list_arg(self): + sample = self.mtrace.get_values('x', chains=[0]) + xshape = self.var_shapes['x'] + expected = np.zeros((self.draws,) + xshape) + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_one_chain_int_arg(self): + npt.assert_equal(self.mtrace.get_values('x', chains=[0]), + self.mtrace.get_values('x', chains=0)) + + def test_get_values_multi_chains_two_element_reversed(self): + sample = self.mtrace.get_values('x', chains=[1, 0]) + xshape = self.var_shapes['x'] + + expected = [np.ones((self.draws,) + xshape), + np.zeros((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_combine(self): + sample = self.mtrace.get_values('x', combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn(self): + sample = self.mtrace.get_values('x', burn=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn_combine(self): + sample = self.mtrace.get_values('x', burn=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin(self): + sample = self.mtrace.get_values('x', thin=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin_combine(self): + sample = self.mtrace.get_values('x', thin=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]]) + npt.assert_equal(sample, expected) + + def test_multichain_point(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.mtrace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for varname, value in expected.items(): + npt.assert_equal(value, point[varname]) + + def test_multichain_point_chain_arg(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.mtrace.point(idx, chain=0) + expected = {'x': np.squeeze(np.zeros(xshape)), + 'y': np.squeeze(np.zeros(yshape))} + + for varname, value in expected.items(): + npt.assert_equal(value, point[varname]) + + def test_multichain_slice(self): + burn = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + expected = {0: + {'x': np.zeros((self.draws, ) + xshape)[burn:], + 'y': np.zeros((self.draws, ) + yshape)[burn:]}, + 1: + {'x': np.ones((self.draws, ) + xshape)[burn:], + 'y': np.ones((self.draws, ) + yshape)[burn:]}} + + sliced = self.mtrace[burn:] + + for chain in self.mtrace.chains: + for varname, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.get_values(varname, chains=[0]), + expected[0][varname]) + npt.assert_equal(sliced.get_values(varname, chains=[1]), + expected[1][varname]) diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py new file mode 100644 index 0000000000..0864039668 --- /dev/null +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -0,0 +1,128 @@ +import os +import shutil +import numpy as np +import numpy.testing as npt +import unittest + +import pymc as pm + +## Set to False to keep effect of cea5659. Should this be set to True? +TEST_PARALLEL = False + + +def remove_file_or_directory(name): + try: + os.remove(name) + except OSError: + shutil.rmtree(name, ignore_errors=True) + + +class TestCompareNDArraySQLite(unittest.TestCase): + + @classmethod + def setUpClass(cls): + if TEST_PARALLEL: + njobs = 2 + else: + njobs = 1 + + data = np.random.normal(size=(3, 20)) + n = 1 + + model = pm.Model() + draws = 5 + with model: + x = pm.Normal('x', 0, 1., shape=n) + + start = {'x': 0.} + step = pm.Metropolis() + cls.db = 'test.db' + + try: + cls.ntrace = pm.sample(draws, step=step, + njobs=njobs, random_seed=9) + cls.strace = pm.sample(draws, step=step, + njobs=njobs, random_seed=9, + trace=pm.backends.SQLite(cls.db)) + ## Extend each trace. + cls.ntrace = pm.sample(draws, step=step, + njobs=njobs, random_seed=4, + trace=cls.ntrace) + cls.strace = pm.sample(draws, step=step, + njobs=njobs, random_seed=4, + trace=cls.strace) + cls.draws = draws * 2 # Account for extension. + except: + remove_file_or_directory(cls.db) + raise + + @classmethod + def tearDownClass(cls): + remove_file_or_directory(cls.db) + + def test_chain_length(self): + assert self.ntrace.nchains == self.strace.nchains + assert len(self.ntrace) == len(self.strace) + + def test_number_of_draws(self): + nvalues = self.ntrace.get_values('x', squeeze=False) + svalues = self.strace.get_values('x', squeeze=False) + assert nvalues[0].shape[0] == self.draws + assert svalues[0].shape[0] == self.draws + + def test_get_item(self): + npt.assert_equal(self.ntrace['x'], self.strace['x']) + + def test_get_values(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf), + self.strace.get_values('x', combine=cf)) + + def test_get_values_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=False, + squeeze=False), + self.strace.get_values('x', combine=False, + squeeze=False)) + + def test_get_values_combine_and_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=True, + squeeze=False), + self.strace.get_values('x', combine=True, + squeeze=False)) + + def test_get_values_with_burn(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, burn=3), + self.strace.get_values('x', combine=cf, burn=3)) + + ## Burn to one value. + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=self.draws - 1), + self.strace.get_values('x', combine=cf, + burn=self.draws - 1)) + + def test_get_values_with_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, thin=2), + self.strace.get_values('x', combine=cf, thin=2)) + + def test_get_values_with_burn_and_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=2, thin=2), + self.strace.get_values('x', combine=cf, + burn=2, thin=2)) + + def test_get_values_with_chains_arg(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', chains=[0]), + self.strace.get_values('x', chains=[0])) + + def test_point(self): + npoint, spoint = self.ntrace[4], self.strace[4] + npt.assert_equal(npoint['x'], spoint['x']) + + def test_point_with_chain_arg(self): + npoint = self.ntrace.point(4, chain=0) + spoint = self.strace.point(4, chain=0) + npt.assert_equal(npoint['x'], spoint['x']) diff --git a/pymc/tests/test_plots.py b/pymc/tests/test_plots.py index 3d5acc19f1..ae9c70c24e 100644 --- a/pymc/tests/test_plots.py +++ b/pymc/tests/test_plots.py @@ -2,7 +2,7 @@ matplotlib.use('Agg', warn=False) from pymc.plots import * -from pymc import psample, Slice, Metropolis, find_hessian, sample +from pymc import Slice, Metropolis, find_hessian, sample def test_plots(): @@ -31,7 +31,7 @@ def test_multichain_plots(): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(1000, [step1, step2], start, threads=2) + ptrace = sample(1000, [step1, step2], start, njobs=2) forestplot(ptrace, vars=['early_mean', 'late_mean']) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index d4016c40f8..6a15c4725a 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1,38 +1,110 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + import pymc -from pymc import sample, psample, iter_sample +from pymc import sampling +from pymc.sampling import sample from .models import simple_init -# Test if multiprocessing is available -import multiprocessing -try: - multiprocessing.Pool(2) - test_parallel = False -except: - test_parallel = False +## Set to False to keep effect of cea5659. Should this be set to True? +TEST_PARALLEL = False + + +@mock.patch('pymc.sampling._sample') +def test_sample_check_full_signature_single_process(sample_func): + sample('draws', 'step', start='start', trace='trace', njobs=1, chain=1, + tune='tune', progressbar='progressbar', model='model', + random_seed='random_seed') + sample_func.assert_called_with('draws', 'step', 'start', 'trace', 1, + 'tune', 'progressbar', 'model', + 'random_seed') + + +@mock.patch('pymc.sampling._mp_sample') +def test_sample_check_full_signature_mp(sample_func): + sample('draws', 'step', start='start', trace='trace', njobs=2, chain=1, + tune='tune', progressbar='progressbar', model='model', + random_seed=0) + + args = sample_func.call_args_list[0][0] + assert args[0] == 2 + + expected_argset = [('draws', 'step', 'start', 'trace', 1, 'tune', + 'progressbar', 'model', 0), + ('draws', 'step', 'start', 'trace', 2, 'tune', + False, 'model', 0)] + argset = list(args[1]) + assert argset == expected_argset + + +def test_soft_update_all_present(): + start = {'a': 1, 'b': 2} + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == {'a': 1, 'b': 2} + + +def test_soft_update_one_missing(): + start = {'a': 1, } + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == {'a': 1, 'b': 4} + + +def test_soft_update_empty(): + start = {} + test_point = {'a': 3, 'b': 4} + sampling._soft_update(start, test_point) + assert start == test_point def test_sample(): model, start, step, _ = simple_init() - test_samplers = [sample] - - tr = sample(5, step, start, model=model) - test_traces = [None, tr] + test_njobs = [1] - if test_parallel: - test_samplers.append(psample) + if TEST_PARALLEL: + test_njobs.append(2) with model: - for trace in test_traces: - for samplr in test_samplers: - for n in [0, 1, 10, 300]: + for njobs in test_njobs: + for n in [1, 10, 300]: + yield sample, n, step, {}, None, njobs - yield samplr, n, step, {} - yield samplr, n, step, {}, trace - yield samplr, n, step, start def test_iter_sample(): model, start, step, _ = simple_init() - for i, trace in enumerate(iter_sample(5, step, start, model=model)): + samps = sampling.iter_sample(5, step, start, model=model) + for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." + + +class TestChooseBackend(unittest.TestCase): + + def test_choose_backend_none(self): + with mock.patch('pymc.sampling.NDArray') as nd: + sampling._choose_backend(None, 'chain') + self.assertTrue(nd.called) + + def test_choose_backend_list_of_variables(self): + with mock.patch('pymc.sampling.NDArray') as nd: + sampling._choose_backend(['var1', 'var2'], 'chain') + nd.assert_called_with(vars=['var1', 'var2']) + + def test_choose_backend_invalid(self): + self.assertRaises(ValueError, + sampling._choose_backend, + 'invalid', 'chain') + + def test_choose_backend_shortcut(self): + backend = mock.Mock() + shortcuts = {'test_backend': {'backend': backend, + 'name': None}} + sampling._choose_backend('test_backend', 'chain', shortcuts=shortcuts) + self.assertTrue(backend.called) diff --git a/pymc/tests/test_sqlite_backend.py b/pymc/tests/test_sqlite_backend.py new file mode 100644 index 0000000000..d531fd8ab5 --- /dev/null +++ b/pymc/tests/test_sqlite_backend.py @@ -0,0 +1,267 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +import warnings + +from pymc.backends import base, sqlite + + +class SQLiteTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + db_patch = mock.patch('pymc.backends.sqlite._SQLiteDB') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace = sqlite.SQLite('test.db') + + self.draws = 5 + + self.trace.var_shapes = {'x': (), 'y': (3,)} + + self.trace._chains = [0] + self.trace._len = self.draws + + +class TestSQLiteSample(SQLiteTestCase): + + def test_setup_trace(self): + self.trace.setup(self.draws, chain=0) + assert self.trace.db.connect.called + + def test_setup_scalar(self): + trace = self.trace + trace.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT)') + trace.db.cursor.execute.assert_any_call(tbl_expected) + + trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' + 'VALUES (NULL, ?, ?, ?)') + self.assertEqual(trace.var_inserts['x'], trace_expected) + + def test_setup_1d(self): + trace = self.trace + trace.setup(draws=3, chain=0) + trace._chains = [] + + tbl_expected = ('CREATE TABLE IF NOT EXISTS [y] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT, v2 FLOAT, v3 FLOAT)') + trace.db.cursor.execute.assert_any_call(tbl_expected) + + trace_expected = ('INSERT INTO [y] (recid, draw, chain, v1, v2, v3) ' + 'VALUES (NULL, ?, ?, ?, ?, ?)') + self.assertEqual(trace.var_inserts['y'], trace_expected) + + def test_setup_2d(self): + trace = self.trace + trace.var_shapes = {'x': (2, 3)} + trace.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), ' + 'v1_1 FLOAT, v1_2 FLOAT, v1_3 FLOAT, ' + 'v2_1 FLOAT, v2_2 FLOAT, v2_3 FLOAT)') + + trace.db.cursor.execute.assert_any_call(tbl_expected) + trace_expected = ('INSERT INTO [x] (recid, draw, chain, ' + 'v1_1, v1_2, v1_3, ' + 'v2_1, v2_2, v2_3) ' + 'VALUES (NULL, ?, ?, ?, ?, ?, ?, ?, ?)') + self.assertEqual(trace.var_inserts['x'], trace_expected) + + def test_record_scalar(self): + trace = self.trace + trace.setup(draws=3, chain=0) + varname = 'x' + trace.varnames = ['x'] + + trace.draw_idx = 0 + trace.fn = mock.Mock(return_value=iter([3.])) + trace.record({'x': None}) + expected = (0, 0, 3.) + self.assertTrue(expected in self.trace._queue['x']) + + def test_record_1d(self): + trace = self.trace + trace.setup(draws=3, chain=0) + varname = 'x' + trace.varnames = ['x'] + + trace.draw_idx = 0 + trace.fn = mock.Mock(return_value=iter([[3., 3.]])) + trace.record({'x': None}) + expected = (0, 0, 3., 3.) + self.assertTrue(expected in self.trace._queue['x']) + + +class TestSQLiteSelection(SQLiteTestCase): + + def setUp(self): + super(TestSQLiteSelection, self).setUp() + self.trace.var_shapes = {'x': (), 'y': (4,)} + self.trace.setup(self.draws, chain=0) + + ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') + self.addCleanup(ndarray_patch.stop) + ndarray_patch.start() + + self.draws = 5 + + def test_get_values_default_keywords(self): + self.trace.get_values('x') + statement = sqlite.TEMPLATES['select'].format(table='x') + expected = (statement, {'chain': 0}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_burn_arg(self): + self.trace.get_values('x', burn=2).format(table='x') + statement = sqlite.TEMPLATES['select_burn'].format(table='x') + expected = (statement, {'chain': 0, 'burn': 1}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_thin_arg(self): + self.trace.get_values('x', thin=2) + statement = sqlite.TEMPLATES['select_thin'].format(table='x') + expected = (statement, {'chain': 0, 'thin': 2}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_burn_thin_arg(self): + self.trace.get_values('x', thin=2, burn=1) + statement = sqlite.TEMPLATES['select_burn_thin'].format(table='x') + expected = (statement, {'chain': 0, 'burn': 0, 'thin': 2}) + self.trace.db.cursor.execute.assert_called_with(*expected) + + def test_point(self): + idx = 2 + + point = self.trace.point(idx) + statement = sqlite.TEMPLATES['select_point'].format(table='x') + statement_args = {'chain': 0, 'draw': idx} + expected = {'x': (statement, statement_args), + 'y': (statement, statement_args)} + + for varname, value in expected.items(): + self.trace.db.cursor.execute.assert_any_call(*value) + + def test_slice(self): + with warnings.catch_warnings(record=True) as wrn: + self.trace[:10] + self.assertEqual(len(wrn), 1) + self.assertEqual(str(wrn[0].message), + 'Slice for SQLite backend has no effect.') + + +class TestSQLiteSelectionMultipleChains(SQLiteTestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + db_patch = mock.patch('pymc.backends.sqlite._SQLiteDB') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace0 = sqlite.SQLite('test.db') + self.trace1 = sqlite.SQLite('test.db') + + self.draws = 5 + + self.trace0.var_shapes = {'x': (), 'y': (3,)} + self.trace0.chain = 0 + self.trace0._len = self.draws + + self.trace1.var_shapes = {'x': (), 'y': (3,)} + self.trace1.chain = 1 + self.trace1._len = self.draws + + self.mtrace = base.MultiTrace([self.trace0, self.trace1]) + + ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') + self.addCleanup(ndarray_patch.stop) + ndarray_patch.start() + + self.draws = 5 + + def test_get_values_default_keywords(self): + self.mtrace.get_values('x') + + db = self.mtrace._traces[0].db + self.assertEqual(db.cursor.execute.call_count, 2) + + statement = 'SELECT * FROM [x] WHERE (chain = :chain)' + expected = [mock.call(statement, {'chain': chain}) + for chain in (0, 1)] + db.cursor.execute.assert_has_calls(expected) + + def test_get_values_chains_one_given(self): + self.mtrace.get_values('x', chains=[0]) + ## If 0 chain is last call, 1 was not called. + statement = sqlite.TEMPLATES['select'].format(table='x') + expected = (statement, {'chain': 0}) + trace = self.mtrace._traces[0] + trace.db.cursor.execute.assert_called_with(*expected) + + def test_get_values_chains_one_chain_arg(self): + self.mtrace.get_values('x', chains=[0]) + ## If 0 chain is last call, 1 was not called. + statement = sqlite.TEMPLATES['select'].format(table='x') + expected = (statement, {'chain': 0}) + trace = self.mtrace._traces[0] + trace.db.cursor.execute.assert_called_with(*expected) + + +class TestSQLiteLoad(unittest.TestCase): + + def setUp(self): + db_patch = mock.patch('pymc.backends.sqlite._SQLiteDB') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + table_list_patch = mock.patch('pymc.backends.sqlite._get_table_list') + self.addCleanup(table_list_patch.stop) + self.table_list = table_list_patch.start() + self.table_list.return_value = ['x', 'y'] + + def test_load(self): + trace = sqlite.load('test.db') + assert self.table_list.called + assert self.db.called + + +def test_create_column_empty(): + result = sqlite._create_colnames(()) + expected = ['v1'] + assert result == expected + + +def test_create_column_1d(): + result = sqlite._create_colnames((2, )) + expected = ['v1', 'v2'] + assert result == expected + + +def test_create_column_2d(): + result = sqlite._create_colnames((2, 2)) + expected = ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + assert result == expected diff --git a/pymc/tests/test_text_backend.py b/pymc/tests/test_text_backend.py new file mode 100644 index 0000000000..2e428bf3b8 --- /dev/null +++ b/pymc/tests/test_text_backend.py @@ -0,0 +1,170 @@ +import sys +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +if sys.version_info[0] == 2: + from StringIO import StringIO +else: + from io import StringIO +import json + +from pymc.backends import text + + +class TextTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + shape_fh_patch = mock.patch('pymc.backends.text._get_shape_fh') + self.addCleanup(shape_fh_patch.stop) + self.shape_fh = shape_fh_patch.start() + + mkdir_patch = mock.patch('pymc.backends.text.os.mkdir') + self.addCleanup(mkdir_patch.stop) + mkdir_patch.start() + + +class TestTextWrite(TextTestCase): + + def setUp(self): + super(TestTextWrite, self).setUp() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.trace = text.Text('textdb') + + self.draws = 5 + self.trace.var_shapes = {'x': (), 'y': (4,)} + self.trace.setup(self.draws, chain=0) + self.trace.draw_idx = self.draws + + savetxt_patch = mock.patch('pymc.backends.text.np.savetxt') + self.addCleanup(savetxt_patch.stop) + self.savetxt = savetxt_patch.start() + + def test_close_args(self): + trace = self.trace + + trace.close() + + self.assertEqual(self.savetxt.call_count, 2) + + for call, varname in enumerate(trace.varnames): + fname, data = self.savetxt.call_args_list[call][0] + self.assertEqual(fname, 'textdb/chain-0/{}.txt'.format(varname)) + npt.assert_equal(data, trace[varname].reshape(-1, data.size)) + + def test_close_shape(self): + trace = self.trace + + fh = StringIO() + self.shape_fh.return_value.__enter__.return_value = fh + trace.close() + self.shape_fh.assert_called_with('textdb/chain-0', 'w') + + shape_result = fh.getvalue() + expected = {varname: [self.draws] + list(var_shape) + for varname, var_shape in trace.var_shapes.items()} + self.assertEqual(json.loads(shape_result), expected) + + +def test__chain_dir_to_chain(): + assert text._chain_dir_to_chain('/path/to/chain-0') == 0 + assert text._chain_dir_to_chain('chain-0') == 0 + + +class TestTextLoad(TextTestCase): + + def setUp(self): + super(TestTextLoad, self).setUp() + + data = {'chain-1/x.txt': np.zeros(4), 'chain-1/y.txt': np.ones(2)} + loadtxt_patch = mock.patch('pymc.backends.text.np.loadtxt') + self.addCleanup(loadtxt_patch.stop) + self.loadtxt = loadtxt_patch.start() + + chain_patch = mock.patch('pymc.backends.text._get_chain_dirs') + self.addCleanup(chain_patch.stop) + self._get_chain_dirs = chain_patch.start() + + def test_load_model_supplied_scalar(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,)})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros(draws) + self.loadtxt.return_value = data + + mtrace = text.load('textdb', model=self.model) + npt.assert_equal(mtrace.get_values('x', chains=[0]), data) + + def test_load_model_supplied_1d(self): + draws = 4 + var_shape = (2,) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + mtrace = text.load('textdb', model=self.model) + npt.assert_equal(mtrace['x'], data) + + def test_load_model_supplied_2d(self): + draws = 4 + var_shape = (2, 3) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + mtrace = text.load('textdb', model=self.model) + npt.assert_equal(mtrace['x'], data) + + def test_load_model_supplied_multichain_chains(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + mtrace = text.load('textdb', model=self.model) + + self.assertEqual(mtrace.chains, [0, 1]) + + def test_load_model_supplied_multichain_chains_select_one(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + mtrace = text.load('textdb', model=self.model, chains=[1]) + + self.assertEqual(mtrace.chains, [1]) diff --git a/pymc/tests/test_trace.py b/pymc/tests/test_trace.py index 35c03a740e..ef82ee29df 100644 --- a/pymc/tests/test_trace.py +++ b/pymc/tests/test_trace.py @@ -5,130 +5,6 @@ import warnings import nose -# Test if multiprocessing is available -import multiprocessing -try: - multiprocessing.Pool(2) - test_parallel = False -except: - test_parallel = False - - -def check_trace(model, trace, n, step, start): - # try using a trace object a few times - - for i in range(2): - trace = sample( - n, step, start, trace, progressbar=False, model=model) - - for (var, val) in start.items(): - - assert np.shape(trace[var]) == (n * (i + 1),) + np.shape(val) - -def test_trace(): - model, start, step, _ = simple_init() - - for h in [pm.NpTrace]: - for n in [20, 1000]: - for vars in [model.vars, model.vars + [model.vars[0] ** 2]]: - trace = h(vars) - - yield check_trace, model, trace, n, step, start - - -def test_multitrace(): - if not test_parallel: - return - model, start, step, _ = simple_init() - trace = None - for n in [20, 1000]: - - yield check_multi_trace, model, trace, n, step, start - - -def check_multi_trace(model, trace, n, step, start): - - for i in range(2): - trace = psample( - n, step, start, trace, model=model) - - for (var, val) in start.items(): - print([len(tr.samples[var].vals) for tr in trace.traces]) - for t in trace[var]: - assert np.shape(t) == (n * (i + 1),) + np.shape(val) - - ctrace = trace.combined() - for (var, val) in start.items(): - - assert np.shape( - ctrace[var]) == (len(trace.traces) * n * (i + 1),) + np.shape(val) - - -def test_get_point(): - - p, model = simple_2model() - p2 = p.copy() - p2['x'] *= 2. - - x = pm.NpTrace(model.vars) - x.record(p) - x.record(p2) - assert x.point(1) == x[1] - -def test_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = sample(iterations, start=start, step=step, progressbar=False) - - burned = tr[burn:] - - # Slicing returns a trace - assert type(burned) is pm.trace.NpTrace - - # Indexing returns an array - assert type(tr[tr.varnames[0]]) is np.ndarray - - # Burned trace is of the correct length - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v].shape == (iterations, start[v].size) for v in tr.varnames]) - - # Now take more burn-in from the burned trace - burned_again = burned[burn:] - assert np.all([burned_again[v].shape == (iterations-2*burn, start[v].size) for v in burned_again.varnames]) - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - -def test_multi_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = psample(iterations, start=start, step=step, threads=2) - - burned = tr[burn:] - - # Slicing returns a MultiTrace - assert type(burned) is pm.trace.MultiTrace - - # Indexing returns a list of arrays - assert type(tr[tr.varnames[0]]) is list - assert type(tr[tr.varnames[0]][0]) is np.ndarray - - # # Burned trace is of the correct length - assert np.all([burned[v][0].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v][0].shape == (iterations, start[v].size) for v in tr.varnames]) - def test_summary_1_value_model(): mu = -2.1 diff --git a/pymc/trace.py b/pymc/trace.py index a1a70c0a35..2d76549a94 100644 --- a/pymc/trace.py +++ b/pymc/trace.py @@ -5,113 +5,11 @@ import types import warnings -__all__ = ['NpTrace', 'MultiTrace', 'summary'] - -class NpTrace(object): - """ - encapsulates the recording of a process chain - """ - def __init__(self, vars): - vars = list(vars) - model = vars[0].model - self.f = model.fastfn(vars) - self.vars = vars - self.varnames = list(map(str, vars)) - self.samples = dict((v, ListArray()) for v in self.varnames) - - def record(self, point): - """ - Records the position of a chain at a certain point in time. - """ - for var, value in zip(self.varnames, self.f(point)): - self.samples[var].append(value) - return self - - def __getitem__(self, index_value): - """ - Return copy NpTrace with sliced sample values if a slice is passed, - or the array of samples if a varname is passed. - """ - - if isinstance(index_value, slice): - - sliced_trace = NpTrace(self.vars) - sliced_trace.samples = dict((name, vals[index_value]) for (name, vals) in self.samples.items()) - - return sliced_trace - - else: - try: - return self.point(index_value) - except (ValueError, TypeError, IndexError): - pass - - return self.samples[str(index_value)].value - - def __len__(self): - return len(self.samples[self.varnames[0]]) - - def point(self, index): - return dict((k, v.value[index]) for (k, v) in self.samples.items()) - - -class ListArray(object): - def __init__(self, *args): - self.vals = list(args) - - @property - def value(self): - if len(self.vals) > 1: - self.vals = [np.concatenate(self.vals, axis=0)] - - return self.vals[0] - - def __getitem__(self, idx): - return ListArray(self.value[idx]) - - - def append(self, v): - self.vals.append(v[np.newaxis]) - - def __len__(self): - if self.vals: - return self.value.shape[0] - else: - return 0 - - -class MultiTrace(object): - def __init__(self, traces, vars=None): - try: - self.traces = list(traces) - except TypeError: - if vars is None: - raise ValueError("vars can't be None if trace count specified") - self.traces = [NpTrace(vars) for _ in range(traces)] - - def __getitem__(self, index_value): - - item_list = [h[index_value] for h in self.traces] - - if isinstance(index_value, slice): - return MultiTrace(item_list) - return item_list - - @property - def varnames(self): - return self.traces[0].varnames - - def point(self, index): - return [h.point(index) for h in self.traces] - - def combined(self): - # Returns a trace consisting of concatenated MultiTrace elements - h = NpTrace(self.traces[0].vars) - for k in self.traces[0].samples: - h.samples[k].vals = [s[k] for s in self.traces] - return h +__all__ = ['summary'] +# TODO: Move this to pymc.stats. (It was left here for diffing +# purposes). def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): """ Generate a pretty-printed summary of the node. @@ -142,15 +40,13 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): """ if vars is None: vars = trace.varnames - if isinstance(trace, MultiTrace): - trace = trace.combined() stat_summ = _StatSummary(roundto, batches, alpha) pq_summ = _PosteriorQuantileSummary(roundto, alpha) for var in vars: # Extract sampled values - sample = trace[var][start:] + sample = trace.get_values(var, burn=start, combine=True) if sample.ndim == 1: sample = sample[:, None] elif sample.ndim > 2: diff --git a/setup.py b/setup.py index 95052d6128..adcd609177 100755 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python from setuptools import setup +import sys DISTNAME = 'pymc' @@ -25,8 +26,12 @@ 'Topic :: Scientific/Engineering :: Mathematics', 'Operating System :: OS Independent'] -required = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', - 'Theano==0.6.0'] +install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', + 'Theano==0.6.0'] + +test_reqs = ['nose'] +if sys.version_info[0] == 2: # py3 has mock in stdlib + test_reqs.append('mock') if __name__ == "__main__": setup(name=DISTNAME, @@ -39,6 +44,8 @@ long_description=LONG_DESCRIPTION, packages=['pymc', 'pymc.distributions', 'pymc.step_methods', 'pymc.tuning', - 'pymc.tests', 'pymc.glm'], + 'pymc.tests', 'pymc.glm', 'pymc.backends'], classifiers=classifiers, - install_requires=required) + install_requires=install_reqs, + tests_require=test_reqs, + test_suite='nose.collector')