diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py new file mode 100644 index 0000000000..8998f26108 --- /dev/null +++ b/pymc/backends/__init__.py @@ -0,0 +1,142 @@ +"""Backends for traces + +Available backends +------------------ + +1. NumPy array (pymc.backends.NDArray) +2. Text files (pymc.backends.Text) +3. SQLite (pymc.backends.SQLite) + +The NumPy arrays and text files both hold the entire trace in memory, +whereas SQLite commits the trace to the database while sampling. + +Selecting a backend +------------------- + +By default, a NumPy array is used as the backend. To specify a different +backend, pass a backend instance to `sample`. + +For example, the following would save traces to the file 'test.db'. + + >>> import pymc as pm + >>> db = pm.backends.SQLite('test.db') + >>> trace = pm.sample(..., db=db) + +Selecting values from a backend +------------------------------- + +After a backend is finished sampling, it returns a Trace object. Values +can be accessed in a few ways. The easiest way is to index the backend +object with a variable or variable name. + + >>> trace['x'] # or trace[x] + +The call will return a list containing the sampling values for all +chains of `x`. (Each call to `pymc.sample` creates a separate chain of +samples.) + +For more control is needed of which values are returned, the +`get_values` method can be used. The call below will return values from +all chains, burning the first 1000 iterations from each chain. + + >>> trace.get_values('x', burn=1000) + +Setting the `combined` flag will concatenate the results from all the +chains. + + >>> trace.get_values('x', burn=1000, combine=True) + +The `chains` parameter of `get_values` can be used to limit the chains +that are retrieved. To work with a subset of chains without having to +specify `chains` each call, you can set the `active_chains` attribute. + + >>> trace.chains + [0, 1, 2] + >>> trace.active_chains = [0, 2] + +After this, only chains 0 and 2 will be used in operations that work +with multiple chains. + +Similary, the `default_chain` attribute sets which chain is used for +functions that require a single chain (e.g., point). + + >>> trace.point(4) # or trace[4] + +Backends can also suppport slicing the trace object. For example, the +following call would return a new trace object without the first 1000 +sampling iterations for all variables. + + >>> sliced_trace = trace[1000:] + +Loading a saved backend +----------------------- + +Saved backends can be loaded using `load` function in the module for the +specific backend. + + >>> trace = pm.backends.sqlite.load('test.db') + +Writing custom backends +----------------------- + +Backends consist of two classes: one that handles storing the sample +results (e.g., backends.ndarray.NDArray or backends.sqlite.SQLite) and +one that handles value selection (e.g., backends.ndarray.Trace or +backends.sqlite.Trace). + +Three methods of the storage class will be called: + +- setup: Before sampling is started, the `setup` method will be called + with two arguments: the number of draws and the chain number. This is + useful setting up any structure for storing the sampling values that + require the above information. + +- record: Record the sampling results for the current draw. This method + will be called with a dictionary of values mapped to the variable + names. This is the only function that *must* do something to have a + meaningful backend. + +- close: This method is called following sampling and should perform any + actions necessary for finalizing and cleaning up the backend. + +The base storage class `backends.base.Backend` provides model setup that +is used by PyMC backends. + +After sampling has completed, the `trace` attribute of the storage +object will be returned. To have a consistent interface with the backend +trace objects in PyMC, this attribute should be an instance of a class +that inherits from pymc.backends.base.Trace, and several methods in the +inherited Trace object should be defined. + +- get_values: This is the core method for selecting values from the + backend. It can be called directly and is used by __getitem__ when the + backend is indexed with a variable name or object. + +- _slice: Defines how the backend returns a slice of itself. This + is called if the backend is indexed with a slice range. + +- point: Returns values for each variables at a single iteration. This + is called if the backend is indexed with a single integer. + +- __len__: This should return the number of draws (for the default + chain). + +- chains: Property that returns a list of chains + +In addtion, a `merge_chains` method should be defined if the backend +will be used with parallel sampling. This method describes how to merge +sampling chains from a list of other traces. + +As mentioned above, the only method necessary to store the sampling +values is `record`. Other methods in the storage may consist of only a +pass statement. The storage object should have an attribute `trace` +(with a `merge_chains` method for parallel sampling), but this does not +have to do anything if storing the values is all that is desired. The +backends.base.Trace is provided for convenience in setting up a +consistent Trace object. + +For specific examples, see pymc.backends.{ndarray,text,sqlite}.py. +""" +from pymc.backends.ndarray import NDArray +from pymc.backends.text import Text +from pymc.backends.sqlite import SQLite diff --git a/pymc/backends/base.py b/pymc/backends/base.py new file mode 100644 index 0000000000..34eda3fd48 --- /dev/null +++ b/pymc/backends/base.py @@ -0,0 +1,219 @@ +"""Base backend for traces + +See the docstring for pymc.backends for more information (includng +creating custom backends). +""" +import numpy as np +from pymc.model import modelcontext + + +class Backend(object): + """Base storage class + + Parameters + ---------- + name : str + Name of backend. + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ + def __init__(self, name, model=None, variables=None): + self.name = name + + model = modelcontext(model) + if variables is None: + variables = model.unobserved_RVs + self.variables = variables + self.var_names = [str(var) for var in variables] + self.fn = model.fastfn(variables) + + ## get variable shapes. common enough that I think most backends + ## will use this + var_values = zip(self.var_names, self.fn(model.test_point)) + self.var_shapes = {var: value.shape + for var, value in var_values} + self.chain = None + self.trace = None + + def setup(self, draws, chain): + """Perform chain-specific setup + + draws : int + Expected number of draws + chain : int + chain number + """ + pass + + def record(self, point): + """Record results of a sampling iteration + + point : dict + Values mappled to variable names + """ + raise NotImplementedError + + def close(self): + """Close the database backend + + This is called after sampling has finished. + """ + pass + + +class Trace(object): + """ + Parameters + ---------- + var_names : list of strs + Sample variables names + backend : Backend object + + Attributes + ---------- + var_names + backend : Backend object + nchains : int + Number of sampling chains + chains : list of ints + List of sampling chain numbers + default_chain : int + Chain to be used if single chain requested + active_chains : list of ints + Values from chains to be used operations + """ + def __init__(self, var_names, backend=None): + self.var_names = var_names + self.backend = backend + self._active_chains = [] + self._default_chain = None + + @property + def default_chain(self): + """Default chain to use for operations that require one chain (e.g., + `point`) + """ + if self._default_chain is None: + return self.active_chains[-1] + return self._default_chain + + @default_chain.setter + def default_chain(self, value): + self._default_chain = value + + @property + def active_chains(self): + """List of chains to be used. Defaults to all. + """ + if not self._active_chains: + return self.chains + return self._active_chains + + @active_chains.setter + def active_chains(self, values): + try: + self._active_chains = [chain for chain in values] + except TypeError: + self._active_chains = [values] + + @property + def nchains(self): + """Number of chains + + A chain is created for each sample call (including parallel + threads). + """ + return len(self.chains) + + def __getitem__(self, idx): + if isinstance(idx, slice): + return self._slice(idx) + + try: + return self.point(idx) + except ValueError: + pass + except TypeError: + pass + return self.get_values(idx) + + ## Selection methods that children must define + + @property + def chains(self): + """All chains in trace""" + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from samples + + Parameters + ---------- + var_name : str + burn : int + thin : int + combine : bool + If True, results from all chains will be concatenated. + chains : list + Chains to retrieve. If None, `active_chains` is used. + squeeze : bool + If `combine` is False, return a single array element if the + resulting list of values only has one element (even if + `combine` is True). + + Returns + ------- + A list of NumPy array of values + """ + raise NotImplementedError + + def _slice(self, idx): + """Slice trace object""" + raise NotImplementedError + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + raise NotImplementedError + + def merge_chains(traces): + """Merge chains from trace instances + + Parameters + ---------- + traces : list + Backend trace instances. Each instance should have only one + chain, and all chain numbers should be unique. + + Raises + ------ + ValueError is raised if any traces have the same current chain + number. + + Returns + ------- + Backend instance with merge chains + """ + raise NotImplementedError + + +def _squeeze_cat(results, combine, squeeze): + """Squeeze and concatenate the results dependending on values of + `combine` and `squeeze`""" + if combine: + results = np.concatenate(results) + if not squeeze: + results = [results] + else: + if squeeze and len(results) == 1: + results = results[0] + return results diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py new file mode 100644 index 0000000000..3960325d97 --- /dev/null +++ b/pymc/backends/ndarray.py @@ -0,0 +1,174 @@ +"""NumPy trace backend + +Store sampling values in memory as a NumPy array. +""" +import numpy as np +from pymc.backends import base + + +class NDArray(base.Backend): + """NDArray storage + + Parameters + ---------- + name : str + Name of backend. + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ + ## make `name` an optional argument for NDArray + def __init__(self, name=None, model=None, variables=None): + super(NDArray, self).__init__(name, model, variables) + + self.trace = Trace(self.var_names, self) + self.draw_idx = 0 + self.draws = None + + def setup(self, draws, chain): + """Perform chain-specific setup + + draws : int + Expected number of draws + chain : int + chain number + """ + self.chain = chain + ## Concatenate new array if chain is already present. + if chain in self.trace.samples: + chain_trace = self.trace.samples[chain] + old_draws = len(self.trace) + self.draws = old_draws + draws + self.draws_idx = old_draws + for var_name, shape in self.var_shapes.items(): + old_trace = chain_trace[var_name] + new_trace = np.zeros((draws, ) + shape) + chain_trace[var_name] = np.concatenate((old_trace, new_trace), + axis=0) + else: # Otherwise, make array of zeros for each variable. + self.draws = draws + var_arrays = {} + for var_name, shape in self.var_shapes.items(): + var_arrays[var_name] = np.zeros((draws, ) + shape) + self.trace.samples[chain] = var_arrays + + def record(self, point): + """Record results of a sampling iteration + + point : dict + Values mappled to variable names + """ + for var_name, value in zip(self.var_names, self.fn(point)): + self.trace.samples[self.chain][var_name][self.draw_idx] = value + self.draw_idx += 1 + + def close(self): + if self.draw_idx == self.draws: + return + ## Remove trailing zeros if interrupted before completed all draws + traces = self.trace.samples[self.chain] + traces = {var: trace[:self.draw_idx] for var, trace in traces.items()} + self.trace.samples[self.chain] = traces + + +class Trace(base.Trace): + + __doc__ = 'NumPy array trace\n' + base.Trace.__doc__ + + def __init__(self, var_names, backend=None): + super(Trace, self).__init__(var_names, backend) + self.samples = {} # chain -> var name -> values + + def __len__(self): + var_name = self.var_names[0] + return self.samples[self.default_chain][var_name].shape[0] + + @property + def chains(self): + """All chains in trace""" + return list(self.samples.keys()) + + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from samples + + Parameters + ---------- + var_name : str + burn : int + thin : int + combine : bool + If True, results from all chains will be concatenated. + chains : list + Chains to retrieve. If None, `active_chains` is used. + squeeze : bool + If `combine` is False, return a single array element if the + resulting list of values only has one element (even if + `combine` is True). + + Returns + ------- + A list of NumPy array of values + """ + if chains is None: + chains = self.active_chains + + var_name = str(var_name) + results = (self.samples[chain][var_name] for chain in chains) + results = [arr[burn::thin] for arr in results] + + return base._squeeze_cat(results, combine, squeeze) + + def _slice(self, idx): + sliced = Trace(self.var_names) + sliced.backend = self.backend + sliced._active_chains = sliced._active_chains + sliced._default_chain = sliced._default_chain + + sliced.samples = {} + for chain, trace in self.samples.items(): + sliced_values = {var_name: values[idx] + for var_name, values in trace.items()} + sliced.samples[chain] = sliced_values + return sliced + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + if chain is None: + chain = self.default_chain + return {var_name: values[idx] + for var_name, values in self.samples[chain].items()} + + def merge_chains(self, traces): + """Merge chains from trace instances + + Parameters + ---------- + traces : list + Backend trace instances. Each instance should have only one + chain, and all chain numbers should be unique. + + Raises + ------ + ValueError is raised if any traces have the same current chain + number. + + Returns + ------- + Backend instance with merge chains + """ + var_name = self.var_names[0] # Select any variable. + for new_trace in traces: + for new_chain in new_trace.chains: + if new_chain in self.chains: + ## Take the new chain if it has more draws. + draws = self.samples[new_chain][var_name].shape[0] + new_draws = new_trace.samples[new_chain][var_name].shape[0] + if draws >= new_draws: + continue + self.samples[new_chain] = new_trace.samples[new_chain] diff --git a/pymc/backends/sqlite.py b/pymc/backends/sqlite.py new file mode 100644 index 0000000000..4e66f35b75 --- /dev/null +++ b/pymc/backends/sqlite.py @@ -0,0 +1,341 @@ +"""SQLite trace backend + +Store sampling values in SQLite database file. + +Database format +--------------- +For each variable, a table is created with the following format: + + recid (INT), draw (INT), chain (INT), v1 (FLOAT), v2 (FLOAT), v3 (FLOAT) ... + +The variable column names are extended to reflect addition dimensions. +For example, a variable with the shape (2, 2) would be stored as + + key (INT), draw (INT), chain (INT), v1_1 (FLOAT), v1_2 (FLOAT), v2_1 (FLOAT) ... + +The key is autoincremented each time a new row is added to the table. +The chain column denotes the chain index, and starts at 0. +""" +import numpy as np +import sqlite3 +import warnings + +from pymc.backends import base + +QUERIES = { + 'table': ('CREATE TABLE IF NOT EXISTS [{table}] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, chain INT(5), ' + '{value_cols})'), + 'insert': ('INSERT INTO [{table}] ' + '(recid, draw, chain, {value_cols}) ' + 'VALUES (NULL, {{draw}}, {chain}, {{value}})'), + 'select': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain})'), + 'select_burn': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain}) AND (draw > {burn})'), + 'select_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain}) AND ' + '(draw - (SELECT draw FROM [{table}] ' + 'WHERE chain = {chain} ' + 'ORDER BY draw LIMIT 1)) % {thin} = 0'), + 'select_burn_thin': ('SELECT * FROM [{table}] ' + 'WHERE (chain = {chain}) AND (draw > {burn}) ' + 'AND (draw - (SELECT draw FROM [{table}] ' + 'WHERE (chain = {chain}) AND (draw > {burn}) ' + 'ORDER BY draw LIMIT 1)) % {thin} = 0'), + 'select_point': ('SELECT * FROM [{{table}}] ' + 'WHERE (chain={chain}) AND (draw={draw})'), + 'max_draw': ('SELECT MAX(draw) FROM [{table}] ' + 'WHERE chain={chain}'), + 'draw_count': ('SELECT COUNT(*) FROM [{table}] ' + 'WHERE chain={chain}'), +} + + +class SQLite(base.Backend): + """SQLite storage + + Parameters + ---------- + name : str + Name of database file + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ + + def __init__(self, name, model=None, variables=None): + super(SQLite, self).__init__(name, model, variables) + + self.trace = Trace(self.var_names, self) + + ## These are set in `setup` to avoid sqlite3.OperationalError + ## (Base Cursor.__init__ not called) when performing parallel + ## sampling + self.conn = None + self.cursor = None + self.connected = False + + self._var_cols = {} + self.var_inserts = {} # var_name -> insert query + self.draw_idx = 0 + + def setup(self, draws, chain): + """Perform chain-specific setup + + draws : int + Expected number of draws + chain : int + chain number + """ + self.chain = chain + + if not self._var_cols: # Table has not been created. + self._var_cols = {var_name: _create_colnames(shape) + for var_name, shape in self.var_shapes.items()} + self._create_table() + else: + self.draw_idx = self.trace._get_max_draw(chain) + 1 + self.trace._len = None + + self._create_insert_queries(chain) + + def _create_table(self): + self.connect() + table = QUERIES['table'] + for var_name, var_cols in self._var_cols.items(): + var_float = ', '.join([v + ' FLOAT' for v in var_cols]) + self.cursor.execute(table.format(table=var_name, + value_cols=var_float)) + + def _create_insert_queries(self, chain): + self.connect() + insert = QUERIES['insert'] + for var_name, var_cols in self._var_cols.items(): + ## Create insert query for each variable + var_str = ', '.join(var_cols) + self.var_inserts[var_name] = insert.format(table=var_name, + value_cols=var_str, + chain=chain) + + def connect(self): + if self.connected: + return + + self.conn = sqlite3.connect(self.name, check_same_thread=False) + self.cursor = self.conn.cursor() + self.connected = True + + def record(self, point): + """Record results of a sampling iteration + + point : dict + Values mappled to variable names + """ + for var_name, value in zip(self.var_names, self.fn(point)): + val_str = ', '.join(['{}'.format(val) for val in np.ravel(value)]) + query = self.var_inserts[var_name].format(draw=self.draw_idx, + value=val_str) + self.cursor.execute(query) + self.draw_idx += 1 + + if not self.draw_idx % 1000: + self.conn.commit() + + def close(self): + if not self.connected: + return + + self.cursor.close() + self.conn.commit() + self.conn.close() + self.connected = False + + +class Trace(base.Trace): + + __doc__ = 'SQLite trace\n' + base.Trace.__doc__ + + def __init__(self, var_names, backend=None): + super(Trace, self).__init__(var_names, backend) + self._len = None + self._chains = None + + def __len__(self): + if self._len is None: + self._len = self._get_number_draws(self.default_chain) + return self._len + + def _get_max_draw(self, chain): + self.backend.connect() + query = QUERIES['max_draw'].format(table=self.var_names[0], + chain=chain) + return self.backend.cursor.execute(query).fetchall()[0][0] + + def _get_number_draws(self, chain): + self.backend.connect() + query = QUERIES['draw_count'].format(table=self.var_names[0], + chain=chain) + return self.backend.cursor.execute(query).fetchall()[0][0] + + @property + def chains(self): + """All chains in trace""" + if self._chains is None: + self.backend.connect() + var_name = self.var_names[0] # any variable should do + self._chains = _get_chain_list(self.backend.cursor, var_name) + return self._chains + + def get_values(self, var_name, burn=0, thin=1, combine=False, chains=None, + squeeze=True): + """Get values from samples + + Parameters + ---------- + var_name : str + burn : int + thin : int + combine : bool + If True, results from all chains will be concatenated. + chains : list + Chains to retrieve. If None, `active_chains` is used. + squeeze : bool + If `combine` is False, return a single array element if the + resulting list of values only has one element (even if + `combine` is True). + + Returns + ------- + A list of NumPy array of values + """ + if burn < 0: + raise ValueError('Negative burn values not supported ' + 'in SQLite backend.') + if thin < 1: + raise ValueError('Only positive thin values are supported ' + 'in SQLite backend.') + if chains is None: + chains = self.active_chains + + var_name = str(var_name) + + query_args = {} + if burn == 0 and thin == 1: + query = 'select' + elif thin == 1: + query = 'select_burn' + query_args = {'burn': burn - 1} + elif burn == 0: + query = 'select_thin' + query_args = {'thin': thin} + else: + query = 'select_burn_thin' + query_args = {'burn': burn - 1, 'thin': thin} + + self.backend.connect() + results = [] + for chain in chains: + call = QUERIES[query].format(table=var_name, chain=chain, + **query_args) + self.backend.cursor.execute(call) + results.append(_rows_to_ndarray(self.backend.cursor)) + return base._squeeze_cat(results, combine, squeeze) + + def _slice(self, idx): + warnings.warn('Slice for SQLite backend has no effect.') + + def point(self, idx, chain=None): + """Return dictionary of point values at `idx` for current chain + with variables names as keys. + + If `chain` is not specified, `default_chain` is used. + """ + if idx < 0: + raise ValueError('Negtive indexing is not supported ' + 'in SQLite backend.') + if chain is None: + chain = self.default_chain + + query = QUERIES['select_point'].format(chain=chain, + draw=idx) + self.backend.connect() + var_values = {} + for var_name in self.var_names: + self.backend.cursor.execute(query.format(table=var_name)) + var_values[var_name] = np.squeeze( + _rows_to_ndarray(self.backend.cursor)) + return var_values + + def merge_chains(self, traces): + pass + + +def _create_colnames(shape): + """Return column names based on `shape` + + Examples + -------- + >>> create_colnames((5,)) + ['v1', 'v2', 'v3', 'v4', 'v5'] + + >>> create_colnames((2,2)) + ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + """ + if not shape: + return ['v1'] + + size = np.prod(shape) + indices = (np.indices(shape) + 1).reshape(-1, size) + return ['v' + '_'.join(map(str, i)) for i in zip(*indices)] + + +def load(name, model=None): + """Load SQLite database from file name + + Parameters + ---------- + name : str + Path to SQLite database file + model : Model + If None, the model is taken from the `with` context. + + Returns + ------- + SQLite backend instance + """ + db = SQLite(name, model=model) + db.connect() + var_names = _get_table_list(db.cursor) + return Trace(var_names, db) + + +def _get_table_list(cursor): + """Return a list of table names in the current database""" + ## Modified from Django. Skips the sqlite_sequence system table used + ## for autoincrement key generation. + cursor.execute("SELECT name FROM sqlite_master " + "WHERE type='table' AND NOT name='sqlite_sequence' " + "ORDER BY name") + return [row[0] for row in cursor.fetchall()] + + +def _get_var_strs(cursor, var_name): + cursor.execute('SELECT * FROM [{}]'.format(var_name)) + col_names = (col_descr[0] for col_descr in cursor.description) + return [name for name in col_names if name.startswith('v')] + + +def _get_chain_list(cursor, var_name): + """Return a list of sorted chains for `var_name`""" + cursor.execute('SELECT DISTINCT chain FROM [{}]'.format(var_name)) + chains = [chain[0] for chain in cursor.fetchall()] + chains.sort() + return chains + + +def _rows_to_ndarray(cursor): + """Convert SQL row to NDArray""" + return np.array([row[3:] for row in cursor.fetchall()]) diff --git a/pymc/backends/text.py b/pymc/backends/text.py new file mode 100644 index 0000000000..5a753b4a7b --- /dev/null +++ b/pymc/backends/text.py @@ -0,0 +1,113 @@ +"""Text file trace backend + +After sampling with NDArray backend, save results as text files. + +Database format +--------------- + +For each chain, a directory named `chain-N` is created. In this +directory, one file per variable is created containing the values of the +object. To deal with multidimensional variables, the array is reshaped +to one dimension before saving with `numpy.savetxt`. The shape +information is saved in a json file in the same directory and is used to +load the database back again using `numpy.loadtxt`. +""" +import os +import glob +import json +import numpy as np +from contextlib import contextmanager + +from pymc.backends import base +from pymc.backends.ndarray import NDArray, Trace + + +class Text(NDArray): + """Text storage + + Parameters + ---------- + name : str + Name of directory to store text files. + model : Model + If None, the model is taken from the `with` context. + variables : list of variable objects + Sampling values will be stored for these variables + """ + def __init__(self, name, model=None, variables=None): + super(Text, self).__init__(name, model, variables) + if not os.path.exists(name): + os.mkdir(name) + + def close(self): + super(Text, self).close() + for chain in self.trace.chains: + chain_name = 'chain-{}'.format(chain) + chain_dir = os.path.join(self.name, chain_name) + os.mkdir(chain_dir) + + shapes = {} + for var_name in self.var_names: + data = self.trace.samples[chain][var_name] + var_file = os.path.join(chain_dir, var_name + '.txt') + np.savetxt(var_file, data.reshape(-1, data.size)) + shapes[var_name] = data.shape + ## Store shape information for reloading. + with _get_shape_fh(chain_dir, 'w') as sfh: + json.dump(shapes, sfh) + + +def load(name, chains=None, model=None): + """Load text database from name + + Parameters + ---------- + name : str + Path to root directory for text database + chains : list or None + Chains to load. If None, all chains are loaded. + model : Model + If None, the model is taken from the `with` context. + + Returns + ------- + ndarray.Trace instance + """ + chain_dirs = _get_chain_dirs(name) + if chains is None: + chains = list(chain_dirs.keys()) + + trace = Trace(None) + + for chain in chains: + chain_dir = chain_dirs[chain] + with _get_shape_fh(chain_dir, 'r') as sfh: + shapes = json.load(sfh) + samples = {} + for var_name, shape in shapes.items(): + var_file = os.path.join(chain_dir, var_name + '.txt') + samples[var_name] = np.loadtxt(var_file).reshape(shape) + trace.samples[chain] = samples + trace.var_names = list(trace.samples[chain].keys()) + return trace + + +## Not opening json directory in `Text.close` and `load` for testing +## convenience +@contextmanager +def _get_shape_fh(chain_dir, mode='r'): + fh = open(os.path.join(chain_dir, 'shapes.json'), mode) + try: + yield fh + finally: + fh.close() + + +def _get_chain_dirs(name): + """Return mapping of chain number to directory""" + return {_chain_dir_to_chain(chain_dir): chain_dir + for chain_dir in glob.glob(os.path.join(name, 'chain-*'))} + + +def _chain_dir_to_chain(chain_dir): + return int(os.path.basename(chain_dir).split('-')[1]) diff --git a/pymc/diagnostics.py b/pymc/diagnostics.py index f5f7669a43..9c81d8897c 100644 --- a/pymc/diagnostics.py +++ b/pymc/diagnostics.py @@ -87,7 +87,7 @@ def geweke(x, first=.1, last=.5, intervals=20): return np.array(zscores) -def gelman_rubin(mtrace): +def gelman_rubin(trace): """ Returns estimate of R for a set of traces. The Gelman-Rubin diagnostic tests for lack of convergence by comparing @@ -99,8 +99,8 @@ def gelman_rubin(mtrace): Parameters ---------- - mtrace : MultiTrace - A MultiTrace object containing parallel traces (minimum 2) + trace + A trace object containing parallel traces (minimum 2) of one or more stochastic parameters. Returns @@ -126,8 +126,7 @@ def gelman_rubin(mtrace): Brooks and Gelman (1998) Gelman and Rubin (1992)""" - m = len(mtrace.traces) - if m < 2: + if trace.nchains < 2: raise ValueError( 'Gelman-Rubin diagnostic requires multiple chains of the same length.') @@ -147,10 +146,11 @@ def calc_rhat(x): return np.sqrt(Vhat/W) Rhat = {} - for var in mtrace.varnames: + for var in trace.var_names: # Get all traces for var - x = np.array([mtrace.traces[i][var] for i in range(m)]) + x = np.array([values[var] for chain in trace.chains + for values in trace.samples.values()]) try: Rhat[var] = calc_rhat(x) @@ -159,9 +159,11 @@ def calc_rhat(x): return Rhat + def trace_to_dataframe(trace): """Convert a PyMC trace consisting of 1-D variables to a pandas DataFrame """ import pandas as pd - return pd.DataFrame({name: np.squeeze(trace_var.vals) - for name, trace_var in trace.samples.items()}) + return pd.DataFrame( + {var_name: np.squeeze(trace.get_values(var_name, combine=True)) + for var_name in trace.var_names}) diff --git a/pymc/examples/gelman_bioassay.py b/pymc/examples/gelman_bioassay.py index 4725ded00b..5aa23d038d 100644 --- a/pymc/examples/gelman_bioassay.py +++ b/pymc/examples/gelman_bioassay.py @@ -28,7 +28,7 @@ def run(n=1000): if n == "short": n = 50 with model: - trace = sample(n, step, trace=model.unobserved_RVs) + trace = sample(n, step) if __name__ == '__main__': run() diff --git a/pymc/examples/sqlite_dump_load.py b/pymc/examples/sqlite_dump_load.py new file mode 100644 index 0000000000..b46636d1f1 --- /dev/null +++ b/pymc/examples/sqlite_dump_load.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import numpy.testing as npt + +import pymc as pm + +# import pydevd +# pydevd.set_pm_excepthook() +np.seterr(invalid='raise') + +data = np.random.normal(size=(2, 20)) +model = pm.Model() + +with model: + x = pm.Normal('x', mu=.5, tau=2. ** -2, shape=(2, 1)) + z = pm.Beta('z', alpha=10, beta=5.5) + d = pm.Normal('data', mu=x, tau=.75 ** -2, observed=data) + + +def run(n=50): + if n == 'short': + n = 5 + with model: + try: + trace = pm.sample(n, step=pm.Metropolis(), + db=pm.backends.SQLite('test.db'), + threads=2) + dumped = pm.backends.sqlite.load('test.db') + + assert trace[x][0].shape[0] == n + assert trace[x][1].shape[0] == n + assert trace.get_values('z', burn=3, + combine=True).shape[0] == n * 2 - 3 * 2 + + assert trace.nchains == dumped.nchains + assert list(sorted(trace.var_names)) == list(sorted(dumped.var_names)) + + for chain in trace.chains: + for var_name in trace.var_names: + data = trace.get_values(var_name, chains=[chain]) + dumped_data = dumped.get_values(var_name, chains=[chain]) + npt.assert_equal(data, dumped_data) + finally: + try: + os.remove('test.db') + except FileNotFoundError: + pass + +if __name__ == '__main__': + run('short') diff --git a/pymc/examples/stochastic_volatility.py b/pymc/examples/stochastic_volatility.py index 2f0297bd63..a62692317f 100644 --- a/pymc/examples/stochastic_volatility.py +++ b/pymc/examples/stochastic_volatility.py @@ -119,12 +119,12 @@ def run(n=2000): if n == "short": n = 50 with model: - trace = sample(5, step, start, trace=model.vars + [sigma]) + trace = sample(5, step, start, variables=model.vars + [sigma]) # Start next run at the last sampled position. start2 = trace.point(-1) step2 = HamiltonianMC(model.vars, hessian(start2, 6), path_length=4.) - trace = sample(2000, step2, trace=trace) + trace = sample(n, step2, db=trace.backend) # diff --git a/pymc/examples/text_dump_load.py b/pymc/examples/text_dump_load.py new file mode 100644 index 0000000000..361d95949d --- /dev/null +++ b/pymc/examples/text_dump_load.py @@ -0,0 +1,47 @@ +import shutil +import numpy as np +import numpy.testing as npt + +import pymc as pm + +# import pydevd +# pydevd.set_pm_excepthook() +np.seterr(invalid='raise') + +data = np.random.normal(size=(2, 20)) +model = pm.Model() + +with model: + x = pm.Normal('x', mu=.5, tau=2. ** -2, shape=(2, 1)) + z = pm.Beta('z', alpha=10, beta=5.5) + d = pm.Normal('data', mu=x, tau=.75 ** -2, observed=data) + + +def run(n=50): + if n == 'short': + n = 5 + with model: + try: + trace = pm.sample(n, step=pm.Metropolis(), + db=pm.backends.Text('textdb'), + threads=2) + dumped = pm.backends.text.load('textdb') + + assert trace[x][0].shape[0] == n + assert trace[x][1].shape[0] == n + assert trace.get_values('z', burn=3, + combine=True).shape[0] == n * 2 - 3 * 2 + + assert trace.nchains == dumped.nchains + assert list(sorted(trace.var_names)) == list(sorted(dumped.var_names)) + + for chain in trace.chains: + for var_name in trace.var_names: + data = trace.samples[chain][var_name] + dumped_data = dumped.samples[chain][var_name] + npt.assert_equal(data, dumped_data) + finally: + shutil.rmtree('textdb') + +if __name__ == '__main__': + run('short') diff --git a/pymc/plots.py b/pymc/plots.py index ce90439913..5f299ddad0 100644 --- a/pymc/plots.py +++ b/pymc/plots.py @@ -7,12 +7,11 @@ import numpy as np from scipy.stats import kde from .stats import * -from .trace import * __all__ = ['traceplot', 'kdeplot', 'kde2plot', 'forestplot', 'autocorrplot'] -def traceplot(trace, vars=None, figsize=None, +def traceplot(trace, var_names=None, figsize=None, lines=None, combined=False, grid=True): """Plot samples histograms and values @@ -20,7 +19,7 @@ def traceplot(trace, vars=None, figsize=None, ---------- trace : result of MCMC run - vars : list of variable names + var_names : list of variable names Variables to be plotted, if None all variable are plotted figsize : figure size tuple If None, size is (12, num of variables * 2) inch @@ -29,8 +28,8 @@ def traceplot(trace, vars=None, figsize=None, lines to the posteriors and horizontal lines on sample values e.g. mean of posteriors, true values of a simulation combined : bool - Flag for combining MultiTrace into a single trace. If False (default) - traces will be plotted separately on the same set of axes. + Flag for combining multiple chains into a single chain. If False + (default), chains will be plotted separately. grid : bool Flag for adding gridlines to histogram. Defaults to True. @@ -40,30 +39,19 @@ def traceplot(trace, vars=None, figsize=None, fig : figure object """ + if var_names is None: + var_names = trace.var_names - if vars is None: - vars = trace.varnames - - if isinstance(trace, MultiTrace): - if combined: - traces = [trace.combined()] - else: - traces = trace.traces - else: - traces = [trace] - - n = len(vars) + n = len(var_names) if figsize is None: figsize = (12, n*2) fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize) - for trace in traces: - for i, v in enumerate(vars): - d = np.squeeze(trace[v]) - - if trace[v].dtype.kind == 'i': + for i, v in enumerate(var_names): + for d in trace.get_values(v, combine=combined, squeeze=False): + if d.dtype.kind == 'i': histplot_op(ax[i, 0], d) else: kdeplot_op(ax[i, 0], d) @@ -138,36 +126,25 @@ def kde2plot(x, y, grid=200): def autocorrplot(trace, vars=None, fontmap=None, max_lag=100): """Bar plot of the autocorrelation function for a trace""" - - try: - # MultiTrace - traces = trace.traces - - except AttributeError: - # NpTrace - traces = [trace] - if fontmap is None: fontmap = {1: 10, 2: 8, 3: 6, 4: 5, 5: 4} if vars is None: - vars = traces[0].varnames + var_names = trace.var_names + else: + var_names = [str(var) for var in vars] - # Extract sample data - samples = [{v: trace[v] for v in vars} for trace in traces] + chains = trace.nchains - chains = len(traces) - - n = len(samples[0]) - f, ax = subplots(n, chains, squeeze=False) + # Extract sample data - max_lag = min(len(samples[0][vars[0]])-1, max_lag) + f, ax = subplots(len(var_names), chains, squeeze=False) - for i, v in enumerate(vars): + max_lag = min(len(trace) - 1, max_lag) + for i, v in enumerate(var_names): for j in range(chains): - - d = np.squeeze(samples[j][v]) + d = np.squeeze(trace.get_values(v, chains=[j])) ax[i, j].acorr(d, detrend=mlab.detrend_mean, maxlags=max_lag) @@ -205,7 +182,7 @@ def var_str(name, shape): names[0] = '%s %s' % (name, names[0]) return names - +## FIXME: This has not been updated to work with backends def forestplot(trace_obj, vars=None, alpha=0.05, quartiles=True, rhat=True, main=None, xtitle=None, xrange=None, ylabels=None, chain_spacing=0.05, vline=0): diff --git a/pymc/progressbar.py b/pymc/progressbar.py index 3027c4ba07..1ca256414e 100644 --- a/pymc/progressbar.py +++ b/pymc/progressbar.py @@ -121,3 +121,22 @@ def progress_bar(iters): return TextProgressBar(iters, ipythonprint) else: return TextProgressBar(iters, consoleprint) + + +def enumerate_progress(iterable, total, meter=None): + """Wrapper `enumerate` in a progress bar update + + Parameters + ---------- + iterable + total : int + meter + Any object with an `update` method that accepts current + iteration as an argument. If none, `progress_bar` is used. + """ + if meter is None: + meter = progress_bar + progress = meter(total) + for i, item in enumerate(iterable): + progress.update(i) + yield i, item diff --git a/pymc/sample.py b/pymc/sample.py index 6220243efc..f3138a0a29 100644 --- a/pymc/sample.py +++ b/pymc/sample.py @@ -1,109 +1,151 @@ from .point import * -from .trace import NpTrace, MultiTrace +from pymc.backends.ndarray import NDArray import multiprocessing as mp from time import time from .core import * from . import step_methods -from .progressbar import progress_bar +from .progressbar import enumerate_progress from numpy.random import seed -__all__ = ['sample', 'psample', 'iter_sample'] +__all__ = ['sample', 'iter_sample'] -def sample(draws, step, start=None, trace=None, tune=None, progressbar=True, model=None, random_seed=None): - """ - Draw a number of samples using the given step method. - Multiple step methods supported via compound step method - returns the amount of time taken. +def sample(draws, step, start=None, db=None, chain=0, threads=1, tune=None, + progressbar=True, model=None, variables=None, random_seed=None): + """Draw samples using the given step method Parameters ---------- - draws : int The number of samples to draw - step : function - A step function + step : step method or list of step methods start : dict - Starting point in parameter space (or partial point) - Defaults to trace.point(-1)) if there is a trace provided and - model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + Starting point in parameter space (or partial point). Defaults + to model.test_point. + db : backend + A storage backend. If None, NDArray is used. + chain : int + Chain number used to store sample in trace. If threads greater + than one, chain numbers will start here + threads : int + Number of parallel traces to start. If None, set to number of + cpus in the system - 2. tune : int - Number of iterations to tune, if applicable (defaults to None) + Number of iterations to tune, if applicable progressbar : bool Flag for progress bar model : Model (optional if in `with` context) + variables : list + Variables to sample. If None, defaults to model.unobserved_RVs. + Ignored if model argument is supplied. + random_seed : int or list of ints + List accepted if more than one thread. + Returns + ------- + Trace object with access to sampling values """ - progress = progress_bar(draws) + if threads is None: + threads = max(mp.cpu_count() - 2, 1) + if threads > 1: + try: + if not len(random_seed) == threads: + random_seeds = [random_seed] * threads + else: + random_seeds = random_seed + except TypeError: # None, int + random_seeds = [random_seed] * threads + + chains = list(range(chain, chain + threads)) + argset = zip([draws] * threads, + [step] * threads, + [start] * threads, + [db] * threads, + chains, + [tune] * threads, + [False] * threads, + [model] * threads, + [variables] * threads, + random_seeds) + sample_func = _thread_sample + sample_args = [threads, argset] + else: + sample_func = _sample + sample_args = [draws, step, start, db, chain, + tune, progressbar, model, variables, random_seed] + return sample_func(*sample_args) + + +def _sample(draws, step, start=None, db=None, chain=0, tune=None, + progressbar=True, model=None, variables=None, random_seed=None): + sampling = _iter_sample(draws, step, start, db, chain, + tune, model, variables, random_seed) + + if progressbar: + sampling = enumerate_progress(sampling, draws) + else: + sampling = enumerate(sampling) try: - for i, trace in enumerate(iter_sample(draws, step, - start=start, - trace=trace, - tune=tune, - model=model, - random_seed=random_seed)): - if progressbar: - progress.update(i) + for i, trace in sampling: + pass except KeyboardInterrupt: - pass + trace.backend.close() return trace -def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, random_seed=None): + +def iter_sample(draws, step, start=None, db=None, chain=0, tune=None, + model=None, variables=None, random_seed=None): """ - Generator that returns a trace on each iteration using the given - step method. Multiple step methods supported via compound step - method returns the amount of time taken. + Generator that returns a trace on each iteration using the given step + method. Parameters ---------- draws : int The number of samples to draw - step : function - A step function + step : step method or list of step methods start : dict - Starting point in parameter space (or partial point) - Defaults to trace.point(-1)) if there is a trace provided and - model.test_point if not (defaults to empty dict) - trace : NpTrace or list - Either a trace of past values or a list of variables to track - (defaults to None) + Starting point in parameter space (or partial point). Defaults + to model.test_point. + db : backend + If None, NDArray is used. + chain : int + Chain number used to store sample in trace. If threads greater + than one, chain numbers will start here tune : int - Number of iterations to tune, if applicable (defaults to None) + Number of iterations to tune, if applicable model : Model (optional if in `with` context) + variables : list + Variables to sample. If None, defaults to model.unobserved_RVs. + Ignored if model argument is supplied. + random_seed : int or list of ints + List accepted if more than one thread. Example ------- for trace in iter_sample(500, step): ... - """ + sampling = _iter_sample(draws, step, start, db, chain, + tune, model, variables, random_seed) + for i, trace in enumerate(sampling): + yield trace[:i + 1] + + +def _iter_sample(draws, step, start=None, db=None, chain=0, tune=None, + model=None, variables=None, random_seed=None): + seed(random_seed) model = modelcontext(model) draws = int(draws) - seed(random_seed) + if draws < 1: + raise ValueError('Argument `draws` should be above 0') if start is None: start = {} - - if isinstance(trace, NpTrace) and len(trace) > 0: - trace_point = trace.point(-1) - trace_point.update(start) - start = trace_point - - else: - test_point = model.test_point.copy() - test_point.update(start) - start = test_point - - if not isinstance(trace, NpTrace): - if trace is None: - trace = model.unobserved_RVs - trace = NpTrace(trace) + _soft_update(start, model.test_point) try: step = step_methods.CompoundStep(step) @@ -112,12 +154,31 @@ def iter_sample(draws, step, start=None, trace=None, tune=None, model=None, rand point = Point(start, model=model) + if db is None: + db = NDArray(model=model, variables=variables) + db.setup(draws, chain) + for i in range(draws): - if (i == tune): + if i == tune: step = stop_tuning(step) point = step.step(point) - trace.record(point) - yield trace + db.record(point) + yield db.trace + else: + db.close() + + +def _thread_sample(threads, args): + p = mp.Pool(threads) + traces = p.map(_argsample, args) + p.close() + traces[0].merge_chains(traces[1:]) + return traces[0] + + +def _argsample(args): + """Defined at top level so it can be pickled""" + return _sample(*args) def stop_tuning(step): @@ -132,71 +193,7 @@ def stop_tuning(step): return step -def argsample(args): - """ defined at top level so it can be pickled""" - return sample(*args) - - -def psample(draws, step, start=None, trace=None, tune=None, model=None, threads=None, - random_seeds=None): - """draw a number of samples using the given step method. - Multiple step methods supported via compound step method - returns the amount of time taken - - Parameters - ---------- - - draws : int - The number of samples to draw - step : function - A step function - start : dict - Starting point in parameter space (Defaults to trace.point(-1)) - trace : MultiTrace or list - Either a trace of past values or a list of variables to track (defaults to None) - tune : int - Number of iterations to tune, if applicable (defaults to None) - model : Model (optional if in `with` context) - threads : int - Number of parallel traces to start - - Examples - -------- - - >>> an example - +def _soft_update(a, b): + """As opposed to dict.update, don't overwrite keys if present. """ - - model = modelcontext(model) - - if not threads: - threads = max(mp.cpu_count() - 2, 1) - - if start is None: - start = {} - - if isinstance(start, dict): - start = threads * [start] - - if trace is None: - trace = model.vars - - if type(trace) is MultiTrace: - mtrace = trace - else: - mtrace = MultiTrace(threads, trace) - - p = mp.Pool(threads) - - if random_seeds is None: - random_seeds = [None] * threads - - argset = zip([draws] * threads, [step] * threads, start, mtrace.traces, - [tune] * threads, [False] * threads, [model] * threads, - random_seeds) - - traces = p.map(argsample, argset) - - p.close() - - return MultiTrace(traces) + a.update({k: v for k, v in b.items() if k not in a}) diff --git a/pymc/stats.py b/pymc/stats.py index b028e35b54..60ca02fe71 100644 --- a/pymc/stats.py +++ b/pymc/stats.py @@ -11,39 +11,28 @@ def statfunc(f): """ def wrapped_f(pymc_obj, *args, **kwargs): + burn = kwargs.pop('burn', 0) + thin = kwargs.pop('thin', 1) + combine = kwargs.pop('combine', False) try: - burn = kwargs.pop('burn') - except KeyError: - burn = 0 - - try: - # MultiTrace - traces = pymc_obj.traces - - try: - vars = kwargs.pop('vars') - except KeyError: - vars = traces[0].varnames - - return [{v: f(trace[v][burn:], *args, **kwargs) for v in vars} for trace in traces] - + var_names = kwargs.pop('vars', pymc_obj.var_names) + chains = kwargs.pop('chains', pymc_obj.active_chains) except AttributeError: - pass - - try: - # NpTrace - try: - vars = kwargs.pop('vars') - except KeyError: - vars = pymc_obj.varnames - - return {v: f(pymc_obj[v][burn:], *args, **kwargs) for v in vars} - except AttributeError: - pass - - # If others fail, assume that raw data is passed - return f(pymc_obj, *args, **kwargs) + # If fails, assume that raw data is passed + return f(pymc_obj, *args, **kwargs) + + results = {chain: {} for chain in chains} + for var_name in var_names: + samples = pymc_obj.get_values(var_name, chains=chains, burn=burn, + thin=thin, combine=combine, + squeeze=False) + for chain, data in zip(chains, samples): + results[chain][var_name] = f(np.squeeze(data), *args, **kwargs) + + if len(chains) == 1 or combine: + results = results[chains[0]] + return results wrapped_f.__doc__ = f.__doc__ wrapped_f.__name__ = f.__name__ diff --git a/pymc/tests/checks.py b/pymc/tests/checks.py index 2e5470bc8f..6d4adf2180 100644 --- a/pymc/tests/checks.py +++ b/pymc/tests/checks.py @@ -1,7 +1,7 @@ import pymc as pm import numpy as np -from pymc import sample, psample +from pymc import sample from numpy.testing import assert_almost_equal diff --git a/pymc/tests/test_diagnostics.py b/pymc/tests/test_diagnostics.py index 8322a7ed27..4525fa2052 100644 --- a/pymc/tests/test_diagnostics.py +++ b/pymc/tests/test_diagnostics.py @@ -9,8 +9,8 @@ def test_gelman_rubin(n=1000): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(n, [step1, step2], start, threads=2, - random_seeds=[1, 3]) + ptrace = sample(n, [step1, step2], start, threads=2, + random_seed=[1, 3]) rhat = gelman_rubin(ptrace) diff --git a/pymc/tests/test_glm.py b/pymc/tests/test_glm.py index 1882d6040c..0262261e45 100644 --- a/pymc/tests/test_glm.py +++ b/pymc/tests/test_glm.py @@ -44,9 +44,9 @@ def test_linear_component(self): step = Slice(model.vars) trace = sample(2000, step, start, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) @unittest.skip("Fails only on travis. Investigate") def test_glm(self): @@ -57,9 +57,9 @@ def test_glm(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 1) - self.assertAlmostEqual(np.mean(trace.samples['sigma'].value), true_sd, 1) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 1) + self.assertAlmostEqual(np.mean(trace['sigma']), true_sd, 1) def test_glm_link_func(self): with Model() as model: @@ -71,5 +71,5 @@ def test_glm_link_func(self): step = Slice(model.vars) trace = sample(2000, step, progressbar=False) - self.assertAlmostEqual(np.mean(trace.samples['Intercept'].value), true_intercept, 1) - self.assertAlmostEqual(np.mean(trace.samples['x'].value), true_slope, 0) + self.assertAlmostEqual(np.mean(trace['Intercept']), true_intercept, 1) + self.assertAlmostEqual(np.mean(trace['x']), true_slope, 0) diff --git a/pymc/tests/test_ndarray_backend.py b/pymc/tests/test_ndarray_backend.py new file mode 100644 index 0000000000..e329ed5844 --- /dev/null +++ b/pymc/tests/test_ndarray_backend.py @@ -0,0 +1,392 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc.backends import ndarray + + +class TestNDArraySampling(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.db = ndarray.NDArray() + + def test_setup_scalar(self): + db = self.db + db.var_shapes = {'x': ()} + draws, chain = 3, 0 + db.setup(draws, chain) + npt.assert_equal(db.trace.samples[chain]['x'], np.zeros(draws)) + + def test_setup_1d(self): + db = self.db + shape = (2,) + db.var_shapes = {'x': shape} + draws, chain = 3, 0 + db.setup(draws, chain) + npt.assert_equal(db.trace.samples[chain]['x'], np.zeros((draws,) + shape)) + + def test_record(self): + db = self.db + draws = 3 + + db.var_shapes = {'x': (), 'y': (4,)} + db.setup(draws, chain=0) + + def just_ones(*args): + while True: + yield 1. + + db.fn = just_ones + db.draw_idx = 0 + + db.record(point=None) + npt.assert_equal(1., db.trace.get_values('x')[0]) + npt.assert_equal(np.ones(4), db.trace['y'][0]) + + def test_clean_interrupt(self): + db = self.db + db.setup(draws=10, chain=0) + db.trace.samples = {0: {'x': np.zeros(10), 'y': np.zeros((10, 5))}} + db.draw_idx = 3 + db.close() + npt.assert_equal(np.zeros(3), db.trace['x']) + npt.assert_equal(np.zeros((3, 5)), db.trace['y']) + + def test_standard_close(self): + db = self.db + db.setup(draws=10, chain=0) + db.trace.samples = {0: {'x': np.zeros(10), 'y': np.zeros((10, 5))}} + db.draw_idx = 10 + db.close() + npt.assert_equal(np.zeros(10), db.trace['x']) + npt.assert_equal(np.zeros((10, 5)), db.trace['y']) + + +class TestNDArraySelection(unittest.TestCase): + + def setUp(self): + var_names = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + self.trace = ndarray.Trace(var_names) + self.trace.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.draws = draws + self.var_names = var_names + self.var_shapes = var_shapes + + def test_chains_single_chain(self): + self.trace.chains == [0] + + def test_nchains_single_chain(self): + self.trace.nchains == 1 + + def test_get_values_default(self): + base_shape = (self.draws,) + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + xsample = self.trace.get_values('x') + npt.assert_equal(np.zeros(base_shape + xshape), xsample) + + ysample = self.trace.get_values('y') + npt.assert_equal(np.zeros(base_shape + yshape), ysample) + + def test_get_values_burn_keyword(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples[chain]['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples[chain]['y'][:burn] = np.ones((burn,) + yshape) + + xsample = self.trace.get_values('x', burn=burn) + npt.assert_equal(np.zeros(base_shape + xshape)[burn:], xsample) + + ysample = self.trace.get_values('y', burn=burn) + npt.assert_equal(np.zeros(base_shape + yshape)[burn:], ysample) + + def test_get_values_thin_keyword(self): + base_shape = (self.draws,) + thin = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + xthin = np.ones((self.draws,) + xshape)[::thin] + ythin = np.ones((self.draws,) + yshape)[::thin] + self.trace.samples[chain]['x'][::thin] = xthin + self.trace.samples[chain]['y'][::thin] = ythin + + xsample = self.trace.get_values('x', thin=thin) + npt.assert_equal(xthin, xsample) + + ysample = self.trace.get_values('y', thin=thin) + npt.assert_equal(ythin, ysample) + + def test_point(self): + idx = 2 + chain = 0 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples[chain]['x'][idx] = 1. + self.trace.samples[chain]['y'][idx] = 1. + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_slice(self): + base_shape = (self.draws,) + burn = 2 + chain = 0 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + ## Make traces distinguishable + self.trace.samples[chain]['x'][:burn] = np.ones((burn,) + xshape) + self.trace.samples[chain]['y'][:burn] = np.ones((burn,) + yshape) + + sliced = self.trace[burn:] + + expected = {'x': np.zeros(base_shape + xshape)[burn:], + 'y': np.zeros(base_shape + yshape)[burn:]} + + for var_name, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.samples[chain][var_name], + expected[var_name]) + + +class TestNDArrayMultipleChains(unittest.TestCase): + + def setUp(self): + var_names = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + draws = 3 + self.trace = ndarray.Trace(var_names) + self.trace.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}, + 1: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.draws = draws + self.var_names = var_names + self.var_shapes = var_shapes + self.total_draws = 2 * draws + + def test_chains_multichain(self): + self.trace.chains == [0, 1] + + def test_nchains_multichain(self): + self.trace.nchains == [0, 1] + + def test_get_values_multi_default(self): + sample = self.trace.get_values('x') + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_only_one_element(self): + sample = self.trace.get_values('x', chains=[0]) + xshape = self.var_shapes['x'] + + expected = np.zeros((self.draws,) + xshape) + + npt.assert_equal(sample, expected) + + def test_get_values_multi_chains_two_element_reversed(self): + sample = self.trace.get_values('x', chains=[1, 0]) + xshape = self.var_shapes['x'] + + expected = [np.ones((self.draws,) + xshape), + np.zeros((self.draws,) + xshape)] + npt.assert_equal(sample, expected) + + def test_get_values_multi_combine(self): + sample = self.trace.get_values('x', combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape), + np.ones((self.draws,) + xshape)]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn(self): + sample = self.trace.get_values('x', burn=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn_combine(self): + sample = self.trace.get_values('x', burn=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[2:], + np.ones((self.draws,) + xshape)[2:]]) + npt.assert_equal(sample, expected) + + def test_get_values_multi_burn_one_active_chain(self): + self.trace.active_chains = 0 + sample = self.trace.get_values('x', burn=2) + xshape = self.var_shapes['x'] + + expected = np.zeros((self.draws,) + xshape)[2:] + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin(self): + sample = self.trace.get_values('x', thin=2) + xshape = self.var_shapes['x'] + + expected = [np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]] + npt.assert_equal(sample, expected) + + def test_get_values_multi_thin_combine(self): + sample = self.trace.get_values('x', thin=2, combine=True) + xshape = self.var_shapes['x'] + + expected = np.concatenate([np.zeros((self.draws,) + xshape)[::2], + np.ones((self.draws,) + xshape)[::2]]) + npt.assert_equal(sample, expected) + + def test_multichain_point(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.ones(xshape)), + 'y': np.squeeze(np.ones(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_multichain_point_chain_arg(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + point = self.trace.point(idx, chain=0) + expected = {'x': np.squeeze(np.zeros(xshape)), + 'y': np.squeeze(np.zeros(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_multichain_point_change_default_chain(self): + idx = 2 + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + self.trace.default_chain = 0 + + point = self.trace.point(idx) + expected = {'x': np.squeeze(np.zeros(xshape)), + 'y': np.squeeze(np.zeros(yshape))} + + for var_name, value in expected.items(): + npt.assert_equal(value, point[var_name]) + + def test_multichain_slice(self): + base_shapes = [(self.draws,)] * 2 + burn = 2 + + xshape = self.var_shapes['x'] + yshape = self.var_shapes['y'] + + expected = {0: + {'x': np.zeros((self.draws, ) + xshape)[burn:], + 'y': np.zeros((self.draws, ) + yshape)[burn:]}, + 1: + {'x': np.ones((self.draws, ) + xshape)[burn:], + 'y': np.ones((self.draws, ) + yshape)[burn:]}} + + sliced = self.trace[burn:] + + for chain in self.trace.chains: + for var_name, var_shape in self.var_shapes.items(): + npt.assert_equal(sliced.samples[chain][var_name], + expected[chain][var_name]) + +class TestMergeChains(unittest.TestCase): + + def setUp(self): + var_names = ['x', 'y'] + var_shapes = {'x': (), 'y': (2,)} + self.trace1 = ndarray.Trace(var_names) + + self.trace2 = ndarray.Trace(var_names) + self.var_names = var_names + self.var_shapes = var_shapes + + def test_merge_chains_two_traces(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.trace2.samples = {1: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + + self.trace1.merge_chains([self.trace2]) + self.assertEqual(self.trace1.samples[1], self.trace2.samples[1]) + + def test_merge_chains_two_traces_same_slot_same_size(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.trace2.samples = {0: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.trace1.merge_chains([self.trace2]) + npt.assert_equal(self.trace1.samples[0]['x'], np.zeros(draws)) + + def test_merge_chains_two_traces_same_slot_base_longer(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws), + 'y': np.zeros((draws, 2))}} + self.trace2.samples = {0: + {'x': np.ones(draws - 1), + 'y': np.ones((draws - 1, 2))}} + self.trace1.merge_chains([self.trace2]) + npt.assert_equal(self.trace1.samples[0]['x'], np.zeros(draws)) + + def test_merge_chains_two_traces_same_slot_new_longer(self): + draws = 3 + self.trace1.samples = {0: + {'x': np.zeros(draws - 1), + 'y': np.zeros((draws - 1, 2))}} + self.trace2.samples = {0: + {'x': np.ones(draws), + 'y': np.ones((draws, 2))}} + self.trace1.merge_chains([self.trace2]) + npt.assert_equal(self.trace1.samples[0]['x'], np.ones(draws)) diff --git a/pymc/tests/test_ndarray_sqlite_selection.py b/pymc/tests/test_ndarray_sqlite_selection.py new file mode 100644 index 0000000000..d076a7ef35 --- /dev/null +++ b/pymc/tests/test_ndarray_sqlite_selection.py @@ -0,0 +1,127 @@ +import os +import numpy as np +import numpy.testing as npt +import unittest +import multiprocessing as mp + +import pymc as pm + + +def remove_db(db): + try: + os.remove(db) + except FileNotFoundError: + pass + + +class TestCompareNDArraySQLite(unittest.TestCase): + + @classmethod + def setUpClass(cls): + ## Use two threads if available + try: + mp.Pool(2) + threads = 2 + except: + threads = 1 + + data = np.random.normal(size=(3, 20)) + n = 1 + + model = pm.Model() + draws = 5 + with model: + x = pm.Normal('x', 0, 1., shape=n) + + start = {'x': 0.} + step = pm.Metropolis() + cls.db = 'test.db' + + try: + cls.ntrace = pm.sample(draws, step=step, + threads=threads, random_seed=9) + cls.strace = pm.sample(draws, step=step, + threads=threads, random_seed=9, + db=pm.backends.SQLite(cls.db)) + ## Extend each trace. + cls.ntrace = pm.sample(draws, step=step, + threads=threads, random_seed=4, + db=cls.ntrace.backend) + cls.strace = pm.sample(draws, step=step, + threads=threads, random_seed=4, + db=cls.strace.backend) + cls.draws = draws * 2 # Account for extension. + except: + remove_db(cls.db) + raise + + @classmethod + def tearDownClass(cls): + remove_db(cls.db) + + def test_chain_length(self): + assert self.ntrace.nchains == self.strace.nchains + assert len(self.ntrace) == len(self.strace) + + def test_number_of_draws(self): + nvalues = self.ntrace.get_values('x', squeeze=False) + svalues = self.strace.get_values('x', squeeze=False) + assert nvalues[0].shape[0] == self.draws + assert svalues[0].shape[0] == self.draws + + def test_get_item(self): + npt.assert_equal(self.ntrace['x'], self.strace['x']) + + def test_get_values(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf), + self.strace.get_values('x', combine=cf)) + + def test_get_values_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=False, + squeeze=False), + self.strace.get_values('x', combine=False, + squeeze=False)) + + def test_get_values_combine_and_no_squeeze(self): + npt.assert_equal(self.ntrace.get_values('x', combine=True, + squeeze=False), + self.strace.get_values('x', combine=True, + squeeze=False)) + + def test_get_values_with_burn(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, burn=3), + self.strace.get_values('x', combine=cf, burn=3)) + + ## burn so just one value left + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=self.draws - 1), + self.strace.get_values('x', combine=cf, + burn=self.draws - 1)) + + def test_get_values_with_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, thin=2), + self.strace.get_values('x', combine=cf, thin=2)) + + def test_get_values_with_burn_and_thin(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', combine=cf, + burn=2, thin=2), + self.strace.get_values('x', combine=cf, + burn=2, thin=2)) + + def test_get_values_with_chains_arg(self): + for cf in [False, True]: + npt.assert_equal(self.ntrace.get_values('x', chains=[0]), + self.strace.get_values('x', chains=[0])) + + def test_point(self): + npoint, spoint = self.ntrace[4], self.strace[4] + npt.assert_equal(npoint['x'], spoint['x']) + + def test_point_with_chain_arg(self): + npoint = self.ntrace.point(4, chain=0) + spoint = self.strace.point(4, chain=0) + npt.assert_equal(npoint['x'], spoint['x']) diff --git a/pymc/tests/test_plots.py b/pymc/tests/test_plots.py index f31b9a5006..b782a969fe 100644 --- a/pymc/tests/test_plots.py +++ b/pymc/tests/test_plots.py @@ -1,6 +1,6 @@ #from ..plots import * from pymc.plots import * -from pymc import psample, Slice, Metropolis, find_hessian, sample +from pymc import Slice, Metropolis, find_hessian, sample def test_plots(): @@ -15,7 +15,8 @@ def test_plots(): step = Metropolis(model.vars, h) trace = sample(3000, step, start) - forestplot(trace) + # FIXME: forestplot has not been rewritten for backend + # forestplot(trace) autocorrplot(trace) @@ -29,8 +30,9 @@ def test_multichain_plots(): step1 = Slice([dm.early_mean, dm.late_mean]) step2 = Metropolis([dm.switchpoint]) start = {'early_mean': 2., 'late_mean': 3., 'switchpoint': 50} - ptrace = psample(1000, [step1, step2], start, threads=2) + ptrace = sample(1000, [step1, step2], start, threads=2) - forestplot(ptrace, vars=['early_mean', 'late_mean']) + # FIXME: forestplot has not been rewritten for backend + # forestplot(ptrace, vars=['early_mean', 'late_mean']) autocorrplot(ptrace, vars=['switchpoint']) diff --git a/pymc/tests/test_progressbar.py b/pymc/tests/test_progressbar.py new file mode 100644 index 0000000000..ea6bab90a6 --- /dev/null +++ b/pymc/tests/test_progressbar.py @@ -0,0 +1,17 @@ +try: + import unittest.mock as mock # py3 +except ImportError: + import mock + +from pymc import progressbar + + +def test_enumerate_progress(): + iterable = list(range(5, 8)) + meter = mock.Mock() + results = list(progressbar.enumerate_progress(iterable, + len(iterable), + meter)) + for i, _ in enumerate(iterable): + assert meter.update.called_with(i) + assert list(zip(*results))[1] == tuple(iterable) diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index d4016c40f8..0e9b4d5bfd 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1,5 +1,12 @@ -import pymc -from pymc import sample, psample, iter_sample +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest + +from pymc import sample, iter_sample from .models import simple_init # Test if multiprocessing is available @@ -15,24 +22,19 @@ def test_sample(): model, start, step, _ = simple_init() - test_samplers = [sample] - - tr = sample(5, step, start, model=model) - test_traces = [None, tr] + test_threads = [1] if test_parallel: - test_samplers.append(psample) + test_threads.append(2) with model: - for trace in test_traces: - for samplr in test_samplers: - for n in [0, 1, 10, 300]: + for threads in test_threads: + for n in [1, 10, 300]: + yield sample, n, step, {}, None, threads - yield samplr, n, step, {} - yield samplr, n, step, {}, trace - yield samplr, n, step, start def test_iter_sample(): model, start, step, _ = simple_init() - for i, trace in enumerate(iter_sample(5, step, start, model=model)): + samps = iter_sample(5, step, start, model=model) + for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." diff --git a/pymc/tests/test_sqlite_backend.py b/pymc/tests/test_sqlite_backend.py new file mode 100644 index 0000000000..51e0f5e9dd --- /dev/null +++ b/pymc/tests/test_sqlite_backend.py @@ -0,0 +1,244 @@ +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +import warnings + +from pymc.backends import sqlite + + +class SQLiteTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.db = sqlite.SQLite('test.db') + + self.draws = 5 + + self.db.cursor = mock.Mock() + self.db.var_shapes = {'x': (), 'y': (3,)} + + self.db.trace._chains = [0] + self.db.trace._len = self.draws + + connect_patch = mock.patch('pymc.backends.sqlite.SQLite.connect') + self.addCleanup(connect_patch.stop) + self.connect = connect_patch.start() + + +class TestSQLiteSample(SQLiteTestCase): + + def test_setup_trace(self): + self.db.setup(self.draws, chain=0) + assert self.connect.called + + def test_setup_scalar(self): + db = self.db + db.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT)') + db.cursor.execute.assert_any_call(tbl_expected) + + trace_expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' + 'VALUES (NULL, {draw}, 0, {value})') + self.assertEqual(db.var_inserts['x'], trace_expected) + + def test_setup_1d(self): + db = self.db + db.setup(draws=3, chain=0) + db.trace._chains = [] + + tbl_expected = ('CREATE TABLE IF NOT EXISTS [y] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), v1 FLOAT, v2 FLOAT, v3 FLOAT)') + db.cursor.execute.assert_any_call(tbl_expected) + + trace_expected = ('INSERT INTO [y] (recid, draw, chain, v1, v2, v3) ' + 'VALUES (NULL, {draw}, 0, {value})') + self.assertEqual(db.var_inserts['y'], trace_expected) + + def test_setup_2d(self): + db = self.db + db.var_shapes = {'x': (2, 3)} + db.setup(draws=3, chain=0) + tbl_expected = ('CREATE TABLE IF NOT EXISTS [x] ' + '(recid INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ' + 'draw INTEGER, ' + 'chain INT(5), ' + 'v1_1 FLOAT, v1_2 FLOAT, v1_3 FLOAT, ' + 'v2_1 FLOAT, v2_2 FLOAT, v2_3 FLOAT)') + + db.cursor.execute.assert_any_call(tbl_expected) + trace_expected = ('INSERT INTO [x] (recid, draw, chain, ' + 'v1_1, v1_2, v1_3, ' + 'v2_1, v2_2, v2_3) ' + 'VALUES (NULL, {draw}, 0, {value})') + self.assertEqual(db.var_inserts['x'], trace_expected) + + def test_record_scalar(self): + db = self.db + db.setup(draws=3, chain=0) + + db.fn = mock.Mock(return_value=iter([3.])) + var_name = 'x' + db.var_names = ['x'] + + query = sqlite.QUERIES['insert'].format(table=var_name, + value_cols='v1', + chain=0) + db.var_inserts = {'x': query} + db.draw_idx = 0 + db.record({'x': 3.}) + expected = ('INSERT INTO [x] (recid, draw, chain, v1) ' + 'VALUES (NULL, 0, 0, 3.0)') + db.cursor.execute.assert_any_call(expected) + + def test_record_1d(self): + db = self.db + db.setup(draws=3, chain=0) + + db.fn = mock.Mock(return_value=iter([[3., 3.]])) + var_name = 'x' + db.var_names = ['x'] + + query = sqlite.QUERIES['insert'].format(table=var_name, + value_cols='v1, v2', + chain=0) + db.var_inserts = {'x': query} + db.draw_idx = 0 + db.record({'x': [3., 3.]}) + expected = ('INSERT INTO [x] (recid, draw, chain, v1, v2) ' + 'VALUES (NULL, 0, 0, 3.0, 3.0)') + db.cursor.execute.assert_any_call(expected) + + +class SQLiteSelectionTestCase(SQLiteTestCase): + + def setUp(self): + super(SQLiteSelectionTestCase, self).setUp() + self.db.var_shapes = {'x': (), 'y': (4,)} + self.db.setup(self.draws, chain=0) + + ndarray_patch = mock.patch('pymc.backends.sqlite._rows_to_ndarray') + self.addCleanup(ndarray_patch.stop) + ndarray_patch.start() + + self.draws = 5 + self.db.trace.active_chains = 0 + + +class TestSQLiteSelection(SQLiteSelectionTestCase): + + def test_get_values_default_keywords(self): + self.db.trace.get_values('x') + expected = 'SELECT * FROM [x] WHERE (chain = 0)' + self.db.cursor.execute.assert_called_with(expected) + + def test_get_values_burn_arg(self): + self.db.trace.get_values('x', burn=2) + expected = 'SELECT * FROM [x] WHERE (chain = 0) AND (draw > 1)' + self.db.cursor.execute.assert_called_with(expected) + + def test_get_values_thin_arg(self): + self.db.trace.get_values('x', thin=2) + expected = ('SELECT * FROM [x] ' + 'WHERE (chain = 0) AND ' + '(draw - (SELECT draw FROM [x] ' + 'WHERE chain = 0 ' + 'ORDER BY draw LIMIT 1)) % 2 = 0') + self.db.cursor.execute.assert_called_with(expected) + + def test_get_values_burn_thin_arg(self): + self.db.trace.get_values('x', thin=2, burn=1) + expected = ('SELECT * FROM [x] ' + 'WHERE (chain = 0) AND (draw > 0) ' + 'AND (draw - (SELECT draw FROM [x] ' + 'WHERE (chain = 0) AND (draw > 0) ' + 'ORDER BY draw LIMIT 1)) % 2 = 0') + self.db.cursor.execute.assert_called_with(expected) + + def test_point(self): + idx = 2 + + point = self.db.trace.point(idx) + expected = {'x': + 'SELECT * FROM [x] WHERE (chain=0) AND (draw=2)', + 'y': + 'SELECT * FROM [y] WHERE (chain=0) AND (draw=2)'} + + for var_name, value in expected.items(): + self.db.cursor.execute.assert_any_call(value) + + def test_slice(self): + with warnings.catch_warnings(record=True) as wrn: + self.db.trace[:10] + self.assertEqual(len(wrn), 1) + self.assertEqual(str(wrn[0].message), + 'Slice for SQLite backend has no effect.') + + +class TestSQLiteSelectionMultipleChains(SQLiteSelectionTestCase): + + def setUp(self): + super(TestSQLiteSelectionMultipleChains, self).setUp() + self.db.trace.active_chains = [0, 1] + + def test_get_values_default_keywords(self): + self.db.trace.get_values('x') + expected = ['SELECT * FROM [x] WHERE (chain = 0)', + 'SELECT * FROM [x] WHERE (chain = 1)'] + for value in expected: + self.db.cursor.execute.assert_any_call(value) + + def test_get_values_chains_one_given(self): + self.db.trace.get_values('x', chains=[0]) + expected = 'SELECT * FROM [x] WHERE (chain = 0)' + ## If 0 chain is last call, 1 was not called + self.db.cursor.execute.assert_called_with(expected) + + +class TestSQLiteLoad(unittest.TestCase): + + def setUp(self): + db_patch = mock.patch('pymc.backends.sqlite.SQLite') + self.addCleanup(db_patch.stop) + self.db = db_patch.start() + + table_list_patch = mock.patch('pymc.backends.sqlite._get_table_list') + self.addCleanup(table_list_patch.stop) + self.table_list = table_list_patch.start() + self.table_list.return_value = ['x', 'y'] + + def test_load(self): + trace = sqlite.load('test.db') + assert self.table_list.called + assert self.db.called + +def test_create_column_empty(): + result = sqlite._create_colnames(()) + expected = ['v1'] + assert result == expected + + +def test_create_column_1d(): + result = sqlite._create_colnames((2, )) + expected = ['v1', 'v2'] + assert result == expected + + +def test_create_column_2d(): + result = sqlite._create_colnames((2, 2)) + expected = ['v1_1', 'v1_2', 'v2_1', 'v2_2'] + assert result == expected diff --git a/pymc/tests/test_text_backend.py b/pymc/tests/test_text_backend.py new file mode 100644 index 0000000000..520324fa82 --- /dev/null +++ b/pymc/tests/test_text_backend.py @@ -0,0 +1,170 @@ +import sys +import numpy as np +import numpy.testing as npt +try: + import unittest.mock as mock # py3 +except ImportError: + import mock +import unittest +if sys.version_info[0] == 2: + from StringIO import StringIO +else: + from io import StringIO +import json + +from pymc.backends import text + + +class TextTestCase(unittest.TestCase): + + def setUp(self): + self.variables = ['x', 'y'] + self.model = mock.Mock() + self.model.unobserved_RVs = self.variables + self.model.fastfn = mock.MagicMock() + + shape_fh_patch = mock.patch('pymc.backends.text._get_shape_fh') + self.addCleanup(shape_fh_patch.stop) + self.shape_fh = shape_fh_patch.start() + + mkdir_patch = mock.patch('pymc.backends.text.os.mkdir') + self.addCleanup(mkdir_patch.stop) + mkdir_patch.start() + + +class TestTextWrite(TextTestCase): + + def setUp(self): + super(TestTextWrite, self).setUp() + + with mock.patch('pymc.backends.base.modelcontext') as context: + context.return_value = self.model + self.db = text.Text('textdb') + + self.draws = 5 + self.db.var_shapes = {'x': (), 'y': (4,)} + self.db.setup(self.draws, chain=0) + self.db.draw_idx = self.draws + + savetxt_patch = mock.patch('pymc.backends.text.np.savetxt') + self.addCleanup(savetxt_patch.stop) + self.savetxt = savetxt_patch.start() + + def test_close_args(self): + db = self.db + + db.close() + + self.assertEqual(self.savetxt.call_count, 2) + + for call, var_name in enumerate(db.var_names): + fname, data = self.savetxt.call_args_list[call][0] + self.assertEqual(fname, 'textdb/chain-0/{}.txt'.format(var_name)) + npt.assert_equal(data, db.trace[var_name].reshape(-1, data.size)) + + def test_close_shape(self): + db = self.db + + fh = StringIO() + self.shape_fh.return_value.__enter__.return_value = fh + db.close() + self.shape_fh.assert_called_with('textdb/chain-0', 'w') + + shape_result = fh.getvalue() + expected = {var_name: [self.draws] + list(var_shape) + for var_name, var_shape in db.var_shapes.items()} + self.assertEqual(json.loads(shape_result), expected) + + +def test__chain_dir_to_chain(): + assert text._chain_dir_to_chain('/path/to/chain-0') == 0 + assert text._chain_dir_to_chain('chain-0') == 0 + + +class TestTextLoad(TextTestCase): + + def setUp(self): + super(TestTextLoad, self).setUp() + + data = {'chain-1/x.txt': np.zeros(4), 'chain-1/y.txt': np.ones(2)} + loadtxt_patch = mock.patch('pymc.backends.text.np.loadtxt') + self.addCleanup(loadtxt_patch.stop) + self.loadtxt = loadtxt_patch.start() + + chain_patch = mock.patch('pymc.backends.text._get_chain_dirs') + self.addCleanup(chain_patch.stop) + self._get_chain_dirs = chain_patch.start() + + def test_load_model_supplied_scalar(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,)})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros(draws) + self.loadtxt.return_value = data + + trace = text.load('textdb', model=self.model) + npt.assert_equal(trace.samples[0]['x'], data) + + def test_load_model_supplied_1d(self): + draws = 4 + var_shape = (2,) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + db = text.load('textdb', model=self.model) + npt.assert_equal(db['x'], data) + + def test_load_model_supplied_2d(self): + draws = 4 + var_shape = (2, 3) + self._get_chain_dirs.return_value = {0: 'chain-0'} + fh = StringIO(json.dumps({'x': (draws,) + var_shape})) + self.shape_fh.return_value.__enter__.return_value = fh + + data = np.zeros((draws,) + var_shape) + self.loadtxt.return_value = data.reshape(-1, data.size) + + db = text.load('textdb', model=self.model) + npt.assert_equal(db['x'], data) + + def test_load_model_supplied_multichain_chains(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + trace = text.load('textdb', model=self.model) + + self.assertEqual(trace.chains, [0, 1]) + + def test_load_model_supplied_multichain_chains_select_one(self): + draws = 4 + self._get_chain_dirs.return_value = {0: 'chain-0', 1: 'chain-1'} + + def chain_fhs(): + for chain in [0, 1]: + yield StringIO(json.dumps({'x': (draws,)})) + fhs = chain_fhs() + + self.shape_fh.return_value.__enter__ = lambda x: next(fhs) + + data = np.zeros(draws) + self.loadtxt.return_value = data + + trace = text.load('textdb', model=self.model, chains=[1]) + + self.assertEqual(trace.chains, [1]) diff --git a/pymc/tests/test_trace.py b/pymc/tests/test_trace.py index 35c03a740e..ef82ee29df 100644 --- a/pymc/tests/test_trace.py +++ b/pymc/tests/test_trace.py @@ -5,130 +5,6 @@ import warnings import nose -# Test if multiprocessing is available -import multiprocessing -try: - multiprocessing.Pool(2) - test_parallel = False -except: - test_parallel = False - - -def check_trace(model, trace, n, step, start): - # try using a trace object a few times - - for i in range(2): - trace = sample( - n, step, start, trace, progressbar=False, model=model) - - for (var, val) in start.items(): - - assert np.shape(trace[var]) == (n * (i + 1),) + np.shape(val) - -def test_trace(): - model, start, step, _ = simple_init() - - for h in [pm.NpTrace]: - for n in [20, 1000]: - for vars in [model.vars, model.vars + [model.vars[0] ** 2]]: - trace = h(vars) - - yield check_trace, model, trace, n, step, start - - -def test_multitrace(): - if not test_parallel: - return - model, start, step, _ = simple_init() - trace = None - for n in [20, 1000]: - - yield check_multi_trace, model, trace, n, step, start - - -def check_multi_trace(model, trace, n, step, start): - - for i in range(2): - trace = psample( - n, step, start, trace, model=model) - - for (var, val) in start.items(): - print([len(tr.samples[var].vals) for tr in trace.traces]) - for t in trace[var]: - assert np.shape(t) == (n * (i + 1),) + np.shape(val) - - ctrace = trace.combined() - for (var, val) in start.items(): - - assert np.shape( - ctrace[var]) == (len(trace.traces) * n * (i + 1),) + np.shape(val) - - -def test_get_point(): - - p, model = simple_2model() - p2 = p.copy() - p2['x'] *= 2. - - x = pm.NpTrace(model.vars) - x.record(p) - x.record(p2) - assert x.point(1) == x[1] - -def test_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = sample(iterations, start=start, step=step, progressbar=False) - - burned = tr[burn:] - - # Slicing returns a trace - assert type(burned) is pm.trace.NpTrace - - # Indexing returns an array - assert type(tr[tr.varnames[0]]) is np.ndarray - - # Burned trace is of the correct length - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v].shape == (iterations, start[v].size) for v in tr.varnames]) - - # Now take more burn-in from the burned trace - burned_again = burned[burn:] - assert np.all([burned_again[v].shape == (iterations-2*burn, start[v].size) for v in burned_again.varnames]) - assert np.all([burned[v].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - -def test_multi_slice(): - - model, start, step, moments = simple_init() - - iterations = 100 - burn = 10 - - with model: - tr = psample(iterations, start=start, step=step, threads=2) - - burned = tr[burn:] - - # Slicing returns a MultiTrace - assert type(burned) is pm.trace.MultiTrace - - # Indexing returns a list of arrays - assert type(tr[tr.varnames[0]]) is list - assert type(tr[tr.varnames[0]][0]) is np.ndarray - - # # Burned trace is of the correct length - assert np.all([burned[v][0].shape == (iterations-burn, start[v].size) for v in burned.varnames]) - - # Original trace did not change - assert np.all([tr[v][0].shape == (iterations, start[v].size) for v in tr.varnames]) - def test_summary_1_value_model(): mu = -2.1 diff --git a/pymc/trace.py b/pymc/trace.py index 7b174d7504..cf43352bf5 100644 --- a/pymc/trace.py +++ b/pymc/trace.py @@ -5,124 +5,19 @@ import types import warnings -__all__ = ['NpTrace', 'MultiTrace', 'summary'] - -class NpTrace(object): - """ - encapsulates the recording of a process chain - """ - def __init__(self, vars): - vars = list(vars) - model = vars[0].model - self.f = model.fastfn(vars) - self.vars = vars - self.varnames = list(map(str, vars)) - self.samples = dict((v, ListArray()) for v in self.varnames) - - def record(self, point): - """ - Records the position of a chain at a certain point in time. - """ - for var, value in zip(self.varnames, self.f(point)): - self.samples[var].append(value) - return self - - def __getitem__(self, index_value): - """ - Return copy NpTrace with sliced sample values if a slice is passed, - or the array of samples if a varname is passed. - """ - - if isinstance(index_value, slice): - - sliced_trace = NpTrace(self.vars) - sliced_trace.samples = dict((name, vals[index_value]) for (name, vals) in self.samples.items()) - - return sliced_trace - - else: - try: - return self.point(index_value) - except ValueError: - pass - except TypeError: - pass - - return self.samples[str(index_value)].value - - def __len__(self): - return len(self.samples[self.varnames[0]]) - - def point(self, index): - return dict((k, v.value[index]) for (k, v) in self.samples.items()) - - -class ListArray(object): - def __init__(self, *args): - self.vals = list(args) - - @property - def value(self): - if len(self.vals) > 1: - self.vals = [np.concatenate(self.vals, axis=0)] - - return self.vals[0] - - def __getitem__(self, idx): - return ListArray(self.value[idx]) - - - def append(self, v): - self.vals.append(v[np.newaxis]) - - def __len__(self): - if self.vals: - return self.value.shape[0] - else: - return 0 - - -class MultiTrace(object): - def __init__(self, traces, vars=None): - try: - self.traces = list(traces) - except TypeError: - if vars is None: - raise ValueError("vars can't be None if trace count specified") - self.traces = [NpTrace(vars) for _ in range(traces)] - - def __getitem__(self, index_value): - - item_list = [h[index_value] for h in self.traces] - - if isinstance(index_value, slice): - return MultiTrace(item_list) - return item_list - - @property - def varnames(self): - return self.traces[0].varnames - - def point(self, index): - return [h.point(index) for h in self.traces] - - def combined(self): - # Returns a trace consisting of concatenated MultiTrace elements - h = NpTrace(self.traces[0].vars) - for k in self.traces[0].samples: - h.samples[k].vals = [s[k] for s in self.traces] - return h +__all__ = ['summary'] -def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): +def summary(trace, var_names=None, alpha=0.05, start=0, batches=100, + roundto=3): """ Generate a pretty-printed summary of the node. :Parameters: trace : Trace object - Trace containing MCMC sample + Trace containing MCMC samples - vars : list of strings + var_names : list of strings List of variables to summarize. Defaults to None, which results in all variables summarized. @@ -142,25 +37,23 @@ def summary(trace, vars=None, alpha=0.05, start=0, batches=100, roundto=3): The number of digits to round posterior statistics. """ - if vars is None: - vars = trace.varnames - if isinstance(trace, MultiTrace): - trace = trace.combined() + if var_names is None: + var_names = trace.var_names stat_summ = _StatSummary(roundto, batches, alpha) pq_summ = _PosteriorQuantileSummary(roundto, alpha) - for var in vars: + for var_name in var_names: # Extract sampled values - sample = trace[var][start:] + sample = trace.get_values(var_name, burn=start, combine=True) if sample.ndim == 1: sample = sample[:, None] elif sample.ndim > 2: ## trace dimensions greater than 2 (variable greater than 1) - warnings.warn('Skipping {} (above 1 dimension)'.format(var)) + warnings.warn('Skipping {} (above 1 dimension)'.format(var_name)) continue - print('\n%s:' % var) + print('\n%s:' % var_name) print(' ') stat_summ.print_output(sample) diff --git a/setup.py b/setup.py index 95052d6128..adcd609177 100755 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ #!/usr/bin/env python from setuptools import setup +import sys DISTNAME = 'pymc' @@ -25,8 +26,12 @@ 'Topic :: Scientific/Engineering :: Mathematics', 'Operating System :: OS Independent'] -required = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', - 'Theano==0.6.0'] +install_reqs = ['numpy>=1.7.1', 'scipy>=0.12.0', 'matplotlib>=1.2.1', + 'Theano==0.6.0'] + +test_reqs = ['nose'] +if sys.version_info[0] == 2: # py3 has mock in stdlib + test_reqs.append('mock') if __name__ == "__main__": setup(name=DISTNAME, @@ -39,6 +44,8 @@ long_description=LONG_DESCRIPTION, packages=['pymc', 'pymc.distributions', 'pymc.step_methods', 'pymc.tuning', - 'pymc.tests', 'pymc.glm'], + 'pymc.tests', 'pymc.glm', 'pymc.backends'], classifiers=classifiers, - install_requires=required) + install_requires=install_reqs, + tests_require=test_reqs, + test_suite='nose.collector')