diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 4d03674e6d..b7a77a0722 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -111,31 +111,31 @@ def instantiate_steppers(_model, steps, selected_steps, step_kwargs=None): ---------- model : Model object A fully-specified model object; legacy argument -- ignored - steps : step function or vector of step functions - One or more step functions that have been assigned to some subset of - the model's parameters. Defaults to None (no assigned variables). - selected_steps : dictionary of step methods and variables - The step methods and the variables that have were assigned to them. + steps : list + A list of zero or more step function instances that have been assigned to some subset of + the model's parameters. + selected_steps : dict + A dictionary that maps a step method class to a list of zero or more model variables. step_kwargs : dict Parameters for the samplers. Keys are the lower case names of - the step method, values a dict of arguments. + the step method, values a dict of arguments. Defaults to None. Returns ------- - methods : list - List of step methods associated with the model's variables. + methods : list or step + List of step methods associated with the model's variables, or step method + if there is only one. """ if step_kwargs is None: step_kwargs = {} used_keys = set() for step_class, vars in selected_steps.items(): - if len(vars) == 0: - continue - args = step_kwargs.get(step_class.name, {}) - used_keys.add(step_class.name) - step = step_class(vars=vars, **args) - steps.append(step) + if vars: + args = step_kwargs.get(step_class.name, {}) + used_keys.add(step_class.name) + step = step_class(vars=vars, **args) + steps.append(step) unused_args = set(step_kwargs).difference(used_keys) if unused_args: