diff --git a/CHANGES b/CHANGES index 35ceb4a8fe..70df418482 100644 --- a/CHANGES +++ b/CHANGES @@ -1,6 +1,7 @@ Next release ============ +* ENH: Introduced runtime decisions (https://github.com/nipy/nipype/pull/1299) * ENH: Provides a Nipype wrapper for ANTs DenoiseImage (https://github.com/nipy/nipype/pull/1291) * FIX: Minor bugfix logging hash differences (https://github.com/nipy/nipype/pull/1298) * FIX: Use released Prov python library (https://github.com/nipy/nipype/pull/1279) diff --git a/doc/users/index.rst b/doc/users/index.rst index 3a432135a6..d560dbd7cf 100644 --- a/doc/users/index.rst +++ b/doc/users/index.rst @@ -33,6 +33,7 @@ function_interface mapnode_and_iterables joinnode_and_itersource + runtime_decisions model_specification saving_workflows spmmcr diff --git a/doc/users/runtime_decisions.rst b/doc/users/runtime_decisions.rst new file mode 100644 index 0000000000..cfe3debb55 --- /dev/null +++ b/doc/users/runtime_decisions.rst @@ -0,0 +1,120 @@ +.. runtime_decisions: + +=========================== +Runtime decisions in nipype +=========================== + +Adding conditional execution (https://github.com/nipy/nipype/issues/878) +other runtime decisions (https://github.com/nipy/nipype/issues/819) in +nipype is an old request. Here we introduce some logic and signalling into +the workflows. + +Disable signal in nodes +======================= + +The :class:`nipype.pipeline.engine.Node` now provides a `signals` attribute +with a `disable` signal by default. +When the `run()` member of a node is called, the interface will run +normally *iff* `disable` is `False` (default case). + +Example +------- + +For instance, the following code will run the BET interface from fsl: + + >>> from nipype.pipeline.engine import Node + >>> from nipype.interfaces import fsl + >>> bet = Node(fsl.BET(), 'BET') + >>> bet.inputs.in_file = 'T1.nii' + >>> bet.run() # doctest: +SKIP + +However, if we set the disable signal, then the interface is not run. + + >>> bet.signals.disable = True + >>> bet.run() is None + True + +Disable signal in Workflow +========================== + +:class:`nipype.pipeline.engine.Workflow` also provides signals, including +`disable` by default. +It is also allowed to connect the output of a node to a signal in a workflow, +using the `signalnode.` port. + + +Example +------- + + >>> from nipype.pipeline import engine as pe + >>> from nipype.interfaces import utility as niu + >>> def _myfunc(val): + ... return val + 1 + >>> wf = pe.Workflow('TestDisableWorkflow') + >>> inputnode = pe.Node(niu.IdentityInterface( + ... fields=['in_value']), 'inputnode') + >>> outputnode = pe.Node(niu.IdentityInterface( + ... fields=['out_value']), 'outputnode') + >>> func = pe.Node(niu.Function( + ... input_names=['val'], output_names=['out'], + ... function=_myfunc), 'functionnode') + >>> wf.connect([ + ... (inputnode, func, [('in_value', 'val')]), + ... (ifset, outputnode, [('out', 'out_value')]) + ... ]) + >>> wf.inputs.inputnode.in_value = 0 + >>> wf.run() # Will produce 1 in outputnode.out_value + +The workflow can be disabled: + + >>> wf.signals.disabled = True + >>> wf.run() # The outputnode.out_value remains + + +CachedWorkflow +============== + +The :class:`nipype.pipeline.engine.CachedWorkflow` is a type of workflow +that implements a conditional workflow that is executed *iff* the set of +cached inputs is not set. +More precisely, this workflow is able to decide whether its nodes should +be executed or not if all the inputs of the input node called `cachenode` +are set. +For instance, in https://github.com/nipy/nipype/pull/1081 this feature +is requested. +The implementation makes use of :class:`nipype.interfaces.utility.CheckInterface` +which produces an output `out` set to `True` if any/all the inputs are defined +and `False` otherwise. +The input `operation` allows to switch between the any and all conditions. + + +Example +------- + + >>> from nipype.pipeline import engine as pe + >>> from nipype.interfaces import utility as niu + >>> def _myfunc(a, b): + ... return a + b + >>> wf = pe.CachedWorkflow('InnerWorkflow', + ... cache_map=('c', 'out')) + >>> inputnode = pe.Node(niu.IdentityInterface( + ... fields=['a', 'b']), 'inputnode') + >>> func = pe.Node(niu.Function( + ... input_names=['a', 'b'], output_names=['out'], + ... function=_myfunc), 'functionnode') + >>> wf.connect([ + ... (inputnode, func, [('a', 'a'), ('b', 'b')]), + ... (func, 'output', [('out', 'out')]) + ... ]) + >>> wf.inputs.inputnode.a = 2 + >>> wf.inputs.inputnode.b = 3 + >>> wf.run() # Will generate 5 in outputnode.out + +Please note that the output node should be referred to as 'output' in +the *connect()* call. + +If we set all the inputs of the cache, then the workflow is skipped and +the output is mapped from the cache: + + >>> wf.inputs.cachenode.c = 7 + >>> wf.run() # Will produce 7 in outputnode.out diff --git a/nipype/__init__.py b/nipype/__init__.py index 3db5a359de..e057c2e78c 100644 --- a/nipype/__init__.py +++ b/nipype/__init__.py @@ -83,6 +83,6 @@ def _test_local_install(): pass -from .pipeline import Node, MapNode, JoinNode, Workflow +from .pipeline import engine from .interfaces import (DataGrabber, DataSink, SelectFiles, IdentityInterface, Rename, Function, Select, Merge) diff --git a/nipype/caching/memory.py b/nipype/caching/memory.py index 67ef605e32..d8b14fb396 100644 --- a/nipype/caching/memory.py +++ b/nipype/caching/memory.py @@ -21,7 +21,7 @@ from ..interfaces.base import BaseInterface from ..pipeline.engine import Node -from ..pipeline.utils import modify_paths +from ..pipeline.engine.utils import modify_paths ################################################################################ # PipeFunc object: callable interface to nipype.interface objects diff --git a/nipype/caching/tests/test_memory.py b/nipype/caching/tests/test_memory.py index 784eca1b93..89afcfa7a7 100644 --- a/nipype/caching/tests/test_memory.py +++ b/nipype/caching/tests/test_memory.py @@ -7,7 +7,7 @@ from nose.tools import assert_equal from .. import Memory -from ...pipeline.tests.test_engine import TestInterface +from ...pipeline.engine.tests.test_engine import TestInterface from ... import config config.set_default_config() diff --git a/nipype/interfaces/base.py b/nipype/interfaces/base.py index e831fc67ce..78f5b0de59 100644 --- a/nipype/interfaces/base.py +++ b/nipype/interfaces/base.py @@ -40,9 +40,8 @@ TraitListObject, TraitError, isdefined, File, Directory, has_metadata) -from ..utils.filemanip import (md5, hash_infile, FileNotFoundError, - hash_timestamp, save_json, - split_filename) +from ..utils.filemanip import md5, \ + FileNotFoundError, save_json, split_filename from ..utils.misc import is_container, trim, str2bool from ..utils.provenance import write_provenance from .. import config, logging, LooseVersion @@ -463,6 +462,7 @@ def _deprecated_warn(self, obj, name, old, new): def _hash_infile(self, adict, key): """ Inject file hashes into adict[key]""" + from nipype.utils.filemanip import hash_infile, hash_timestamp stuff = adict[key] if not is_container(stuff): stuff = [stuff] @@ -578,6 +578,7 @@ def get_hashval(self, hash_method=None): def _get_sorteddict(self, object, dictwithhash=False, hash_method=None, hash_files=True): + from nipype.utils.filemanip import hash_infile, hash_timestamp if isinstance(object, dict): out = [] for key, val in sorted(object.items()): diff --git a/nipype/interfaces/io.py b/nipype/interfaces/io.py index ed5c0b5f9f..378b966031 100644 --- a/nipype/interfaces/io.py +++ b/nipype/interfaces/io.py @@ -920,7 +920,8 @@ class SelectFiles(IOBase): -------- >>> import pprint - >>> from nipype import SelectFiles, Node + >>> from nipype.pipeline.engine import Node + >>> from nipype.interfaces.io import SelectFiles >>> templates={"T1": "{subject_id}/struct/T1.nii", ... "epi": "{subject_id}/func/f[0, 1].nii"} >>> dg = Node(SelectFiles(templates), "selectfiles") diff --git a/nipype/interfaces/utility.py b/nipype/interfaces/utility.py index 37883d4e5c..4cfeef3c72 100644 --- a/nipype/interfaces/utility.py +++ b/nipype/interfaces/utility.py @@ -566,3 +566,84 @@ def _list_outputs(self): entry = self._parse_line(line) outputs = self._append_entry(outputs, entry) return outputs + + +class CheckInterfaceOutputSpec(TraitedSpec): + out = traits.Bool(False, desc='Inputs meet condition') + + +class CheckInterface(IOBase): + """ + Interface that performs checks on inputs + + Examples + -------- + + >>> from nipype.interfaces.utility import CheckInterface + >>> checkif = CheckInterface(fields=['a', 'b'], operation='any') + >>> checkif._list_outputs()['out'] + False + + >>> checkif.inputs.a = 'foo' + >>> out = checkif.run() + >>> checkif._list_outputs()['out'] + True + + >>> checkif.inputs.operation = 'all' + >>> out = checkif.run() + >>> checkif._list_outputs()['out'] + False + + >>> checkif.inputs.b = 'bar' + >>> out = checkif.run() + >>> checkif._list_outputs()['out'] + True + """ + input_spec = DynamicTraitedSpec + output_spec = CheckInterfaceOutputSpec + _always_run = True + + def __init__(self, fields=None, operation='all', **inputs): + super(CheckInterface, self).__init__(**inputs) + + if fields is None or not fields: + raise ValueError('CheckInterface fields must be a non-empty ' + 'list') + + if 'operation' in fields: + raise ValueError('CheckInterface does not allow fields using' + ' special name \'operation\'') + # Each input must be in the fields. + for in_field in inputs: + if in_field not in fields: + raise ValueError('CheckInterface input is not in the ' + 'fields: %s' % in_field) + self._fields = fields + add_traits(self.inputs, fields + ['operation']) + + # Adding any traits wipes out all input values set in superclass initialization, + # even it the trait is not in the add_traits argument. The work-around is to reset + # the values after adding the traits. + self.inputs.set(**inputs) + + if operation not in ['all', 'any']: + raise ValueError('CheckInterface does not accept key word ' + '\'%s\' as operation input' % operation) + self.inputs.operation = operation + + def _check_result(self): + if self.inputs.operation not in ['all', 'any']: + raise ValueError('CheckInterface does not accept keyword ' + '\'%s\' as operation input' % operation) + + results = [isdefined(getattr(self.inputs, key)) + for key in self._fields if key != 'operation'] + + if self.inputs.operation == 'any': + return any(results) + return all(results) + + def _list_outputs(self): + outputs = self._outputs().get() + outputs['out'] = self._check_result() + return outputs diff --git a/nipype/pipeline/__init__.py b/nipype/pipeline/__init__.py index b7a6afe20e..32d72a062b 100644 --- a/nipype/pipeline/__init__.py +++ b/nipype/pipeline/__init__.py @@ -7,4 +7,5 @@ from __future__ import absolute_import __docformat__ = 'restructuredtext' -from .engine import Node, MapNode, JoinNode, Workflow + +from .engine import * diff --git a/nipype/pipeline/engine/__init__.py b/nipype/pipeline/engine/__init__.py new file mode 100644 index 0000000000..7563b7cdf6 --- /dev/null +++ b/nipype/pipeline/engine/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +from .workflows import * +from .nodes import * diff --git a/nipype/pipeline/engine/base.py b/nipype/pipeline/engine/base.py new file mode 100644 index 0000000000..9b8750823c --- /dev/null +++ b/nipype/pipeline/engine/base.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Defines functionality for pipelined execution of interfaces + +The `NodeBase` class implements the more general view of a task. + + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) + +""" + +from future import standard_library +standard_library.install_aliases() +from builtins import object + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +from copy import deepcopy +import re +import numpy as np +from nipype.interfaces.traits_extension import traits, Undefined +from nipype.interfaces.base import DynamicTraitedSpec +from nipype.utils.filemanip import loadpkl, savepkl + +from nipype import logging +logger = logging.getLogger('workflow') + + +class EngineBase(object): + """Defines common attributes and functions for workflows and nodes.""" + + def __init__(self, name=None, base_dir=None): + """ Initialize base parameters of a workflow or node + + Parameters + ---------- + name : string (mandatory) + Name of this node. Name must be alphanumeric and not contain any + special characters (e.g., '.', '@'). + base_dir : string + base output directory (will be hashed before creations) + default=None, which results in the use of mkdtemp + + """ + self.base_dir = base_dir + self.config = None + self._verify_name(name) + self.name = name + # for compatibility with node expansion using iterables + self._id = self.name + self._hierarchy = None + + @property + def inputs(self): + raise NotImplementedError + + @property + def outputs(self): + raise NotImplementedError + + @property + def fullname(self): + fullname = self.name + if self._hierarchy: + fullname = self._hierarchy + '.' + self.name + return fullname + + def clone(self, name): + """Clone a workflowbase object + + Parameters + ---------- + + name : string (mandatory) + A clone of node or workflow must have a new name + """ + if (name is None) or (name == self.name): + raise Exception('Cloning requires a new name') + self._verify_name(name) + clone = deepcopy(self) + clone.name = name + clone._id = name + clone._hierarchy = None + return clone + + def _check_outputs(self, parameter): + return hasattr(self.outputs, parameter) + + def _check_inputs(self, parameter): + if isinstance(self.inputs, DynamicTraitedSpec): + return True + return hasattr(self.inputs, parameter) + + def _verify_name(self, name): + valid_name = bool(re.match('^[\w-]+$', name)) + if not valid_name: + raise ValueError('[Workflow|Node] name \'%s\' contains' + ' special characters' % name) + + def __repr__(self): + if self._hierarchy: + return '.'.join((self._hierarchy, self._id)) + else: + return self._id + + def save(self, filename=None): + if filename is None: + filename = 'temp.pklz' + savepkl(filename, self) + + def load(self, filename): + if '.npz' in filename: + DeprecationWarning(('npz files will be deprecated in the next ' + 'release. you can use numpy to open them.')) + return np.load(filename) + return loadpkl(filename) + + +class WorkflowSignalTraits(traits.HasTraits): + def __init__(self, **kwargs): + """ Initialize handlers and inputs""" + # NOTE: In python 2.6, object.__init__ no longer accepts input + # arguments. HasTraits does not define an __init__ and + # therefore these args were being ignored. + # super(TraitedSpec, self).__init__(*args, **kwargs) + super(WorkflowSignalTraits, self).__init__(**kwargs) + traits.push_exception_handler(reraise_exceptions=True) + undefined_traits = {} + for trait in self.copyable_trait_names(): + if not self.traits()[trait].usedefault: + undefined_traits[trait] = Undefined + self.trait_set(trait_change_notify=False, **undefined_traits) + self.set(**kwargs) + + +class BaseSignals(WorkflowSignalTraits): + disable = traits.Bool(False, usedefault=True) + + +class NodeBase(EngineBase): + def __init__(self, name, base_dir=None, control=True): + """Create a workflow object. + + Parameters + ---------- + name : alphanumeric string + unique identifier for the workflow + base_dir : string, optional + path to workflow storage + + """ + super(NodeBase, self).__init__(name, base_dir) + # Initialize signals + self._signals = None + if control: + self._signals = BaseSignals() + for elem in self._signals.copyable_trait_names(): + self._signals.on_trait_change(self._update_disable, elem) + + @property + def signals(self): + return self._signals + + def _update_disable(self): + pass diff --git a/nipype/pipeline/utils.py b/nipype/pipeline/engine/graph.py similarity index 77% rename from nipype/pipeline/utils.py rename to nipype/pipeline/engine/graph.py index 64caa482eb..603e37bfef 100644 --- a/nipype/pipeline/utils.py +++ b/nipype/pipeline/engine/graph.py @@ -1,6 +1,14 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Utility routines for workflow graphs + + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) + """ from future import standard_library @@ -18,142 +26,34 @@ from collections import OrderedDict from copy import deepcopy -from glob import glob from collections import defaultdict import os import re import numpy as np -from nipype.utils.misc import package_check from functools import reduce -package_check('networkx', '1.3') +from nipype.utils.misc import package_check +from nipype.external.six import string_types +from nipype.utils.filemanip import fname_presuffix +from nipype.utils.misc import create_function_from_source +from nipype.interfaces.base import (CommandLine, isdefined, Undefined, + InterfaceResult) +from nipype.interfaces.utility import IdentityInterface +from nipype.utils.provenance import ProvStore, pm, nipype_ns, get_id -import networkx as nx +from nipype import logging, config +from .utils import get_print_name -from ..external.six import string_types -from ..utils.filemanip import (fname_presuffix, FileNotFoundError, - filename_to_list, get_related_files) -from ..utils.misc import create_function_from_source, str2bool -from ..interfaces.base import (CommandLine, isdefined, Undefined, - InterfaceResult) -from ..interfaces.utility import IdentityInterface -from ..utils.provenance import ProvStore, pm, nipype_ns, get_id -from .. import logging, config +package_check('networkx', '1.3') logger = logging.getLogger('workflow') +import networkx as nx try: dfs_preorder = nx.dfs_preorder + logger.debug('detected networkx < 1.4 dev') except AttributeError: dfs_preorder = nx.dfs_preorder_nodes - logger.debug('networkx 1.4 dev or higher detected') - -try: - from os.path import relpath -except ImportError: - import os.path as op - - def relpath(path, start=None): - """Return a relative version of a path""" - if start is None: - start = os.curdir - if not path: - raise ValueError("no path specified") - start_list = op.abspath(start).split(op.sep) - path_list = op.abspath(path).split(op.sep) - if start_list[0].lower() != path_list[0].lower(): - unc_path, rest = op.splitunc(path) - unc_start, rest = op.splitunc(start) - if bool(unc_path) ^ bool(unc_start): - raise ValueError(("Cannot mix UNC and non-UNC paths " - "(%s and %s)") % (path, start)) - else: - raise ValueError("path is on drive %s, start on drive %s" - % (path_list[0], start_list[0])) - # Work out how much of the filepath is shared by start and path. - for i in range(min(len(start_list), len(path_list))): - if start_list[i].lower() != path_list[i].lower(): - break - else: - i += 1 - - rel_list = [op.pardir] * (len(start_list) - i) + path_list[i:] - if not rel_list: - return os.curdir - return op.join(*rel_list) - - -def modify_paths(object, relative=True, basedir=None): - """Convert paths in data structure to either full paths or relative paths - - Supports combinations of lists, dicts, tuples, strs - - Parameters - ---------- - - relative : boolean indicating whether paths should be set relative to the - current directory - basedir : default os.getcwd() - what base directory to use as default - """ - if not basedir: - basedir = os.getcwd() - if isinstance(object, dict): - out = {} - for key, val in sorted(object.items()): - if isdefined(val): - out[key] = modify_paths(val, relative=relative, - basedir=basedir) - elif isinstance(object, (list, tuple)): - out = [] - for val in object: - if isdefined(val): - out.append(modify_paths(val, relative=relative, - basedir=basedir)) - if isinstance(object, tuple): - out = tuple(out) - else: - if isdefined(object): - if isinstance(object, string_types) and os.path.isfile(object): - if relative: - if config.getboolean('execution', 'use_relative_paths'): - out = relpath(object, start=basedir) - else: - out = object - else: - out = os.path.abspath(os.path.join(basedir, object)) - if not os.path.exists(out): - raise FileNotFoundError('File %s not found' % out) - else: - out = object - return out - - -def get_print_name(node, simple_form=True): - """Get the name of the node - - For example, a node containing an instance of interfaces.fsl.BET - would be called nodename.BET.fsl - - """ - name = node.fullname - if hasattr(node, '_interface'): - pkglist = node._interface.__class__.__module__.split('.') - interface = node._interface.__class__.__name__ - destclass = '' - if len(pkglist) > 2: - destclass = '.%s' % pkglist[2] - if simple_form: - name = node.fullname + destclass - else: - name = '.'.join([node.fullname, interface]) + destclass - if simple_form: - parts = name.split('.') - if len(parts) > 2: - return ' ('.join(parts[1:]) + ')' - elif len(parts) == 2: - return parts[1] - return name def _create_dot_graph(graph, show_connectinfo=False, simple_form=True): @@ -290,7 +190,7 @@ def walk(children, level=0, path=None, usename=True): Examples -------- - >>> from nipype.pipeline.utils import walk + >>> from nipype.pipeline.engine.graph import walk >>> iterables = [('a', lambda: [1, 2]), ('b', lambda: [3, 4])] >>> [val['a'] for val in walk(iterables)] [1, 1, 2, 2] @@ -325,7 +225,7 @@ def synchronize_iterables(iterables): Examples -------- - >>> from nipype.pipeline.utils import synchronize_iterables + >>> from nipype.pipeline.engine.graph import synchronize_iterables >>> iterables = dict(a=lambda: [1, 2], b=lambda: [3, 4]) >>> synced = synchronize_iterables(iterables) >>> synced == [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}] @@ -473,16 +373,15 @@ def _merge_graphs(supergraph, nodes, subgraph, nodeid, iterables, node._id += template % i return supergraph - -def _connect_nodes(graph, srcnode, destnode, connection_info): - """Add a connection between two nodes - """ - data = graph.get_edge_data(srcnode, destnode, default=None) - if not data: - data = {'connect': connection_info} - graph.add_edges_from([(srcnode, destnode, data)]) - else: - data['connect'].extend(connection_info) +# def _connect_nodes(graph, srcnode, destnode, connection_info): +# """Add a connection between two nodes +# """ +# data = graph.get_edge_data(srcnode, destnode, default=None) +# if not data: +# data = {'connect': connection_info} +# graph.add_edges_from([(srcnode, destnode, data)]) +# else: +# data['connect'].extend(connection_info) def _remove_nonjoin_identity_nodes(graph, keep_iterables=False): @@ -506,7 +405,9 @@ def _identity_nodes(graph, include_iterables): are included if and only if the include_iterables flag is set to True. """ - return [node for node in nx.topological_sort(graph) + sorted_nodes = nx.topological_sort(graph) + logger.debug('Get identity nodes: %s' % [n.fullname for n in sorted_nodes]) + return [node for node in sorted_nodes if isinstance(node._interface, IdentityInterface) and (include_iterables or getattr(node, 'iterables') is None)] @@ -514,13 +415,23 @@ def _identity_nodes(graph, include_iterables): def _remove_identity_node(graph, node): """Remove identity nodes from an execution graph """ - portinputs, portoutputs = _node_ports(graph, node) + portinputs, portoutputs, signals = _node_ports(graph, node) + logger.debug('Remove Identity Node %s\n\tPortinputs=%s\n\tportoutputs=%s\n' + '\tsignals=%s' % (node, portinputs, portoutputs, signals)) + for field, connections in list(signals.items()): + if portinputs: + _propagate_internal_output(graph, node, field, connections, + portinputs) + else: + _propagate_signal(graph, node, field, connections) + for field, connections in list(portoutputs.items()): if portinputs: _propagate_internal_output(graph, node, field, connections, portinputs) else: _propagate_root_output(graph, node, field, connections) + graph.remove_nodes_from([node]) logger.debug("Removed the identity node %s from the graph." % node) @@ -537,19 +448,39 @@ def _node_ports(graph, node): """ portinputs = {} portoutputs = {} - for u, _, d in graph.in_edges_iter(node, data=True): - for src, dest in d['connect']: - portinputs[dest] = (u, src) - for _, v, d in graph.out_edges_iter(node, data=True): - for src, dest in d['connect']: + signals = {} + + in_edges = graph.in_edges_iter(node, data=True) + out_edges = graph.out_edges_iter(node, data=True) + + logger.debug('Edges of %s, (inputs=%s, signals=%s)' % (node, node.inputs, + node.signals)) + logger.debug('In edges') + for u, _, d in in_edges: + logger.debug('%s' % d) + for c in d['connect']: + portinputs[c[1]] = (u, c[0]) + + logger.debug('Out edges') + for _, v, d in out_edges: + logger.debug('%s' % d) + for c in d['connect']: + src, dest = c[0], c[1] + ctype = 'data' + if len(c) == 3: + ctype = c[-1] if isinstance(src, tuple): srcport = src[0] else: srcport = src - if srcport not in portoutputs: - portoutputs[srcport] = [] - portoutputs[srcport].append((v, dest, src)) - return (portinputs, portoutputs) + + if ctype == 'control': + signals[srcport] = signals.get(srcport, []) + \ + [(v, dest, src)] + else: + portoutputs[srcport] = portoutputs.get(srcport, []) + \ + [(v, dest, src)] + return (portinputs, portoutputs, signals) def _propagate_root_output(graph, node, field, connections): @@ -563,6 +494,20 @@ def _propagate_root_output(graph, node, field, connections): destnode.set_input(inport, value) +def _propagate_signal(graph, node, field, connections): + """Propagates the given graph root node output port + field connections to the out-edge destination nodes.""" + for destnode, inport, src in connections: + value = getattr(node.inputs, field) + if isinstance(src, tuple): + value = evaluate_connect_function(src[1], src[2], + value) + logger.debug( + 'Propagating signal %s.%s (value=%s) to %s.%s' % + (node, field, value, destnode, inport)) + destnode.set_signal(inport, value) + + def _propagate_internal_output(graph, node, field, connections, portinputs): """Propagates the given graph internal node output port field connections to the out-edge source node and in-edge @@ -748,15 +693,15 @@ def make_field_func(*pair): # the (source, destination) field tuples connects = newdata['connect'] # the join fields connected to the source - join_fields = [field for _, field in connects - if field in jnode.joinfield] + join_fields = [c[1] for c in connects + if c[1] in jnode.joinfield] # the {field: slot fields} maps assigned to the input # node, e.g. {'image': 'imageJ3', 'mask': 'maskJ3'} # for the third join source expansion replicate of a # join node with join fields image and mask slots = slot_dicts[in_idx] for con_idx, connect in enumerate(connects): - src_field, dest_field = connect + src_field, dest_field = connect[0], connect[1] # qualify a join destination field name if dest_field in slots: slot_field = slots[dest_field] @@ -977,165 +922,6 @@ def format_dot(dotfilename, format=None): logger.info('Converting dotfile: %s to %s format' % (dotfilename, format)) -def make_output_dir(outdir): - """Make the output_dir if it doesn't exist. - - Parameters - ---------- - outdir : output directory to create - - """ - if not os.path.exists(os.path.abspath(outdir)): - logger.debug("Creating %s" % outdir) - os.makedirs(outdir) - return outdir - - -def get_all_files(infile): - files = [infile] - if infile.endswith(".img"): - files.append(infile[:-4] + ".hdr") - files.append(infile[:-4] + ".mat") - if infile.endswith(".img.gz"): - files.append(infile[:-7] + ".hdr.gz") - return files - - -def walk_outputs(object): - """Extract every file and directory from a python structure - """ - out = [] - if isinstance(object, dict): - for key, val in sorted(object.items()): - if isdefined(val): - out.extend(walk_outputs(val)) - elif isinstance(object, (list, tuple)): - for val in object: - if isdefined(val): - out.extend(walk_outputs(val)) - else: - if isdefined(object) and isinstance(object, string_types): - if os.path.islink(object) or os.path.isfile(object): - out = [(filename, 'f') for filename in get_all_files(object)] - elif os.path.isdir(object): - out = [(object, 'd')] - return out - - -def walk_files(cwd): - for path, _, files in os.walk(cwd): - for f in files: - yield os.path.join(path, f) - - -def clean_working_directory(outputs, cwd, inputs, needed_outputs, config, - files2keep=None, dirs2keep=None): - """Removes all files not needed for further analysis from the directory - """ - if not outputs: - return - outputs_to_keep = list(outputs.get().keys()) - if needed_outputs and \ - str2bool(config['execution']['remove_unnecessary_outputs']): - outputs_to_keep = needed_outputs - # build a list of needed files - output_files = [] - outputdict = outputs.get() - for output in outputs_to_keep: - output_files.extend(walk_outputs(outputdict[output])) - needed_files = [path for path, type in output_files if type == 'f'] - if str2bool(config['execution']['keep_inputs']): - input_files = [] - inputdict = inputs.get() - input_files.extend(walk_outputs(inputdict)) - needed_files += [path for path, type in input_files if type == 'f'] - for extra in ['_0x*.json', 'provenance.*', 'pyscript*.m', 'pyjobs*.mat', - 'command.txt', 'result*.pklz', '_inputs.pklz', '_node.pklz']: - needed_files.extend(glob(os.path.join(cwd, extra))) - if files2keep: - needed_files.extend(filename_to_list(files2keep)) - needed_dirs = [path for path, type in output_files if type == 'd'] - if dirs2keep: - needed_dirs.extend(filename_to_list(dirs2keep)) - for extra in ['_nipype', '_report']: - needed_dirs.extend(glob(os.path.join(cwd, extra))) - temp = [] - for filename in needed_files: - temp.extend(get_related_files(filename)) - needed_files = temp - logger.debug('Needed files: %s' % (';'.join(needed_files))) - logger.debug('Needed dirs: %s' % (';'.join(needed_dirs))) - files2remove = [] - if str2bool(config['execution']['remove_unnecessary_outputs']): - for f in walk_files(cwd): - if f not in needed_files: - if len(needed_dirs) == 0: - files2remove.append(f) - elif not any([f.startswith(dname) for dname in needed_dirs]): - files2remove.append(f) - else: - if not str2bool(config['execution']['keep_inputs']): - input_files = [] - inputdict = inputs.get() - input_files.extend(walk_outputs(inputdict)) - input_files = [path for path, type in input_files if type == 'f'] - for f in walk_files(cwd): - if f in input_files and f not in needed_files: - files2remove.append(f) - logger.debug('Removing files: %s' % (';'.join(files2remove))) - for f in files2remove: - os.remove(f) - for key in outputs.copyable_trait_names(): - if key not in outputs_to_keep: - setattr(outputs, key, Undefined) - return outputs - - -def merge_dict(d1, d2, merge=lambda x, y: y): - """ - Merges two dictionaries, non-destructively, combining - values on duplicate keys as defined by the optional merge - function. The default behavior replaces the values in d1 - with corresponding values in d2. (There is no other generally - applicable merge strategy, but often you'll have homogeneous - types in your dicts, so specifying a merge technique can be - valuable.) - - Examples: - - >>> d1 = {'a': 1, 'c': 3, 'b': 2} - >>> d2 = merge_dict(d1, d1) - >>> len(d2) - 3 - >>> [d2[k] for k in ['a', 'b', 'c']] - [1, 2, 3] - - >>> d3 = merge_dict(d1, d1, lambda x,y: x+y) - >>> len(d3) - 3 - >>> [d3[k] for k in ['a', 'b', 'c']] - [2, 4, 6] - - """ - if not isinstance(d1, dict): - return merge(d1, d2) - result = dict(d1) - if d2 is None: - return result - for k, v in list(d2.items()): - if k in result: - result[k] = merge_dict(result[k], v, merge=merge) - else: - result[k] = v - return result - - -def merge_bundles(g1, g2): - for rec in g2.get_records(): - g1._add_record(rec) - return g1 - - def write_workflow_prov(graph, filename=None, format='all'): """Write W3C PROV Model JSON file """ diff --git a/nipype/pipeline/engine.py b/nipype/pipeline/engine/nodes.py similarity index 53% rename from nipype/pipeline/engine.py rename to nipype/pipeline/engine/nodes.py index 1c73918bf8..772f2f585f 100644 --- a/nipype/pipeline/engine.py +++ b/nipype/pipeline/engine/nodes.py @@ -1,24 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Defines functionality for pipelined execution of interfaces -The `Workflow` class provides core functionality for batch processing. +The `Node` class provides core functionality for atomic tasks processing. - Change directory to provide relative paths for doctests - >>> import os - >>> filepath = os.path.dirname( os.path.realpath( __file__ ) ) - >>> datadir = os.path.realpath(os.path.join(filepath, '../testing/data')) - >>> os.chdir(datadir) + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) """ from future import standard_library standard_library.install_aliases() from builtins import range -from builtins import object -from datetime import datetime -from nipype.utils.misc import flatten, unflatten try: from collections import OrderedDict except ImportError: @@ -28,1088 +28,36 @@ import pickle from glob import glob import gzip -import inspect import os import os.path as op -import re import shutil import errno import socket from shutil import rmtree import sys from tempfile import mkdtemp -from warnings import warn from hashlib import sha1 -import numpy as np -import networkx as nx - -from ..utils.misc import package_check, str2bool -package_check('networkx', '1.3') - -from .. import config, logging +from nipype.interfaces.base import ( + traits, InputMultiPath, CommandLine, Undefined, DynamicTraitedSpec, + Bunch, InterfaceResult, md5, Interface, isdefined) +from nipype.interfaces.utility import IdentityInterface +from nipype.utils.misc import flatten, unflatten, str2bool +from nipype.utils.filemanip import ( + save_json, FileNotFoundError, filename_to_list, list_to_filename, + copyfiles, fnames_presuffix, loadpkl, split_filename, load_json, + savepkl, write_rst_header, write_rst_dict, write_rst_list) +from .utils import (modify_paths, make_output_dir, clean_working_directory, + get_print_name, merge_dict) +from .base import NodeBase +from .graph import evaluate_connect_function + +from nipype.external.six import string_types +from nipype import config, logging logger = logging.getLogger('workflow') -from ..interfaces.base import (traits, InputMultiPath, CommandLine, - Undefined, TraitedSpec, DynamicTraitedSpec, - Bunch, InterfaceResult, md5, Interface, - TraitDictObject, TraitListObject, isdefined) -from ..utils.misc import (getsource, create_function_from_source, - flatten, unflatten) -from ..utils.filemanip import (save_json, FileNotFoundError, - filename_to_list, list_to_filename, - copyfiles, fnames_presuffix, loadpkl, - split_filename, load_json, savepkl, - write_rst_header, write_rst_dict, - write_rst_list) -from ..external.six import string_types -from .utils import (generate_expanded_graph, modify_paths, - export_graph, make_output_dir, write_workflow_prov, - clean_working_directory, format_dot, topological_sort, - get_print_name, merge_dict, evaluate_connect_function) - - -def _write_inputs(node): - lines = [] - nodename = node.fullname.replace('.', '_') - for key, _ in list(node.inputs.items()): - val = getattr(node.inputs, key) - if isdefined(val): - if type(val) == str: - try: - func = create_function_from_source(val) - except RuntimeError as e: - lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) - else: - funcname = [name for name in func.__globals__ - if name != '__builtins__'][0] - lines.append(pickle.loads(val)) - if funcname == nodename: - lines[-1] = lines[-1].replace(' %s(' % funcname, - ' %s_1(' % funcname) - funcname = '%s_1' % funcname - lines.append('from nipype.utils.misc import getsource') - lines.append("%s.inputs.%s = getsource(%s)" % (nodename, - key, - funcname)) - else: - lines.append('%s.inputs.%s = %s' % (nodename, key, val)) - return lines - - -def format_node(node, format='python', include_config=False): - """Format a node in a given output syntax.""" - lines = [] - name = node.fullname.replace('.', '_') - if format == 'python': - klass = node._interface - importline = 'from %s import %s' % (klass.__module__, - klass.__class__.__name__) - comment = '# Node: %s' % node.fullname - spec = inspect.signature(node._interface.__init__) - args = spec.args[1:] - if args: - filled_args = [] - for arg in args: - if hasattr(node._interface, '_%s' % arg): - filled_args.append('%s=%s' % (arg, getattr(node._interface, - '_%s' % arg))) - args = ', '.join(filled_args) - else: - args = '' - klass_name = klass.__class__.__name__ - if isinstance(node, MapNode): - nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' \ - % (name, klass_name, args, node.iterfield, name) - else: - nodedef = '%s = Node(%s(%s), name="%s")' \ - % (name, klass_name, args, name) - lines = [importline, comment, nodedef] - - if include_config: - lines = [importline, "from collections import OrderedDict", - comment, nodedef] - lines.append('%s.config = %s' % (name, node.config)) - - if node.iterables is not None: - lines.append('%s.iterables = %s' % (name, node.iterables)) - lines.extend(_write_inputs(node)) - - return lines - - -class WorkflowBase(object): - """Defines common attributes and functions for workflows and nodes.""" - - def __init__(self, name=None, base_dir=None): - """ Initialize base parameters of a workflow or node - - Parameters - ---------- - name : string (mandatory) - Name of this node. Name must be alphanumeric and not contain any - special characters (e.g., '.', '@'). - base_dir : string - base output directory (will be hashed before creations) - default=None, which results in the use of mkdtemp - - """ - self.base_dir = base_dir - self.config = None - self._verify_name(name) - self.name = name - # for compatibility with node expansion using iterables - self._id = self.name - self._hierarchy = None - - @property - def inputs(self): - raise NotImplementedError - - @property - def outputs(self): - raise NotImplementedError - - @property - def fullname(self): - fullname = self.name - if self._hierarchy: - fullname = self._hierarchy + '.' + self.name - return fullname - - def clone(self, name): - """Clone a workflowbase object - - Parameters - ---------- - - name : string (mandatory) - A clone of node or workflow must have a new name - """ - if (name is None) or (name == self.name): - raise Exception('Cloning requires a new name') - self._verify_name(name) - clone = deepcopy(self) - clone.name = name - clone._id = name - clone._hierarchy = None - return clone - - def _check_outputs(self, parameter): - return hasattr(self.outputs, parameter) - - def _check_inputs(self, parameter): - if isinstance(self.inputs, DynamicTraitedSpec): - return True - return hasattr(self.inputs, parameter) - - def _verify_name(self, name): - valid_name = bool(re.match('^[\w-]+$', name)) - if not valid_name: - raise Exception('the name must not contain any special characters') - - def __repr__(self): - if self._hierarchy: - return '.'.join((self._hierarchy, self._id)) - else: - return self._id - - def save(self, filename=None): - if filename is None: - filename = 'temp.pklz' - savepkl(filename, self) - - def load(self, filename): - if '.npz' in filename: - DeprecationWarning(('npz files will be deprecated in the next ' - 'release. you can use numpy to open them.')) - return np.load(filename) - return loadpkl(filename) - - -class Workflow(WorkflowBase): - """Controls the setup and execution of a pipeline of processes.""" - - def __init__(self, name, base_dir=None): - """Create a workflow object. - - Parameters - ---------- - name : alphanumeric string - unique identifier for the workflow - base_dir : string, optional - path to workflow storage - - """ - super(Workflow, self).__init__(name, base_dir) - self._graph = nx.DiGraph() - self.config = deepcopy(config._sections) - - # PUBLIC API - def clone(self, name): - """Clone a workflow - - .. note:: - - Will reset attributes used for executing workflow. See - _init_runtime_fields. - - Parameters - ---------- - - name: alphanumeric name - unique name for the workflow - - """ - clone = super(Workflow, self).clone(name) - clone._reset_hierarchy() - return clone - - # Graph creation functions - def connect(self, *args, **kwargs): - """Connect nodes in the pipeline. - - This routine also checks if inputs and outputs are actually provided by - the nodes that are being connected. - - Creates edges in the directed graph using the nodes and edges specified - in the `connection_list`. Uses the NetworkX method - DiGraph.add_edges_from. - - Parameters - ---------- - - args : list or a set of four positional arguments - - Four positional arguments of the form:: - - connect(source, sourceoutput, dest, destinput) - - source : nodewrapper node - sourceoutput : string (must be in source.outputs) - dest : nodewrapper node - destinput : string (must be in dest.inputs) - A list of 3-tuples of the following form:: - [(source, target, - [('sourceoutput/attribute', 'targetinput'), - ...]), - ...] - - Or:: - - [(source, target, [(('sourceoutput1', func, arg2, ...), - 'targetinput'), ...]), - ...] - sourceoutput1 will always be the first argument to func - and func will be evaluated and the results sent ot targetinput - - currently func needs to define all its needed imports within the - function as we use the inspect module to get at the source code - and execute it remotely - """ - if len(args) == 1: - connection_list = args[0] - elif len(args) == 4: - connection_list = [(args[0], args[2], [(args[1], args[3])])] - else: - raise Exception('unknown set of parameters to connect function') - if not kwargs: - disconnect = False - else: - disconnect = kwargs['disconnect'] - newnodes = [] - for srcnode, destnode, _ in connection_list: - if self in [srcnode, destnode]: - msg = ('Workflow connect cannot contain itself as node:' - ' src[%s] dest[%s] workflow[%s]') % (srcnode, - destnode, - self.name) - - raise IOError(msg) - if (srcnode not in newnodes) and not self._has_node(srcnode): - newnodes.append(srcnode) - if (destnode not in newnodes) and not self._has_node(destnode): - newnodes.append(destnode) - if newnodes: - self._check_nodes(newnodes) - for node in newnodes: - if node._hierarchy is None: - node._hierarchy = self.name - not_found = [] - connected_ports = {} - for srcnode, destnode, connects in connection_list: - if destnode not in connected_ports: - connected_ports[destnode] = [] - # check to see which ports of destnode are already - # connected. - if not disconnect and (destnode in self._graph.nodes()): - for edge in self._graph.in_edges_iter(destnode): - data = self._graph.get_edge_data(*edge) - for sourceinfo, destname in data['connect']: - if destname not in connected_ports[destnode]: - connected_ports[destnode] += [destname] - for source, dest in connects: - # Currently datasource/sink/grabber.io modules - # determine their inputs/outputs depending on - # connection settings. Skip these modules in the check - if dest in connected_ports[destnode]: - raise Exception(""" -Trying to connect %s:%s to %s:%s but input '%s' of node '%s' is already -connected. -""" % (srcnode, source, destnode, dest, dest, destnode)) - if not (hasattr(destnode, '_interface') and - '.io' in str(destnode._interface.__class__)): - if not destnode._check_inputs(dest): - not_found.append(['in', destnode.name, dest]) - if not (hasattr(srcnode, '_interface') and - '.io' in str(srcnode._interface.__class__)): - if isinstance(source, tuple): - # handles the case that source is specified - # with a function - sourcename = source[0] - elif isinstance(source, string_types): - sourcename = source - else: - raise Exception(('Unknown source specification in ' - 'connection from output of %s') % - srcnode.name) - if sourcename and not srcnode._check_outputs(sourcename): - not_found.append(['out', srcnode.name, sourcename]) - connected_ports[destnode] += [dest] - infostr = [] - for info in not_found: - infostr += ["Module %s has no %sput called %s\n" % (info[1], - info[0], - info[2])] - if not_found: - raise Exception('\n'.join(['Some connections were not found'] + - infostr)) - - # turn functions into strings - for srcnode, destnode, connects in connection_list: - for idx, (src, dest) in enumerate(connects): - if isinstance(src, tuple) and not isinstance(src[1], string_types): - function_source = getsource(src[1]) - connects[idx] = ((src[0], function_source, src[2:]), dest) - - # add connections - for srcnode, destnode, connects in connection_list: - edge_data = self._graph.get_edge_data(srcnode, destnode, None) - if edge_data: - logger.debug('(%s, %s): Edge data exists: %s' - % (srcnode, destnode, str(edge_data))) - for data in connects: - if data not in edge_data['connect']: - edge_data['connect'].append(data) - if disconnect: - logger.debug('Removing connection: %s' % str(data)) - edge_data['connect'].remove(data) - if edge_data['connect']: - self._graph.add_edges_from([(srcnode, - destnode, - edge_data)]) - else: - # pass - logger.debug('Removing connection: %s->%s' % (srcnode, - destnode)) - self._graph.remove_edges_from([(srcnode, destnode)]) - elif not disconnect: - logger.debug('(%s, %s): No edge data' % (srcnode, destnode)) - self._graph.add_edges_from([(srcnode, destnode, - {'connect': connects})]) - edge_data = self._graph.get_edge_data(srcnode, destnode) - logger.debug('(%s, %s): new edge data: %s' % (srcnode, destnode, - str(edge_data))) - - def disconnect(self, *args): - """Disconnect two nodes - - See the docstring for connect for format. - """ - # yoh: explicit **dict was introduced for compatibility with Python 2.5 - return self.connect(*args, **dict(disconnect=True)) - - def add_nodes(self, nodes): - """ Add nodes to a workflow - - Parameters - ---------- - nodes : list - A list of WorkflowBase-based objects - """ - newnodes = [] - all_nodes = self._get_all_nodes() - for node in nodes: - if self._has_node(node): - raise IOError('Node %s already exists in the workflow' % node) - if isinstance(node, Workflow): - for subnode in node._get_all_nodes(): - if subnode in all_nodes: - raise IOError(('Subnode %s of node %s already exists ' - 'in the workflow') % (subnode, node)) - newnodes.append(node) - if not newnodes: - logger.debug('no new nodes to add') - return - for node in newnodes: - if not issubclass(node.__class__, WorkflowBase): - raise Exception('Node %s must be a subclass of WorkflowBase' % - str(node)) - self._check_nodes(newnodes) - for node in newnodes: - if node._hierarchy is None: - node._hierarchy = self.name - self._graph.add_nodes_from(newnodes) - - def remove_nodes(self, nodes): - """ Remove nodes from a workflow - - Parameters - ---------- - nodes : list - A list of WorkflowBase-based objects - """ - self._graph.remove_nodes_from(nodes) - - # Input-Output access - @property - def inputs(self): - return self._get_inputs() - - @property - def outputs(self): - return self._get_outputs() - - def get_node(self, name): - """Return an internal node by name - """ - nodenames = name.split('.') - nodename = nodenames[0] - outnode = [node for node in self._graph.nodes() if - str(node).endswith('.' + nodename)] - if outnode: - outnode = outnode[0] - if nodenames[1:] and issubclass(outnode.__class__, Workflow): - outnode = outnode.get_node('.'.join(nodenames[1:])) - else: - outnode = None - return outnode - - def list_node_names(self): - """List names of all nodes in a workflow - """ - outlist = [] - for node in nx.topological_sort(self._graph): - if isinstance(node, Workflow): - outlist.extend(['.'.join((node.name, nodename)) for nodename in - node.list_node_names()]) - else: - outlist.append(node.name) - return sorted(outlist) - - def write_graph(self, dotfilename='graph.dot', graph2use='hierarchical', - format="png", simple_form=True): - """Generates a graphviz dot file and a png file - - Parameters - ---------- - - graph2use: 'orig', 'hierarchical' (default), 'flat', 'exec', 'colored' - orig - creates a top level graph without expanding internal - workflow nodes; - flat - expands workflow nodes recursively; - hierarchical - expands workflow nodes recursively with a - notion on hierarchy; - colored - expands workflow nodes recursively with a - notion on hierarchy in color; - exec - expands workflows to depict iterables - - format: 'png', 'svg' - - simple_form: boolean (default: True) - Determines if the node name used in the graph should be of the form - 'nodename (package)' when True or 'nodename.Class.package' when - False. - - """ - graphtypes = ['orig', 'flat', 'hierarchical', 'exec', 'colored'] - if graph2use not in graphtypes: - raise ValueError('Unknown graph2use keyword. Must be one of: ' + - str(graphtypes)) - base_dir, dotfilename = op.split(dotfilename) - if base_dir == '': - if self.base_dir: - base_dir = self.base_dir - if self.name: - base_dir = op.join(base_dir, self.name) - else: - base_dir = os.getcwd() - base_dir = make_output_dir(base_dir) - if graph2use in ['hierarchical', 'colored']: - dotfilename = op.join(base_dir, dotfilename) - self.write_hierarchical_dotfile(dotfilename=dotfilename, - colored=graph2use == "colored", - simple_form=simple_form) - format_dot(dotfilename, format=format) - else: - graph = self._graph - if graph2use in ['flat', 'exec']: - graph = self._create_flat_graph() - if graph2use == 'exec': - graph = generate_expanded_graph(deepcopy(graph)) - export_graph(graph, base_dir, dotfilename=dotfilename, - format=format, simple_form=simple_form) - - def write_hierarchical_dotfile(self, dotfilename=None, colored=False, - simple_form=True): - dotlist = ['digraph %s{' % self.name] - dotlist.append(self._get_dot(prefix=' ', colored=colored, - simple_form=simple_form)) - dotlist.append('}') - dotstr = '\n'.join(dotlist) - if dotfilename: - fp = open(dotfilename, 'wt') - fp.writelines(dotstr) - fp.close() - else: - logger.info(dotstr) - - def export(self, filename=None, prefix="output", format="python", - include_config=False): - """Export object into a different format - - Parameters - ---------- - filename: string - file to save the code to; overrides prefix - prefix: string - prefix to use for output file - format: string - one of "python" - include_config: boolean - whether to include node and workflow config values - - """ - formats = ["python"] - if format not in formats: - raise ValueError('format must be one of: %s' % '|'.join(formats)) - flatgraph = self._create_flat_graph() - nodes = nx.topological_sort(flatgraph) - - lines = ['# Workflow'] - importlines = ['from nipype.pipeline.engine import Workflow, ' - 'Node, MapNode'] - functions = {} - if format == "python": - connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name - connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' \ - % self.name - wfdef = '%s = Workflow("%s")' % (self.name, self.name) - lines.append(wfdef) - if include_config: - lines.append('%s.config = %s' % (self.name, self.config)) - for idx, node in enumerate(nodes): - nodename = node.fullname.replace('.', '_') - # write nodes - nodelines = format_node(node, format='python', - include_config=include_config) - for line in nodelines: - if line.startswith('from'): - if line not in importlines: - importlines.append(line) - else: - lines.append(line) - # write connections - for u, _, d in flatgraph.in_edges_iter(nbunch=node, - data=True): - for cd in d['connect']: - if isinstance(cd[0], tuple): - args = list(cd[0]) - if args[1] in functions: - funcname = functions[args[1]] - else: - func = create_function_from_source(args[1]) - funcname = [name for name in func.__globals__ - if name != '__builtins__'][0] - functions[args[1]] = funcname - args[1] = funcname - args = tuple([arg for arg in args if arg]) - line_args = (u.fullname.replace('.', '_'), - args, nodename, cd[1]) - line = connect_template % line_args - line = line.replace("'%s'" % funcname, funcname) - lines.append(line) - else: - line_args = (u.fullname.replace('.', '_'), - cd[0], nodename, cd[1]) - lines.append(connect_template2 % line_args) - functionlines = ['# Functions'] - for function in functions: - functionlines.append(pickle.loads(function).rstrip()) - all_lines = importlines + functionlines + lines - - if not filename: - filename = '%s%s.py' % (prefix, self.name) - with open(filename, 'wt') as fp: - fp.writelines('\n'.join(all_lines)) - return all_lines - - def run(self, plugin=None, plugin_args=None, updatehash=False): - """ Execute the workflow - - Parameters - ---------- - - plugin: plugin name or object - Plugin to use for execution. You can create your own plugins for - execution. - plugin_args : dictionary containing arguments to be sent to plugin - constructor. see individual plugin doc strings for details. - """ - if plugin is None: - plugin = config.get('execution', 'plugin') - if not isinstance(plugin, string_types): - runner = plugin - else: - name = 'nipype.pipeline.plugins' - try: - __import__(name) - except ImportError: - msg = 'Could not import plugin module: %s' % name - logger.error(msg) - raise ImportError(msg) - else: - plugin_mod = getattr(sys.modules[name], '%sPlugin' % plugin) - runner = plugin_mod(plugin_args=plugin_args) - flatgraph = self._create_flat_graph() - self.config = merge_dict(deepcopy(config._sections), self.config) - if 'crashdump_dir' in self.config: - warn(("Deprecated: workflow.config['crashdump_dir']\n" - "Please use config['execution']['crashdump_dir']")) - crash_dir = self.config['crashdump_dir'] - self.config['execution']['crashdump_dir'] = crash_dir - del self.config['crashdump_dir'] - logger.info(str(sorted(self.config))) - self._set_needed_outputs(flatgraph) - execgraph = generate_expanded_graph(deepcopy(flatgraph)) - for index, node in enumerate(execgraph.nodes()): - node.config = merge_dict(deepcopy(self.config), node.config) - node.base_dir = self.base_dir - node.index = index - if isinstance(node, MapNode): - node.use_plugin = (plugin, plugin_args) - self._configure_exec_nodes(execgraph) - if str2bool(self.config['execution']['create_report']): - self._write_report_info(self.base_dir, self.name, execgraph) - runner.run(execgraph, updatehash=updatehash, config=self.config) - datestr = datetime.utcnow().strftime('%Y%m%dT%H%M%S') - if str2bool(self.config['execution']['write_provenance']): - prov_base = op.join(self.base_dir, - 'workflow_provenance_%s' % datestr) - logger.info('Provenance file prefix: %s' % prov_base) - write_workflow_prov(execgraph, prov_base, format='all') - return execgraph - - # PRIVATE API AND FUNCTIONS - - def _write_report_info(self, workingdir, name, graph): - if workingdir is None: - workingdir = os.getcwd() - report_dir = op.join(workingdir, name) - if not op.exists(report_dir): - os.makedirs(report_dir) - shutil.copyfile(op.join(op.dirname(__file__), - 'report_template.html'), - op.join(report_dir, 'index.html')) - shutil.copyfile(op.join(op.dirname(__file__), - '..', 'external', 'd3.js'), - op.join(report_dir, 'd3.js')) - nodes, groups = topological_sort(graph, depth_first=True) - graph_file = op.join(report_dir, 'graph1.json') - json_dict = {'nodes': [], 'links': [], 'groups': [], 'maxN': 0} - for i, node in enumerate(nodes): - report_file = "%s/_report/report.rst" % \ - node.output_dir().replace(report_dir, '') - result_file = "%s/result_%s.pklz" % \ - (node.output_dir().replace(report_dir, ''), - node.name) - json_dict['nodes'].append(dict(name='%d_%s' % (i, node.name), - report=report_file, - result=result_file, - group=groups[i])) - maxN = 0 - for gid in np.unique(groups): - procs = [i for i, val in enumerate(groups) if val == gid] - N = len(procs) - if N > maxN: - maxN = N - json_dict['groups'].append(dict(procs=procs, - total=N, - name='Group_%05d' % gid)) - json_dict['maxN'] = maxN - for u, v in graph.in_edges_iter(): - json_dict['links'].append(dict(source=nodes.index(u), - target=nodes.index(v), - value=1)) - save_json(graph_file, json_dict) - graph_file = op.join(report_dir, 'graph.json') - template = '%%0%dd_' % np.ceil(np.log10(len(nodes))).astype(int) - - def getname(u, i): - name_parts = u.fullname.split('.') - # return '.'.join(name_parts[:-1] + [template % i + name_parts[-1]]) - return template % i + name_parts[-1] - json_dict = [] - for i, node in enumerate(nodes): - imports = [] - for u, v in graph.in_edges_iter(nbunch=node): - imports.append(getname(u, nodes.index(u))) - json_dict.append(dict(name=getname(node, i), - size=1, - group=groups[i], - imports=imports)) - save_json(graph_file, json_dict) - - def _set_needed_outputs(self, graph): - """Initialize node with list of which outputs are needed.""" - rm_outputs = self.config['execution']['remove_unnecessary_outputs'] - if not str2bool(rm_outputs): - return - for node in graph.nodes(): - node.needed_outputs = [] - for edge in graph.out_edges_iter(node): - data = graph.get_edge_data(*edge) - sourceinfo = [v1[0] if isinstance(v1, tuple) else v1 - for v1, v2 in data['connect']] - node.needed_outputs += [v for v in sourceinfo - if v not in node.needed_outputs] - if node.needed_outputs: - node.needed_outputs = sorted(node.needed_outputs) - - def _configure_exec_nodes(self, graph): - """Ensure that each node knows where to get inputs from - """ - for node in graph.nodes(): - node.input_source = {} - for edge in graph.in_edges_iter(node): - data = graph.get_edge_data(*edge) - for sourceinfo, field in sorted(data['connect']): - node.input_source[field] = \ - (op.join(edge[0].output_dir(), - 'result_%s.pklz' % edge[0].name), - sourceinfo) - - def _check_nodes(self, nodes): - """Checks if any of the nodes are already in the graph - - """ - node_names = [node.name for node in self._graph.nodes()] - node_lineage = [node._hierarchy for node in self._graph.nodes()] - for node in nodes: - if node.name in node_names: - idx = node_names.index(node.name) - if node_lineage[idx] in [node._hierarchy, self.name]: - raise IOError('Duplicate node name %s found.' % node.name) - else: - node_names.append(node.name) - - def _has_attr(self, parameter, subtype='in'): - """Checks if a parameter is available as an input or output - """ - if subtype == 'in': - subobject = self.inputs - else: - subobject = self.outputs - attrlist = parameter.split('.') - cur_out = subobject - for attr in attrlist: - if not hasattr(cur_out, attr): - return False - cur_out = getattr(cur_out, attr) - return True - - def _get_parameter_node(self, parameter, subtype='in'): - """Returns the underlying node corresponding to an input or - output parameter - """ - if subtype == 'in': - subobject = self.inputs - else: - subobject = self.outputs - attrlist = parameter.split('.') - cur_out = subobject - for attr in attrlist[:-1]: - cur_out = getattr(cur_out, attr) - return cur_out.traits()[attrlist[-1]].node - - def _check_outputs(self, parameter): - return self._has_attr(parameter, subtype='out') - - def _check_inputs(self, parameter): - return self._has_attr(parameter, subtype='in') - - def _get_inputs(self): - """Returns the inputs of a workflow - - This function does not return any input ports that are already - connected - """ - inputdict = TraitedSpec() - for node in self._graph.nodes(): - inputdict.add_trait(node.name, traits.Instance(TraitedSpec)) - if isinstance(node, Workflow): - setattr(inputdict, node.name, node.inputs) - else: - taken_inputs = [] - for _, _, d in self._graph.in_edges_iter(nbunch=node, - data=True): - for cd in d['connect']: - taken_inputs.append(cd[1]) - unconnectedinputs = TraitedSpec() - for key, trait in list(node.inputs.items()): - if key not in taken_inputs: - unconnectedinputs.add_trait(key, - traits.Trait(trait, - node=node)) - value = getattr(node.inputs, key) - setattr(unconnectedinputs, key, value) - setattr(inputdict, node.name, unconnectedinputs) - getattr(inputdict, node.name).on_trait_change(self._set_input) - return inputdict - - def _get_outputs(self): - """Returns all possible output ports that are not already connected - """ - outputdict = TraitedSpec() - for node in self._graph.nodes(): - outputdict.add_trait(node.name, traits.Instance(TraitedSpec)) - if isinstance(node, Workflow): - setattr(outputdict, node.name, node.outputs) - elif node.outputs: - outputs = TraitedSpec() - for key, _ in list(node.outputs.items()): - outputs.add_trait(key, traits.Any(node=node)) - setattr(outputs, key, None) - setattr(outputdict, node.name, outputs) - return outputdict - - def _set_input(self, object, name, newvalue): - """Trait callback function to update a node input - """ - object.traits()[name].node.set_input(name, newvalue) - - def _set_node_input(self, node, param, source, sourceinfo): - """Set inputs of a node given the edge connection""" - if isinstance(sourceinfo, string_types): - val = source.get_output(sourceinfo) - elif isinstance(sourceinfo, tuple): - if callable(sourceinfo[1]): - val = sourceinfo[1](source.get_output(sourceinfo[0]), - *sourceinfo[2:]) - newval = val - if isinstance(val, TraitDictObject): - newval = dict(val) - if isinstance(val, TraitListObject): - newval = val[:] - logger.debug('setting node input: %s->%s', param, str(newval)) - node.set_input(param, deepcopy(newval)) - - def _get_all_nodes(self): - allnodes = [] - for node in self._graph.nodes(): - if isinstance(node, Workflow): - allnodes.extend(node._get_all_nodes()) - else: - allnodes.append(node) - return allnodes - - def _has_node(self, wanted_node): - for node in self._graph.nodes(): - if wanted_node == node: - return True - if isinstance(node, Workflow): - if node._has_node(wanted_node): - return True - return False - - def _create_flat_graph(self): - """Make a simple DAG where no node is a workflow.""" - logger.debug('Creating flat graph for workflow: %s', self.name) - workflowcopy = deepcopy(self) - workflowcopy._generate_flatgraph() - return workflowcopy._graph - - def _reset_hierarchy(self): - """Reset the hierarchy on a graph - """ - for node in self._graph.nodes(): - if isinstance(node, Workflow): - node._reset_hierarchy() - for innernode in node._graph.nodes(): - innernode._hierarchy = '.'.join((self.name, - innernode._hierarchy)) - else: - node._hierarchy = self.name - - def _generate_flatgraph(self): - """Generate a graph containing only Nodes or MapNodes - """ - logger.debug('expanding workflow: %s', self) - nodes2remove = [] - if not nx.is_directed_acyclic_graph(self._graph): - raise Exception(('Workflow: %s is not a directed acyclic graph ' - '(DAG)') % self.name) - nodes = nx.topological_sort(self._graph) - for node in nodes: - logger.debug('processing node: %s' % node) - if isinstance(node, Workflow): - nodes2remove.append(node) - # use in_edges instead of in_edges_iter to allow - # disconnections to take place properly. otherwise, the - # edge dict is modified. - for u, _, d in self._graph.in_edges(nbunch=node, data=True): - logger.debug('in: connections-> %s' % str(d['connect'])) - for cd in deepcopy(d['connect']): - logger.debug("in: %s" % str(cd)) - dstnode = node._get_parameter_node(cd[1], subtype='in') - srcnode = u - srcout = cd[0] - dstin = cd[1].split('.')[-1] - logger.debug('in edges: %s %s %s %s' % - (srcnode, srcout, dstnode, dstin)) - self.disconnect(u, cd[0], node, cd[1]) - self.connect(srcnode, srcout, dstnode, dstin) - # do not use out_edges_iter for reasons stated in in_edges - for _, v, d in self._graph.out_edges(nbunch=node, data=True): - logger.debug('out: connections-> %s' % str(d['connect'])) - for cd in deepcopy(d['connect']): - logger.debug("out: %s" % str(cd)) - dstnode = v - if isinstance(cd[0], tuple): - parameter = cd[0][0] - else: - parameter = cd[0] - srcnode = node._get_parameter_node(parameter, - subtype='out') - if isinstance(cd[0], tuple): - srcout = list(cd[0]) - srcout[0] = parameter.split('.')[-1] - srcout = tuple(srcout) - else: - srcout = parameter.split('.')[-1] - dstin = cd[1] - logger.debug('out edges: %s %s %s %s' % (srcnode, - srcout, - dstnode, - dstin)) - self.disconnect(node, cd[0], v, cd[1]) - self.connect(srcnode, srcout, dstnode, dstin) - # expand the workflow node - # logger.debug('expanding workflow: %s', node) - node._generate_flatgraph() - for innernode in node._graph.nodes(): - innernode._hierarchy = '.'.join((self.name, - innernode._hierarchy)) - self._graph.add_nodes_from(node._graph.nodes()) - self._graph.add_edges_from(node._graph.edges(data=True)) - if nodes2remove: - self._graph.remove_nodes_from(nodes2remove) - logger.debug('finished expanding workflow: %s', self) - - def _get_dot(self, prefix=None, hierarchy=None, colored=False, - simple_form=True, level=0): - """Create a dot file with connection info - """ - if prefix is None: - prefix = ' ' - if hierarchy is None: - hierarchy = [] - colorset = ['#FFFFC8', '#0000FF', '#B4B4FF', '#E6E6FF', '#FF0000', - '#FFB4B4', '#FFE6E6', '#00A300', '#B4FFB4', '#E6FFE6'] - - dotlist = ['%slabel="%s";' % (prefix, self.name)] - for node in nx.topological_sort(self._graph): - fullname = '.'.join(hierarchy + [node.fullname]) - nodename = fullname.replace('.', '_') - if not isinstance(node, Workflow): - node_class_name = get_print_name(node, simple_form=simple_form) - if not simple_form: - node_class_name = '.'.join(node_class_name.split('.')[1:]) - if hasattr(node, 'iterables') and node.iterables: - dotlist.append(('%s[label="%s", shape=box3d,' - 'style=filled, color=black, colorscheme' - '=greys7 fillcolor=2];') % (nodename, - node_class_name)) - else: - if colored: - dotlist.append(('%s[label="%s", style=filled,' - ' fillcolor="%s"];') - % (nodename, node_class_name, - colorset[level])) - else: - dotlist.append(('%s[label="%s"];') - % (nodename, node_class_name)) - - for node in nx.topological_sort(self._graph): - if isinstance(node, Workflow): - fullname = '.'.join(hierarchy + [node.fullname]) - nodename = fullname.replace('.', '_') - dotlist.append('subgraph cluster_%s {' % nodename) - if colored: - dotlist.append(prefix + prefix + 'edge [color="%s"];' % (colorset[level + 1])) - dotlist.append(prefix + prefix + 'style=filled;') - dotlist.append(prefix + prefix + 'fillcolor="%s";' % (colorset[level + 2])) - dotlist.append(node._get_dot(prefix=prefix + prefix, - hierarchy=hierarchy + [self.name], - colored=colored, - simple_form=simple_form, level=level + 3)) - dotlist.append('}') - if level == 6: - level = 2 - else: - for subnode in self._graph.successors_iter(node): - if node._hierarchy != subnode._hierarchy: - continue - if not isinstance(subnode, Workflow): - nodefullname = '.'.join(hierarchy + [node.fullname]) - subnodefullname = '.'.join(hierarchy + - [subnode.fullname]) - nodename = nodefullname.replace('.', '_') - subnodename = subnodefullname.replace('.', '_') - for _ in self._graph.get_edge_data(node, - subnode)['connect']: - dotlist.append('%s -> %s;' % (nodename, - subnodename)) - logger.debug('connection: ' + dotlist[-1]) - # add between workflow connections - for u, v, d in self._graph.edges_iter(data=True): - uname = '.'.join(hierarchy + [u.fullname]) - vname = '.'.join(hierarchy + [v.fullname]) - for src, dest in d['connect']: - uname1 = uname - vname1 = vname - if isinstance(src, tuple): - srcname = src[0] - else: - srcname = src - if '.' in srcname: - uname1 += '.' + '.'.join(srcname.split('.')[:-1]) - if '.' in dest and '@' not in dest: - if not isinstance(v, Workflow): - if 'datasink' not in \ - str(v._interface.__class__).lower(): - vname1 += '.' + '.'.join(dest.split('.')[:-1]) - else: - vname1 += '.' + '.'.join(dest.split('.')[:-1]) - if uname1.split('.')[:-1] != vname1.split('.')[:-1]: - dotlist.append('%s -> %s;' % (uname1.replace('.', '_'), - vname1.replace('.', '_'))) - logger.debug('cross connection: ' + dotlist[-1]) - return ('\n' + prefix).join(dotlist) - - -class Node(WorkflowBase): +class Node(NodeBase): """Wraps interface objects for use in pipeline A Node creates a sandbox-like directory for executing the underlying @@ -1121,12 +69,15 @@ class Node(WorkflowBase): Examples -------- - >>> from nipype import Node - >>> from nipype.interfaces import spm - >>> realign = Node(spm.Realign(), 'realign') - >>> realign.inputs.in_files = 'functional.nii' - >>> realign.inputs.register_to_mean = True - >>> realign.run() # doctest: +SKIP + >>> from nipype.pipeline.engine import Node + >>> from nipype.interfaces import fsl + >>> bet = Node(fsl.BET(), 'BET') + >>> bet.inputs.in_file = 'T1.nii' + >>> bet.run() # doctest: +SKIP + + >>> bet.signals.disable = True + >>> bet.run() is None + True """ @@ -1201,14 +152,18 @@ def __init__(self, interface, name, iterables=None, itersource=None, multiprocessing pool """ - base_dir = None - if 'base_dir' in kwargs: - base_dir = kwargs['base_dir'] - super(Node, self).__init__(name, base_dir) + base_dir = kwargs.get('base_dir', None) + if interface is None: raise IOError('Interface must be provided') if not isinstance(interface, Interface): raise IOError('interface must be an instance of an Interface') + + control = kwargs.get('control', True) + if isinstance(interface, IdentityInterface): + control = False + + super(Node, self).__init__(name, base_dir, control) self._interface = interface self.name = name self._result = None @@ -1273,6 +228,16 @@ def set_input(self, parameter, val): str(val))) setattr(self.inputs, parameter, deepcopy(val)) + def set_signal(self, parameter, val): + """ Set interface input value""" + logger.debug('setting nodelevel(%s) signal %s = %s' % (str(self), + parameter, + str(val))) + if isinstance(self._interface, IdentityInterface): + self.set_input(parameter, val) + elif self.signals is not None: + setattr(self.signals, parameter, deepcopy(val)) + def get_output(self, parameter): """Retrieve a particular output of the node""" val = None @@ -1312,6 +277,10 @@ def hash_exists(self, updatehash=False): self._save_hashfile(hashfile, hashed_inputs) return op.exists(hashfile), hashvalue, hashfile, hashed_inputs + def _update_disable(self): + logger.debug('Signal disable is now %s for node %s' % + (self.signals.disable, self.fullname)) + def run(self, updatehash=False): """Execute the node in its directory. @@ -1321,6 +290,10 @@ def run(self, updatehash=False): updatehash: boolean Update the hash stored in the output directory """ + if (self.signals is not None and self.signals.disable): + logger.debug('Node: %s skipped' % self.fullname) + return self._result + # check to see if output directory and hash exist if self.config is None: self.config = deepcopy(config._sections) @@ -1338,6 +311,7 @@ def run(self, updatehash=False): logger.debug(('updatehash, overwrite, always_run, hash_exists', updatehash, self.overwrite, self._interface.always_run, hash_exists)) + if (not updatehash and (((self.overwrite is None and self._interface.always_run) or self.overwrite) or not @@ -1500,7 +474,7 @@ def _get_inputs(self): This mechanism can be easily extended/replaced to retrieve data from other data sources (e.g., XNAT, HTTP, etc.,.) """ - logger.debug('Setting node inputs') + logger.debug('Setting node inputs: %s' % self.input_source.keys()) for key, info in list(self.input_source.items()): logger.debug('input: %s' % key) results_file = info[0] @@ -1757,6 +731,9 @@ def write_report(self, report_type=None, cwd=None): 'Exec ID : %s' % self._id])) fp.writelines(write_rst_header('Original Inputs', level=1)) fp.writelines(write_rst_dict(self.inputs.get())) + if self.signals: + fp.writelines(write_rst_header('Signals', level=1)) + fp.writelines(write_rst_dict(self.signals.get())) if report_type == 'postexec': logger.debug('writing post-exec report to %s' % report_file) fp = open(report_file, 'at') @@ -1800,7 +777,7 @@ class JoinNode(Node): -------- >>> import nipype.pipeline.engine as pe - >>> from nipype import Node, JoinNode, Workflow + >>> from nipype.pipeline.engine import Node, JoinNode, Workflow >>> from nipype.interfaces.utility import IdentityInterface >>> from nipype.interfaces import (ants, dcm2nii, fsl) >>> wf = Workflow(name='preprocess') @@ -1893,7 +870,7 @@ def _add_join_item_fields(self): >>> from nipype.interfaces.utility import IdentityInterface >>> import nipype.pipeline.engine as pe - >>> from nipype import Node, JoinNode, Workflow + >>> from nipype.pipeline.engine import Node, JoinNode, Workflow >>> inputspec = Node(IdentityInterface(fields=['image']), ... name='inputspec'), >>> join = JoinNode(IdentityInterface(fields=['images', 'mask']), @@ -2026,7 +1003,7 @@ class MapNode(Node): Examples -------- - >>> from nipype import MapNode + >>> from nipype.pipeline.engine import MapNode >>> from nipype.interfaces import fsl >>> realign = MapNode(fsl.MCFLIRT(), 'in_file', 'realign') >>> realign.inputs.in_file = ['functional.nii', diff --git a/nipype/pipeline/report_template.html b/nipype/pipeline/engine/report_template.html similarity index 100% rename from nipype/pipeline/report_template.html rename to nipype/pipeline/engine/report_template.html diff --git a/nipype/pipeline/report_template2.html b/nipype/pipeline/engine/report_template2.html similarity index 100% rename from nipype/pipeline/report_template2.html rename to nipype/pipeline/engine/report_template2.html diff --git a/nipype/pipeline/tests/__init__.py b/nipype/pipeline/engine/tests/__init__.py similarity index 100% rename from nipype/pipeline/tests/__init__.py rename to nipype/pipeline/engine/tests/__init__.py diff --git a/nipype/pipeline/engine/tests/test_conditional.py b/nipype/pipeline/engine/tests/test_conditional.py new file mode 100644 index 0000000000..196493840d --- /dev/null +++ b/nipype/pipeline/engine/tests/test_conditional.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +from nipype.testing import (assert_raises, assert_equal, + assert_true, assert_false) +from nipype.interfaces import base as nib +from nipype.interfaces import utility as niu +from nipype.interfaces import io as nio +from nipype.pipeline import engine as pe +from copy import deepcopy +import os.path as op +from tempfile import mkdtemp +from shutil import rmtree +import json + + +ifresult = None + + +class SetInputSpec(nib.TraitedSpec): + val = nib.traits.Int(2, mandatory=True, desc='input') + + +class SetOutputSpec(nib.TraitedSpec): + out = nib.traits.Int(desc='ouput') + + +class SetInterface(nib.BaseInterface): + input_spec = SetInputSpec + output_spec = SetOutputSpec + _always_run = True + + def _run_interface(self, runtime): + global ifresult + runtime.returncode = 0 + ifresult = self.inputs.val + return runtime + + def _list_outputs(self): + global ifresult + outputs = self._outputs().get() + outputs['out'] = self.inputs.val + return outputs + + +def _base_workflow(): + + def _myfunc(val): + return val + 1 + + wf = pe.Workflow('InnerWorkflow') + inputnode = pe.Node(niu.IdentityInterface( + fields=['in_value']), 'inputnode') + outputnode = pe.Node(niu.IdentityInterface( + fields=['out_value']), 'outputnode') + func = pe.Node(niu.Function( + input_names=['val'], output_names=['out'], + function=_myfunc), 'functionnode') + ifset = pe.Node(SetInterface(), 'SetIface') + + wf.connect([ + (inputnode, func, [('in_value', 'val')]), + (func, ifset, [('out', 'val')]), + (ifset, outputnode, [('out', 'out_value')]) + ]) + return wf + + +def _base_cachedworkflow(): + + def _myfunc(a, b): + return a + b + + wf = pe.CachedWorkflow('InnerWorkflow', + cache_map=('c', 'out')) + + inputnode = pe.Node(niu.IdentityInterface( + fields=['a', 'b']), 'inputnode') + func = pe.Node(niu.Function( + input_names=['a', 'b'], output_names=['out'], + function=_myfunc), 'functionnode') + ifset = pe.Node(SetInterface(), 'SetIface') + + wf.connect([ + (inputnode, func, [('a', 'a'), ('b', 'b')]), + (func, 'output', [('out', 'out')]), + ('output', ifset, [('out', 'val')]) + ]) + return wf + + +def test_workflow_disable(): + global ifresult + wf = _base_workflow() + + ifresult = None + wf.inputs.inputnode.in_value = 0 + wf.run() + yield assert_equal, ifresult, 1 + + # Check if direct signal setting works + ifresult = None + wf.signals.disable = True + wf.run() + yield assert_equal, ifresult, None + + ifresult = None + wf.signals.disable = False + wf.run() + yield assert_equal, ifresult, 1 + + # Check if signalnode way works + ifresult = None + wf.inputs.signalnode.disable = True + wf.run() + yield assert_equal, ifresult, None + + ifresult = None + wf.inputs.signalnode.disable = False + wf.run() + yield assert_equal, ifresult, 1 + + # Check if one can set signal then node + ifresult = None + wf.signals.disable = True + wf.run() + yield assert_equal, ifresult, None + + ifresult = None + wf.inputs.signalnode.disable = False + wf.run() + yield assert_equal, ifresult, 1 + + # Check if one can set node then signal + ifresult = None + wf.inputs.signalnode.disable = True + wf.run() + yield assert_equal, ifresult, None + + ifresult = None + wf.signals.disable = False + wf.run() + yield assert_equal, ifresult, 1 + + +def test_workflow_disable_nested_A(): + global ifresult + + inner = _base_workflow() + dn = pe.Node(niu.IdentityInterface( + fields=['donotrun', 'value']), 'decisionnode') + + outer = pe.Workflow('OuterWorkflow', control=False) + + outer.connect([ + (dn, inner, [('donotrun', 'signalnode.disable')]) + ], conn_type='control') + + outer.connect([ + (dn, inner, [('value', 'inputnode.in_value')]) + ]) + + ifresult = None + outer.inputs.decisionnode.value = 0 + outer.run() + yield assert_equal, ifresult, 1 + + ifresult = None + outer.inputs.decisionnode.donotrun = False + outer.run() + yield assert_equal, ifresult, 1 + + ifresult = None + outer.inputs.decisionnode.donotrun = True + outer.run() + yield assert_equal, ifresult, None + + ifresult = None + outer.inputs.decisionnode.donotrun = False + outer.run() + yield assert_equal, ifresult, 1 + + +def test_workflow_disable_nested_B(): + global ifresult + + inner = _base_workflow() + dn = pe.Node(niu.IdentityInterface(fields=['value']), + 'inputnode') + + outer = pe.Workflow('OuterWorkflow') + + outer.connect([ + (dn, inner, [('value', 'inputnode.in_value')]) + ]) + + ifresult = None + outer.inputs.inputnode.value = 0 + outer.run() + yield assert_equal, ifresult, 1 + + ifresult = None + outer.signals.disable = True + outer.run() + yield assert_equal, ifresult, None + + ifresult = None + outer.signals.disable = False + outer.run() + yield assert_equal, ifresult, 1 + + +def test_cached_workflow(): + global ifresult + + cwf = _base_cachedworkflow() + cwf.inputs.inputnode.a = 2 + cwf.inputs.inputnode.b = 3 + + # check results + ifresult = None + res = cwf.run() + yield assert_equal, ifresult, 5 + + # check disable + # ifresult = None + # cwf.signals.disable = True + # res = cwf.run() + # yield assert_equal, ifresult, None + + ifresult = None + cwf.inputs.cachenode.c = 7 + res = cwf.run() + yield assert_equal, ifresult, 7 diff --git a/nipype/pipeline/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py similarity index 92% rename from nipype/pipeline/tests/test_engine.py rename to nipype/pipeline/engine/tests/test_engine.py index 30b2981b4c..f1feb892be 100644 --- a/nipype/pipeline/tests/test_engine.py +++ b/nipype/pipeline/engine/tests/test_engine.py @@ -54,7 +54,8 @@ def test_connect(): yield assert_true, mod1 in pipe._graph.nodes() yield assert_true, mod2 in pipe._graph.nodes() - yield assert_equal, pipe._graph.get_edge_data(mod1, mod2), {'connect': [('output1', 'input1')]} + yield assert_equal, pipe._graph.get_edge_data( + mod1, mod2), {'connect': [('output1', 'input1', 'data')]} def test_add_nodes(): @@ -514,27 +515,21 @@ def test_mapnode_nested(): cwd = os.getcwd() wd = mkdtemp() os.chdir(wd) - from nipype import MapNode, Function + from nipype import Function def func1(in1): return in1 + 1 - n1 = MapNode(Function(input_names=['in1'], - output_names=['out'], - function=func1), - iterfield=['in1'], - nested=True, - name='n1') + n1 = pe.MapNode(Function( + input_names=['in1'], output_names=['out'], function=func1), + iterfield=['in1'], nested=True, name='n1') n1.inputs.in1 = [[1, [2]], 3, [4, 5]] n1.run() print(n1.get_output('out')) yield assert_equal, n1.get_output('out'), [[2, [3]], 4, [5, 6]] - n2 = MapNode(Function(input_names=['in1'], - output_names=['out'], - function=func1), - iterfield=['in1'], - nested=False, - name='n1') + n2 = pe.MapNode(Function( + input_names=['in1'], output_names=['out'], function=func1), + iterfield=['in1'], nested=False, name='n1') n2.inputs.in1 = [[1, [2]], 3, [4, 5]] error_raised = False try: @@ -556,14 +551,10 @@ def func1(): def func2(a): return a + 1 - n1 = pe.Node(Function(input_names=[], - output_names=['a'], - function=func1), - name='n1') - n2 = pe.Node(Function(input_names=['a'], - output_names=['b'], - function=func2), - name='n2') + n1 = pe.Node(Function( + input_names=[], output_names=['a'], function=func1), name='n1') + n2 = pe.Node(Function( + input_names=['a'], output_names=['b'], function=func2), name='n2') w1 = pe.Workflow(name='test') modify = lambda x: x + 1 n1.inputs.a = 1 @@ -617,14 +608,10 @@ def func1(): def func2(a): return a + 1 - n1 = pe.Node(Function(input_names=[], - output_names=['a'], - function=func1), - name='n1') - n2 = pe.Node(Function(input_names=['a'], - output_names=['b'], - function=func2), - name='n2') + n1 = pe.Node(Function( + input_names=[], output_names=['a'], function=func1), name='n1') + n2 = pe.Node(Function( + input_names=['a'], output_names=['b'], function=func2), name='n2') w1 = pe.Workflow(name='test') modify = lambda x: x + 1 n1.inputs.a = 1 @@ -650,17 +637,15 @@ def test_mapnode_json(): cwd = os.getcwd() wd = mkdtemp() os.chdir(wd) - from nipype import MapNode, Function, Workflow + from nipype import Function def func1(in1): return in1 + 1 - n1 = MapNode(Function(input_names=['in1'], - output_names=['out'], - function=func1), - iterfield=['in1'], - name='n1') + n1 = pe.MapNode(Function( + input_names=['in1'], output_names=['out'], function=func1), + iterfield=['in1'], name='n1') n1.inputs.in1 = [1] - w1 = Workflow(name='test') + w1 = pe.Workflow(name='test') w1.base_dir = wd w1.config['execution']['crashdump_dir'] = wd w1.add_nodes([n1]) @@ -693,18 +678,16 @@ def test_serial_input(): cwd = os.getcwd() wd = mkdtemp() os.chdir(wd) - from nipype import MapNode, Function, Workflow + from nipype import Function def func1(in1): return in1 - n1 = MapNode(Function(input_names=['in1'], - output_names=['out'], - function=func1), - iterfield=['in1'], - name='n1') + n1 = pe.MapNode(Function( + input_names=['in1'], output_names=['out'], + function=func1), iterfield=['in1'], name='n1') n1.inputs.in1 = [1, 2, 3] - w1 = Workflow(name='test') + w1 = pe.Workflow(name='test') w1.base_dir = wd w1.add_nodes([n1]) # set local check diff --git a/nipype/pipeline/tests/test_join.py b/nipype/pipeline/engine/tests/test_join.py similarity index 100% rename from nipype/pipeline/tests/test_join.py rename to nipype/pipeline/engine/tests/test_join.py diff --git a/nipype/pipeline/tests/test_utils.py b/nipype/pipeline/engine/tests/test_utils.py similarity index 98% rename from nipype/pipeline/tests/test_utils.py rename to nipype/pipeline/engine/tests/test_utils.py index 50d44b78a0..6c48eba7e2 100644 --- a/nipype/pipeline/tests/test_utils.py +++ b/nipype/pipeline/engine/tests/test_utils.py @@ -9,12 +9,13 @@ from tempfile import mkdtemp from shutil import rmtree -from ...testing import (assert_equal, assert_true, assert_false) +from nipype.testing import (assert_equal, assert_true, assert_false) import nipype.pipeline.engine as pe import nipype.interfaces.base as nib import nipype.interfaces.utility as niu -from ... import config -from ..utils import merge_dict, clean_working_directory, write_workflow_prov +from nipype import config +from ..utils import merge_dict, clean_working_directory +from ..graph import write_workflow_prov def test_identitynode_removal(): diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py new file mode 100644 index 0000000000..d3a41cbfc3 --- /dev/null +++ b/nipype/pipeline/engine/utils.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +from future import standard_library +standard_library.install_aliases() +from builtins import range + +import os +import os.path as op +from glob import glob +import pickle +import inspect +from nipype import logging, config +from nipype.external.six import string_types +from nipype.interfaces.base import isdefined, Undefined +from nipype.utils.misc import create_function_from_source, str2bool +from nipype.utils.filemanip import (FileNotFoundError, filename_to_list, + get_related_files) + +logger = logging.getLogger('workflow') + +try: + from os.path import relpath +except ImportError: + def relpath(path, start=None): + """Return a relative version of a path""" + if start is None: + start = os.curdir + if not path: + raise ValueError("no path specified") + start_list = op.abspath(start).split(op.sep) + path_list = op.abspath(path).split(op.sep) + if start_list[0].lower() != path_list[0].lower(): + unc_path, rest = op.splitunc(path) + unc_start, rest = op.splitunc(start) + if bool(unc_path) ^ bool(unc_start): + raise ValueError(("Cannot mix UNC and non-UNC paths " + "(%s and %s)") % (path, start)) + else: + raise ValueError("path is on drive %s, start on drive %s" + % (path_list[0], start_list[0])) + # Work out how much of the filepath is shared by start and path. + for i in range(min(len(start_list), len(path_list))): + if start_list[i].lower() != path_list[i].lower(): + break + else: + i += 1 + + rel_list = [op.pardir] * (len(start_list) - i) + path_list[i:] + if not rel_list: + return os.curdir + return op.join(*rel_list) + + +def modify_paths(object, relative=True, basedir=None): + """Convert paths in data structure to either full paths or relative paths + + Supports combinations of lists, dicts, tuples, strs + + Parameters + ---------- + + relative : boolean indicating whether paths should be set relative to the + current directory + basedir : default os.getcwd() + what base directory to use as default + """ + if not basedir: + basedir = os.getcwd() + if isinstance(object, dict): + out = {} + for key, val in sorted(object.items()): + if isdefined(val): + out[key] = modify_paths(val, relative=relative, + basedir=basedir) + elif isinstance(object, (list, tuple)): + out = [] + for val in object: + if isdefined(val): + out.append(modify_paths(val, relative=relative, + basedir=basedir)) + if isinstance(object, tuple): + out = tuple(out) + else: + if isdefined(object): + if isinstance(object, string_types) and op.isfile(object): + if relative: + if config.getboolean('execution', 'use_relative_paths'): + out = relpath(object, start=basedir) + else: + out = object + else: + out = op.abspath(op.join(basedir, object)) + if not op.exists(out): + raise FileNotFoundError('File %s not found' % out) + else: + out = object + return out + + +def get_print_name(node, simple_form=True): + """Get the name of the node + + For example, a node containing an instance of interfaces.fsl.BET + would be called nodename.BET.fsl + + """ + name = node.fullname + if hasattr(node, '_interface'): + pkglist = node._interface.__class__.__module__.split('.') + interface = node._interface.__class__.__name__ + destclass = '' + if len(pkglist) > 2: + destclass = '.%s' % pkglist[2] + if simple_form: + name = node.fullname + destclass + else: + name = '.'.join([node.fullname, interface]) + destclass + if simple_form: + parts = name.split('.') + if len(parts) > 2: + return ' ('.join(parts[1:]) + ')' + elif len(parts) == 2: + return parts[1] + return name + + +def make_output_dir(outdir): + """Make the output_dir if it doesn't exist. + + Parameters + ---------- + outdir : output directory to create + + """ + if not op.exists(op.abspath(outdir)): + logger.debug("Creating %s" % outdir) + os.makedirs(outdir) + return outdir + + +def clean_working_directory(outputs, cwd, inputs, needed_outputs, config, + files2keep=None, dirs2keep=None): + """Removes all files not needed for further analysis from the directory + """ + if not outputs: + return + outputs_to_keep = list(outputs.get().keys()) + if needed_outputs and \ + str2bool(config['execution']['remove_unnecessary_outputs']): + outputs_to_keep = needed_outputs + # build a list of needed files + output_files = [] + outputdict = outputs.get() + for output in outputs_to_keep: + output_files.extend(walk_outputs(outputdict[output])) + needed_files = [path for path, type in output_files if type == 'f'] + if str2bool(config['execution']['keep_inputs']): + input_files = [] + inputdict = inputs.get() + input_files.extend(walk_outputs(inputdict)) + needed_files += [path for path, type in input_files if type == 'f'] + for extra in ['_0x*.json', 'provenance.*', 'pyscript*.m', 'pyjobs*.mat', + 'command.txt', 'result*.pklz', '_inputs.pklz', '_node.pklz']: + needed_files.extend(glob(os.path.join(cwd, extra))) + if files2keep: + needed_files.extend(filename_to_list(files2keep)) + needed_dirs = [path for path, type in output_files if type == 'd'] + if dirs2keep: + needed_dirs.extend(filename_to_list(dirs2keep)) + for extra in ['_nipype', '_report']: + needed_dirs.extend(glob(os.path.join(cwd, extra))) + temp = [] + for filename in needed_files: + temp.extend(get_related_files(filename)) + needed_files = temp + logger.debug('Needed files: %s' % (';'.join(needed_files))) + logger.debug('Needed dirs: %s' % (';'.join(needed_dirs))) + files2remove = [] + if str2bool(config['execution']['remove_unnecessary_outputs']): + for f in walk_files(cwd): + if f not in needed_files: + if len(needed_dirs) == 0: + files2remove.append(f) + elif not any([f.startswith(dname) for dname in needed_dirs]): + files2remove.append(f) + else: + if not str2bool(config['execution']['keep_inputs']): + input_files = [] + inputdict = inputs.get() + input_files.extend(walk_outputs(inputdict)) + input_files = [path for path, type in input_files if type == 'f'] + for f in walk_files(cwd): + if f in input_files and f not in needed_files: + files2remove.append(f) + logger.debug('Removing files: %s' % (';'.join(files2remove))) + for f in files2remove: + os.remove(f) + for key in outputs.copyable_trait_names(): + if key not in outputs_to_keep: + setattr(outputs, key, Undefined) + return outputs + + +def get_all_files(infile): + files = [infile] + if infile.endswith(".img"): + files.append(infile[:-4] + ".hdr") + files.append(infile[:-4] + ".mat") + if infile.endswith(".img.gz"): + files.append(infile[:-7] + ".hdr.gz") + return files + + +def walk_outputs(object): + """Extract every file and directory from a python structure + """ + out = [] + if isinstance(object, dict): + for key, val in sorted(object.items()): + if isdefined(val): + out.extend(walk_outputs(val)) + elif isinstance(object, (list, tuple)): + for val in object: + if isdefined(val): + out.extend(walk_outputs(val)) + else: + if isdefined(object) and isinstance(object, string_types): + if os.path.islink(object) or os.path.isfile(object): + out = [(filename, 'f') for filename in get_all_files(object)] + elif os.path.isdir(object): + out = [(object, 'd')] + return out + + +def walk_files(cwd): + for path, _, files in os.walk(cwd): + for f in files: + yield os.path.join(path, f) + + +def merge_dict(d1, d2, merge=lambda x, y: y): + """ + Merges two dictionaries, non-destructively, combining + values on duplicate keys as defined by the optional merge + function. The default behavior replaces the values in d1 + with corresponding values in d2. (There is no other generally + applicable merge strategy, but often you'll have homogeneous + types in your dicts, so specifying a merge technique can be + valuable.) + + Examples: + + >>> d1 = {'a': 1, 'c': 3, 'b': 2} + >>> d2 = merge_dict(d1, d1) + >>> len(d2) + 3 + >>> [d2[k] for k in ['a', 'b', 'c']] + [1, 2, 3] + + >>> d3 = merge_dict(d1, d1, lambda x,y: x+y) + >>> len(d3) + 3 + >>> [d3[k] for k in ['a', 'b', 'c']] + [2, 4, 6] + + """ + if not isinstance(d1, dict): + return merge(d1, d2) + result = dict(d1) + if d2 is None: + return result + for k, v in list(d2.items()): + if k in result: + result[k] = merge_dict(result[k], v, merge=merge) + else: + result[k] = v + return result + + +def merge_bundles(g1, g2): + for rec in g2.get_records(): + g1._add_record(rec) + return g1 + +def _write_inputs(node): + lines = [] + nodename = node.fullname.replace('.', '_') + for key, _ in list(node.inputs.items()): + val = getattr(node.inputs, key) + if isdefined(val): + if type(val) == str: + try: + func = create_function_from_source(val) + except RuntimeError as e: + lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) + else: + funcname = [name for name in func.__globals__ + if name != '__builtins__'][0] + lines.append(pickle.loads(val)) + if funcname == nodename: + lines[-1] = lines[-1].replace(' %s(' % funcname, + ' %s_1(' % funcname) + funcname = '%s_1' % funcname + lines.append('from nipype.utils.misc import getsource') + lines.append("%s.inputs.%s = getsource(%s)" % (nodename, + key, + funcname)) + else: + lines.append('%s.inputs.%s = %s' % (nodename, key, val)) + return lines + + +def format_node(node, format='python', include_config=False): + """Format a node in a given output syntax.""" + from .nodes import MapNode + lines = [] + name = node.fullname.replace('.', '_') + if format == 'python': + klass = node._interface + importline = 'from %s import %s' % (klass.__module__, + klass.__class__.__name__) + comment = '# Node: %s' % node.fullname + spec = inspect.signature(node._interface.__init__) + args = spec.args[1:] + if args: + filled_args = [] + for arg in args: + if hasattr(node._interface, '_%s' % arg): + filled_args.append('%s=%s' % (arg, getattr(node._interface, + '_%s' % arg))) + args = ', '.join(filled_args) + else: + args = '' + klass_name = klass.__class__.__name__ + if isinstance(node, MapNode): + nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' \ + % (name, klass_name, args, node.iterfield, name) + else: + nodedef = '%s = Node(%s(%s), name="%s")' \ + % (name, klass_name, args, name) + lines = [importline, comment, nodedef] + + if include_config: + lines = [importline, "from collections import OrderedDict", + comment, nodedef] + lines.append('%s.config = %s' % (name, node.config)) + + if node.iterables is not None: + lines.append('%s.iterables = %s' % (name, node.iterables)) + lines.extend(_write_inputs(node)) + + return lines diff --git a/nipype/pipeline/engine/workflows.py b/nipype/pipeline/engine/workflows.py new file mode 100644 index 0000000000..61dca75f32 --- /dev/null +++ b/nipype/pipeline/engine/workflows.py @@ -0,0 +1,1199 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Defines functionality for pipelined execution of interfaces + +The `Workflow` class provides core functionality for batch processing. + + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) + +""" + +from future import standard_library +standard_library.install_aliases() + +from datetime import datetime +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +from copy import deepcopy +import pickle +import os +import os.path as op +import shutil +import sys +from warnings import warn +import numpy as np +import networkx as nx + +from nipype.utils.misc import (getsource, create_function_from_source, + package_check, str2bool) +from nipype.external.six import string_types +from nipype import config, logging + +from nipype.interfaces.base import (traits, TraitedSpec, TraitDictObject, + TraitListObject) +from nipype.interfaces.utility import IdentityInterface + +from .utils import (make_output_dir, get_print_name, merge_dict) +from .graph import (generate_expanded_graph, export_graph, write_workflow_prov, + format_dot, topological_sort) +from .base import NodeBase +from .nodes import Node, MapNode + +logger = logging.getLogger('workflow') +package_check('networkx', '1.3') + + +class Workflow(NodeBase): + """Controls the setup and execution of a pipeline of processes.""" + _control = True + + def __init__(self, name, base_dir=None, control=True): + """Create a workflow object. + + Parameters + ---------- + name : alphanumeric string + unique identifier for the workflow + base_dir : string, optional + path to workflow storage + + """ + super(Workflow, self).__init__(name, base_dir) + self._graph = nx.DiGraph() + self._control = control + self.config = deepcopy(config._sections) + + if control: + self._signalnode = Node(IdentityInterface( + fields=self.signals.copyable_trait_names()), 'signalnode') + self.add_nodes([self._signalnode]) + + # Automatically initialize signal + for s in self.signals.copyable_trait_names(): + setattr(self._signalnode.inputs, s, getattr(self.signals, s)) + + def _update_disable(self): + logger.debug('Signal disable is now %s for workflow %s' % + (self.signals.disable, self.fullname)) + self._signalnode.inputs.disable = self.signals.disable + + # PUBLIC API + def clone(self, name): + """Clone a workflow + + .. note:: + + Will reset attributes used for executing workflow. See + _init_runtime_fields. + + Parameters + ---------- + + name: alphanumeric name + unique name for the workflow + + """ + clone = super(Workflow, self).clone(name) + clone._reset_hierarchy() + return clone + + # Graph creation functions + def connect(self, *args, **kwargs): + """Connect nodes in the pipeline. + + This routine also checks if inputs and outputs are actually provided by + the nodes that are being connected. + + Creates edges in the directed graph using the nodes and edges specified + in the `connection_list`. Uses the NetworkX method + DiGraph.add_edges_from. + + Parameters + ---------- + + args : list or a set of four positional arguments + + Four positional arguments of the form:: + + connect(source, sourceoutput, dest, destinput) + + source : nodewrapper node + sourceoutput : string (must be in source.outputs) + dest : nodewrapper node + destinput : string (must be in dest.inputs) + + A list of 3-tuples of the following form:: + + [(source, target, + [('sourceoutput/attribute', 'targetinput'), + ...]), + ...] + + Or:: + + [(source, target, [(('sourceoutput1', func, arg2, ...), + 'targetinput'), ...]), + ...] + sourceoutput1 will always be the first argument to func + and func will be evaluated and the results sent ot targetinput + + currently func needs to define all its needed imports within the + function as we use the inspect module to get at the source code + and execute it remotely + """ + if len(args) == 1: + connection_list = args[0] + elif len(args) == 4: + connection_list = [(args[0], args[2], [(args[1], args[3])])] + else: + raise TypeError('connect() takes either 4 arguments, or 1 list of' + ' connection tuples (%d args given)' % len(args)) + if not kwargs: + kwargs = {} + + disconnect = kwargs.get('disconnect', False) + + if disconnect: + self.disconnect(connection_list) + return + + conn_type = kwargs.get('conn_type', 'data') + logger.debug('connect(disconnect=%s, conn_type=%s): %s' % + (disconnect, conn_type, connection_list)) + + all_srcnodes = set([c[0] for c in connection_list]) + all_dstnodes = set([c[1] for c in connection_list]) + allnodes = all_srcnodes | all_dstnodes + + if self in allnodes: + raise IOError('Workflow connect cannot contain itself as node') + + # Check if nodes are already in the graph + nodesingraph = set(self._graph.nodes()) + newnodes = list(allnodes - nodesingraph) + if newnodes: + logger.debug('New nodes: %s, existing nodes: %s' % (newnodes, nodesingraph)) + for node in newnodes: + if node._hierarchy is None: + node._hierarchy = self.name + self._check_nodes(newnodes) + self._graph.add_nodes_from(newnodes) + + # check correctness of required connections + connected_ports = self._check_connected(list(all_dstnodes)) + + not_found = [] + redirected = [] + for srcnode, dstnode, connects in connection_list: + src_io = (hasattr(srcnode, '_interface') and + '.io' in str(srcnode._interface.__class__)) + dst_io = (hasattr(dstnode, '_interface') and + '.io' in str(dstnode._interface.__class__)) + + nodeconns = connected_ports.get(dstnode, []) + duplicated = [] + for source, dest in connects: + logger.debug('connect(%s): evaluating %s:%s -> %s:%s' % + (conn_type, srcnode, source, dstnode, dest)) + # Check port is not taken + if dest in nodeconns: + duplicated.append((srcnode, source, dstnode, dest)) + continue + + # Currently datasource/sink/grabber.io modules + # determine their inputs/outputs depending on + # connection settings. Skip these modules in the check + if not dst_io: + if not dstnode._check_inputs(dest): + not_found.append(['in', '%s' % dstnode, dest]) + + if not src_io: + if isinstance(source, tuple): + # handles the case that source is specified + # with a function + sourcename = source[0] + elif isinstance(source, string_types): + sourcename = source + else: + raise Exception( + 'Unknown source specification in connection from ' + 'output of %s' % srcnode.name) + + if sourcename and not srcnode._check_outputs(sourcename): + not_found.append(['out', '%s' % srcnode, sourcename]) + + nodeconns += [dest] + connected_ports[dstnode] = nodeconns + + if conn_type == 'data': + if duplicated: + raise Exception( + 'connect(): found duplicated connections.\n\t\t' + + '\n\t\t'.join(['%s.%s -> %s.%s' % c for c in duplicated])) + infostr = [] + for info in not_found: + infostr += ["Module %s has no %sput called %s\n" % (info[1], + info[0], + info[2])] + if not_found: + infostr.insert( + 0, 'Some connections were not found connecting %s.%s to ' + '%s.%s' % (srcnode, source, dstnode, dest)) + raise Exception('\n'.join(infostr)) + else: + if duplicated: + logger.debug('Duplicated signal' + '\n\t\t'.join( + ['%s.%s -> %s.%s' % c for c in duplicated])) + + # turn functions into strings + for srcnode, dstnode, connects in connection_list: + for idx, (src, dest) in enumerate(connects): + if isinstance(src, tuple) and not isinstance(src[1], string_types): + function_source = getsource(src[1]) + connects[idx] = ((src[0], function_source, src[2:]), dest) + + # add connections + for srcnode, dstnode, connects in connection_list: + edge_data = self._graph.get_edge_data( + srcnode, dstnode, {'connect': []}) + + msg = 'No existing connections' if not edge_data['connect'] else \ + 'Previous connections exist' + msg += ' from %s to %s %s' % (srcnode.fullname, dstnode.fullname, + connects) + logger.debug(msg) + + edge_data['connect'] += [(c[0], c[1], conn_type) + for c in connects] + logger.debug('(%s, %s): new edge data: %s' % + (srcnode, dstnode, str(edge_data))) + + self._graph.add_edges_from([(srcnode, dstnode, edge_data)]) + + # Check that connections are actually created + edge_data = self._graph.get_edge_data(srcnode, dstnode) + if not edge_data['connect']: + self._graph.remove_edge(srcnode, dstnode) + + def disconnect(self, *args): + """Disconnect nodes + See the docstring for connect for format. + """ + if len(args) == 1: + connection_list = args[0] + elif len(args) == 4: + connection_list = [(args[0], args[2], [(args[1], args[3])])] + else: + raise TypeError('disconnect() takes either 4 arguments, or 1 list ' + 'of connection tuples (%d args given)' % len(args)) + + for srcnode, dstnode, conn in connection_list: + logger.debug('disconnect(): %s->%s %s' % (srcnode, dstnode, conn)) + if self in [srcnode, dstnode]: + raise IOError( + 'Workflow connect cannot contain itself as node: src[%s] ' + 'dest[%s] workflow[%s]') % (srcnode, dstnode, self.name) + + # If node is not in the graph, not connected + if not self._has_node(srcnode) or not self._has_node(dstnode): + continue + + edge_data = self._graph.get_edge_data( + srcnode, dstnode, {'connect': []}) + ed_conns = [(c[0], c[1]) for c in edge_data['connect']] + ed_meta = [c[2] for c in edge_data['connect']] + + remove = [] + for edge in conn: + if edge in ed_conns: + idx = ed_conns.index(edge) + remove.append((edge[0], edge[1], ed_meta[idx])) + + logger.debug('disconnect(): remove list %s' % remove) + for el in remove: + edge_data['connect'].remove(el) + logger.debug('disconnect(): removed connection %s' % str(el)) + + if not edge_data['connect']: + self._graph.remove_edge(srcnode, dstnode) + else: + self._graph.add_edges_from( + [(srcnode, dstnode, edge_data)]) + + def add_nodes(self, nodes): + """ Add nodes to a workflow + + Parameters + ---------- + nodes : list + A list of NodeBase-based objects + """ + newnodes = [] + all_nodes = self._get_all_nodes() + all_nodenames = [n.name for n in all_nodes] + for node in nodes: + if self._has_node(node): + raise IOError('Node %s already exists in the workflow' % node) + + logger.debug('Node: %s, names: %s' % (node.name, all_nodenames)) + if node.name in all_nodenames: + raise IOError('Workflow %s already contains a node called' + ' %s' % (self, node.name)) + + if isinstance(node, Workflow): + for subnode in node._get_all_nodes(): + if subnode in all_nodes: + raise IOError(('Subnode %s of node %s already exists ' + 'in the workflow') % (subnode, node)) + newnodes.append(node) + if not newnodes: + logger.debug('no new nodes to add') + return + for node in newnodes: + if not issubclass(node.__class__, NodeBase): + raise Exception('Node %s must be a subclass of NodeBase' % + str(node)) + self._check_nodes(newnodes) + for node in newnodes: + if node._hierarchy is None: + node._hierarchy = self.name + self._graph.add_nodes_from(newnodes) + + def remove_nodes(self, nodes): + """ Remove nodes from a workflow + + Parameters + ---------- + nodes : list + A list of NodeBase-based objects + """ + self._graph.remove_nodes_from(nodes) + + # Input-Output access + @property + def inputs(self): + return self._get_inputs() + + @property + def outputs(self): + return self._get_outputs() + + def get_node(self, name): + """Return an internal node by name + """ + nodenames = name.split('.') + nodename = nodenames[0] + outnode = [node for node in self._graph.nodes() if + str(node).endswith('.' + nodename)] + if outnode: + outnode = outnode[0] + if nodenames[1:] and issubclass(outnode.__class__, Workflow): + outnode = outnode.get_node('.'.join(nodenames[1:])) + else: + outnode = None + return outnode + + def list_node_names(self): + """List names of all nodes in a workflow + """ + outlist = [] + sorted_nodes = nx.topological_sort(self._graph) + logger.debug('list_node_names(): sorted nodes %s' % sorted_nodes) + for node in sorted_nodes: + if isinstance(node, Workflow): + outlist.extend(['.'.join((node.name, nodename)) for nodename in + node.list_node_names()]) + else: + outlist.append(node.name) + return sorted(outlist) + + def write_graph(self, dotfilename='graph.dot', graph2use='hierarchical', + format="png", simple_form=True): + """Generates a graphviz dot file and a png file + + Parameters + ---------- + + graph2use: 'orig', 'hierarchical' (default), 'flat', 'exec', 'colored' + orig - creates a top level graph without expanding internal + workflow nodes; + flat - expands workflow nodes recursively; + hierarchical - expands workflow nodes recursively with a + notion on hierarchy; + colored - expands workflow nodes recursively with a + notion on hierarchy in color; + exec - expands workflows to depict iterables + + format: 'png', 'svg' + + simple_form: boolean (default: True) + Determines if the node name used in the graph should be of the form + 'nodename (package)' when True or 'nodename.Class.package' when + False. + + """ + self._connect_signals() + + graphtypes = ['orig', 'flat', 'hierarchical', 'exec', 'colored'] + if graph2use not in graphtypes: + raise ValueError('Unknown graph2use keyword. Must be one of: ' + + str(graphtypes)) + base_dir, dotfilename = op.split(dotfilename) + if base_dir == '': + if self.base_dir: + base_dir = self.base_dir + if self.name: + base_dir = op.join(base_dir, self.name) + else: + base_dir = os.getcwd() + base_dir = make_output_dir(base_dir) + if graph2use in ['hierarchical', 'colored']: + dotfilename = op.join(base_dir, dotfilename) + self.write_hierarchical_dotfile(dotfilename=dotfilename, + colored=graph2use == "colored", + simple_form=simple_form) + format_dot(dotfilename, format=format) + else: + graph = self._graph + if graph2use in ['flat', 'exec']: + graph = self._create_flat_graph() + if graph2use == 'exec': + graph = generate_expanded_graph(deepcopy(graph)) + export_graph(graph, base_dir, dotfilename=dotfilename, + format=format, simple_form=simple_form) + + def write_hierarchical_dotfile(self, dotfilename=None, colored=False, + simple_form=True): + dotlist = ['digraph %s{' % self.name] + dotlist.append(self._get_dot(prefix=' ', colored=colored, + simple_form=simple_form)) + dotlist.append('}') + dotstr = '\n'.join(dotlist) + if dotfilename: + fp = open(dotfilename, 'wt') + fp.writelines(dotstr) + fp.close() + else: + logger.info(dotstr) + + def export(self, filename=None, prefix="output", format="python", + include_config=False): + """Export object into a different format + + Parameters + ---------- + filename: string + file to save the code to; overrides prefix + prefix: string + prefix to use for output file + format: string + one of "python" + include_config: boolean + whether to include node and workflow config values + + """ + from utils import format_node + + formats = ["python"] + if format not in formats: + raise ValueError('format must be one of: %s' % '|'.join(formats)) + flatgraph = self._create_flat_graph() + nodes = nx.topological_sort(flatgraph) + + lines = ['# Workflow'] + importlines = ['from nipype.pipeline.engine import Workflow, ' + 'Node, MapNode'] + functions = {} + if format == "python": + connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name + connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' \ + % self.name + wfdef = '%s = Workflow("%s")' % (self.name, self.name) + lines.append(wfdef) + if include_config: + lines.append('%s.config = %s' % (self.name, self.config)) + for idx, node in enumerate(nodes): + nodename = node.fullname.replace('.', '_') + # write nodes + nodelines = format_node(node, format='python', + include_config=include_config) + for line in nodelines: + if line.startswith('from'): + if line not in importlines: + importlines.append(line) + else: + lines.append(line) + # write connections + for u, _, d in flatgraph.in_edges_iter(nbunch=node, + data=True): + for cd in d['connect']: + if isinstance(cd[0], tuple): + args = list(cd[0]) + if args[1] in functions: + funcname = functions[args[1]] + else: + func = create_function_from_source(args[1]) + funcname = [name for name in func.__globals__ + if name != '__builtins__'][0] + functions[args[1]] = funcname + args[1] = funcname + args = tuple([arg for arg in args if arg]) + line_args = (u.fullname.replace('.', '_'), + args, nodename, cd[1]) + line = connect_template % line_args + line = line.replace("'%s'" % funcname, funcname) + lines.append(line) + else: + line_args = (u.fullname.replace('.', '_'), + cd[0], nodename, cd[1]) + lines.append(connect_template2 % line_args) + functionlines = ['# Functions'] + for function in functions: + functionlines.append(pickle.loads(function).rstrip()) + all_lines = importlines + functionlines + lines + + if not filename: + filename = '%s%s.py' % (prefix, self.name) + with open(filename, 'wt') as fp: + fp.writelines('\n'.join(all_lines)) + return all_lines + + def run(self, plugin=None, plugin_args=None, updatehash=False): + """ Execute the workflow + + Parameters + ---------- + + plugin: plugin name or object + Plugin to use for execution. You can create your own plugins for + execution. + plugin_args : dictionary containing arguments to be sent to plugin + constructor. see individual plugin doc strings for details. + """ + self._connect_signals() + + if plugin is None: + plugin = config.get('execution', 'plugin') + if not isinstance(plugin, string_types): + runner = plugin + else: + name = 'nipype.pipeline.plugins' + try: + __import__(name) + except ImportError: + msg = 'Could not import plugin module: %s' % name + logger.error(msg) + raise ImportError(msg) + else: + plugin_mod = getattr(sys.modules[name], '%sPlugin' % plugin) + runner = plugin_mod(plugin_args=plugin_args) + flatgraph = self._create_flat_graph() + self.config = merge_dict(deepcopy(config._sections), self.config) + if 'crashdump_dir' in self.config: + warn(("Deprecated: workflow.config['crashdump_dir']\n" + "Please use config['execution']['crashdump_dir']")) + crash_dir = self.config['crashdump_dir'] + self.config['execution']['crashdump_dir'] = crash_dir + del self.config['crashdump_dir'] + logger.info(str(sorted(self.config))) + self._set_needed_outputs(flatgraph) + execgraph = generate_expanded_graph(deepcopy(flatgraph)) + for index, node in enumerate(execgraph.nodes()): + node.config = merge_dict(deepcopy(self.config), node.config) + node.base_dir = self.base_dir + node.index = index + if isinstance(node, MapNode): + node.use_plugin = (plugin, plugin_args) + self._configure_exec_nodes(execgraph) + if str2bool(self.config['execution']['create_report']): + self._write_report_info(self.base_dir, self.name, execgraph) + runner.run(execgraph, updatehash=updatehash, config=self.config) + datestr = datetime.utcnow().strftime('%Y%m%dT%H%M%S') + if str2bool(self.config['execution']['write_provenance']): + prov_base = op.join(self.base_dir, + 'workflow_provenance_%s' % datestr) + logger.info('Provenance file prefix: %s' % prov_base) + write_workflow_prov(execgraph, prov_base, format='all') + return execgraph + + # PRIVATE API AND FUNCTIONS + + def _write_report_info(self, workingdir, name, graph): + from nipype.utils.filemanip import save_json + if workingdir is None: + workingdir = os.getcwd() + report_dir = op.join(workingdir, name) + if not op.exists(report_dir): + os.makedirs(report_dir) + shutil.copyfile(op.join(op.dirname(__file__), + 'report_template.html'), + op.join(report_dir, 'index.html')) + shutil.copyfile(op.join(op.dirname(__file__), + '..', '..', 'external', 'd3.js'), + op.join(report_dir, 'd3.js')) + nodes, groups = topological_sort(graph, depth_first=True) + graph_file = op.join(report_dir, 'graph1.json') + json_dict = {'nodes': [], 'links': [], 'groups': [], 'maxN': 0} + for i, node in enumerate(nodes): + report_file = "%s/_report/report.rst" % \ + node.output_dir().replace(report_dir, '') + result_file = "%s/result_%s.pklz" % \ + (node.output_dir().replace(report_dir, ''), + node.name) + json_dict['nodes'].append(dict(name='%d_%s' % (i, node.name), + report=report_file, + result=result_file, + group=groups[i])) + maxN = 0 + for gid in np.unique(groups): + procs = [i for i, val in enumerate(groups) if val == gid] + N = len(procs) + if N > maxN: + maxN = N + json_dict['groups'].append(dict(procs=procs, + total=N, + name='Group_%05d' % gid)) + json_dict['maxN'] = maxN + for u, v in graph.in_edges_iter(): + json_dict['links'].append(dict(source=nodes.index(u), + target=nodes.index(v), + value=1)) + save_json(graph_file, json_dict) + graph_file = op.join(report_dir, 'graph.json') + template = '%%0%dd_' % np.ceil(np.log10(len(nodes))).astype(int) + + def getname(u, i): + name_parts = u.fullname.split('.') + # return '.'.join(name_parts[:-1] + [template % i + name_parts[-1]]) + return template % i + name_parts[-1] + json_dict = [] + for i, node in enumerate(nodes): + imports = [] + for u, v in graph.in_edges_iter(nbunch=node): + imports.append(getname(u, nodes.index(u))) + json_dict.append(dict(name=getname(node, i), + size=1, + group=groups[i], + imports=imports)) + save_json(graph_file, json_dict) + + def _set_needed_outputs(self, graph): + """Initialize node with list of which outputs are needed.""" + rm_outputs = self.config['execution']['remove_unnecessary_outputs'] + if not str2bool(rm_outputs): + return + for node in graph.nodes(): + node.needed_outputs = [] + for edge in graph.out_edges_iter(node): + data = graph.get_edge_data(*edge) + sourceinfo = [v1[0] if isinstance(v1, tuple) else v1 + for v1, v2, _ in data['connect']] + node.needed_outputs += [v for v in sourceinfo + if v not in node.needed_outputs] + if node.needed_outputs: + node.needed_outputs = sorted(node.needed_outputs) + + def _configure_exec_nodes(self, graph): + """Ensure that each node knows where to get inputs from + """ + for node in graph.nodes(): + node.input_source = {} + + for edge in graph.in_edges_iter(node): + data = graph.get_edge_data(*edge) + for conn in sorted(data['connect']): + sourceinfo, field = conn[0], conn[1] + + if node._check_inputs(field): + node.input_source[field] = \ + (op.join(edge[0].output_dir(), + 'result_%s.pklz' % edge[0].name), + sourceinfo) + + logger.debug('Node %s input_source is %s' % (node, node.input_source)) + + def _check_connected(self, nodes): + logger.debug('Checking input connections of %s' % nodes) + allnodes = self._graph.nodes() + + connected = {} + for node in nodes: + if node in allnodes: + edges = self._graph.in_edges_iter(node) + data = [self._graph.get_edge_data(*e)['connect'] + for e in edges] + data = [v for d in data for v in d] + + connected[node] = [] + for d in data: + is_control = (len(d) == 3 and d[2] == 'control') + if not is_control: + connected[node].append(d[1]) + + if not connected[node]: + connected.pop(node, None) + + if connected: + logger.debug('Connected ports found: %s' % connected) + return connected + + def _check_nodes(self, nodes): + """Checks if any of the nodes are already in the graph + """ + node_names = [node.fullname for node in self._graph.nodes()] + node_lineage = [node._hierarchy for node in self._graph.nodes()] + for node in nodes: + if node.fullname in node_names: + idx = node_names.index(node.fullname) + if node_lineage[idx] in [node._hierarchy, self.fullname]: + raise IOError('Duplicate node %s found.' % node) + else: + node_names.append(node.name) + + def _has_attr(self, parameter, subtype='in'): + """Checks if a parameter is available as an input or output + """ + if subtype == 'in': + subobject = self.inputs + else: + subobject = self.outputs + attrlist = parameter.split('.') + cur_out = subobject + for attr in attrlist: + if not hasattr(cur_out, attr): + return False + cur_out = getattr(cur_out, attr) + return True + + def _get_parameter_node(self, parameter, subtype='in'): + """Returns the underlying node corresponding to an input or + output parameter + """ + if subtype == 'in': + subobject = self.inputs + else: + subobject = self.outputs + attrlist = parameter.split('.') + cur_out = subobject + for attr in attrlist[:-1]: + cur_out = getattr(cur_out, attr) + return cur_out.traits()[attrlist[-1]].node + + def _check_outputs(self, parameter): + return self._has_attr(parameter, subtype='out') + + def _check_inputs(self, parameter): + return self._has_attr(parameter, subtype='in') + + def _get_inputs(self): + """Returns the inputs of a workflow + + This function does not return any input ports that are already + connected + """ + inputdict = TraitedSpec() + for node in self._graph.nodes(): + # if node == self._signalnode: + # continue + + inputdict.add_trait(node.name, traits.Instance(TraitedSpec)) + if isinstance(node, Workflow): + setattr(inputdict, node.name, node.inputs) + else: + taken_inputs = set([ + cd[1] for _, _, d in self._graph.in_edges_iter( + nbunch=node, data=True) for cd in d['connect']]) + av_inputs = set(node.inputs.get()) + + unconnectedinputs = TraitedSpec() + for key in (av_inputs - taken_inputs): + trait = node.inputs.trait(key) + unconnectedinputs.add_trait( + key, traits.Trait(trait, node=node)) + + value = getattr(node.inputs, key) + setattr(unconnectedinputs, key, value) + + setattr(inputdict, node.name, unconnectedinputs) + getattr(inputdict, node.name).on_trait_change(self._set_input) + return inputdict + + def _get_outputs(self): + """Returns all possible output ports that are not already connected + """ + outputdict = TraitedSpec() + for node in self._graph.nodes(): + outputdict.add_trait(node.name, traits.Instance(TraitedSpec)) + if isinstance(node, Workflow): + setattr(outputdict, node.name, node.outputs) + elif node.outputs: + outputs = TraitedSpec() + for key, _ in list(node.outputs.items()): + outputs.add_trait(key, traits.Any(node=node)) + setattr(outputs, key, None) + setattr(outputdict, node.name, outputs) + return outputdict + + def _set_input(self, object, name, newvalue): + """Trait callback function to update a node input + """ + logger.debug('_set_input(%s, %s) on %s.' % ( + name, newvalue, self.fullname)) + object.traits()[name].node.set_input(name, newvalue) + + def _set_node_input(self, node, param, source, sourceinfo): + """Set inputs of a node given the edge connection""" + if isinstance(sourceinfo, string_types): + val = source.get_output(sourceinfo) + elif isinstance(sourceinfo, tuple): + if callable(sourceinfo[1]): + val = sourceinfo[1](source.get_output(sourceinfo[0]), + *sourceinfo[2:]) + newval = val + if isinstance(val, TraitDictObject): + newval = dict(val) + if isinstance(val, TraitListObject): + newval = val[:] + logger.debug('setting node input: %s->%s', param, str(newval)) + node.set_input(param, deepcopy(newval)) + + def _get_all_nodes(self): + allnodes = [] + for node in self._graph.nodes(): + if isinstance(node, Workflow): + allnodes.extend(node._get_all_nodes()) + else: + allnodes.append(node) + return allnodes + + def _has_node(self, wanted_node): + for node in self._graph.nodes(): + if wanted_node == node: + return True + if isinstance(node, Workflow): + if node._has_node(wanted_node): + return True + return False + + def _connect_signals(self): + logger.debug('Workflow %s called _connect_signals()' % + self.fullname) + + for node in self._graph.nodes(): + if isinstance(node, Workflow): + node._connect_signals() + + if self._control: + signals = self.signals.copyable_trait_names() + + for node in self._graph.nodes(): + if node == self._signalnode: + continue + + logger.debug('connect_signals(%s) %s' % (self, node)) + if node.signals is None: + continue + + prefix = '' + if isinstance(node, Workflow): + prefix = 'signalnode.' + + for s in signals: + sdest = prefix + s + self.connect(self._signalnode, s, node, sdest, + conn_type='control') + + def _create_flat_graph(self): + """Make a simple DAG where no node is a workflow.""" + logger.debug('Creating flat graph for workflow: %s', self.name) + workflowcopy = deepcopy(self) + workflowcopy._generate_flatgraph() + return workflowcopy._graph + + def _reset_hierarchy(self): + """Reset the hierarchy on a graph + """ + for node in self._graph.nodes(): + if isinstance(node, Workflow): + node._reset_hierarchy() + for innernode in node._graph.nodes(): + innernode._hierarchy = '.'.join((self.name, + innernode._hierarchy)) + else: + node._hierarchy = self.name + + def _generate_flatgraph(self): + """Generate a graph containing only Nodes or MapNodes + """ + logger.debug('expanding workflow: %s', self) + nodes2remove = [] + if not nx.is_directed_acyclic_graph(self._graph): + raise Exception(('Workflow: %s is not a directed acyclic graph ' + '(DAG)') % self.name) + sorted_nodes = nx.topological_sort(self._graph) + logger.debug('_generate_flatgraph(): sorted nodes %s' % sorted_nodes) + for node in sorted_nodes: + logger.debug('processing node: %s' % node) + if isinstance(node, Workflow): + nodes2remove.append(node) + # use in_edges instead of in_edges_iter to allow + # disconnections to take place properly. otherwise, the + # edge dict is modified. + for u, _, d in self._graph.in_edges(nbunch=node, data=True): + logger.debug('in: connections-> %s' % str(d['connect'])) + for cd in deepcopy(d['connect']): + logger.debug("in: %s" % str(cd)) + dstnode = node._get_parameter_node(cd[1], subtype='in') + srcnode = u + srcout = cd[0] + dstin = cd[1].split('.')[-1] + logger.debug('in edges: %s %s %s %s' % + (srcnode, srcout, dstnode, dstin)) + self.disconnect(u, cd[0], node, cd[1]) + self.connect(srcnode, srcout, dstnode, dstin, + conn_type=cd[2]) + # do not use out_edges_iter for reasons stated in in_edges + for _, v, d in self._graph.out_edges(nbunch=node, data=True): + logger.debug('out: connections-> %s' % str(d['connect'])) + for cd in deepcopy(d['connect']): + logger.debug("out: %s" % str(cd)) + dstnode = v + if isinstance(cd[0], tuple): + parameter = cd[0][0] + else: + parameter = cd[0] + srcnode = node._get_parameter_node(parameter, + subtype='out') + if isinstance(cd[0], tuple): + srcout = list(cd[0]) + srcout[0] = parameter.split('.')[-1] + srcout = tuple(srcout) + else: + srcout = parameter.split('.')[-1] + dstin = cd[1] + logger.debug('out edges: %s %s %s %s' % (srcnode, + srcout, + dstnode, + dstin)) + self.disconnect(node, cd[0], v, cd[1]) + self.connect(srcnode, srcout, dstnode, dstin) + # expand the workflow node + # logger.debug('expanding workflow: %s', node) + node._generate_flatgraph() + for innernode in node._graph.nodes(): + innernode._hierarchy = '.'.join((self.name, + innernode._hierarchy)) + self._graph.add_nodes_from(node._graph.nodes()) + self._graph.add_edges_from(node._graph.edges(data=True)) + if nodes2remove: + self._graph.remove_nodes_from(nodes2remove) + logger.debug('finished expanding workflow: %s', self) + + def _get_dot(self, prefix=None, hierarchy=None, colored=False, + simple_form=True, level=0): + """Create a dot file with connection info + """ + if prefix is None: + prefix = ' ' + if hierarchy is None: + hierarchy = [] + colorset = ['#FFFFC8', '#0000FF', '#B4B4FF', '#E6E6FF', '#FF0000', + '#FFB4B4', '#FFE6E6', '#00A300', '#B4FFB4', '#E6FFE6'] + + dotlist = ['%slabel="%s";' % (prefix, self.name)] + for node in nx.topological_sort(self._graph): + fullname = '.'.join(hierarchy + [node.fullname]) + nodename = fullname.replace('.', '_') + if not isinstance(node, Workflow): + node_class_name = get_print_name(node, simple_form=simple_form) + if not simple_form: + node_class_name = '.'.join(node_class_name.split('.')[1:]) + if hasattr(node, 'iterables') and node.iterables: + dotlist.append(('%s[label="%s", shape=box3d,' + 'style=filled, color=black, colorscheme' + '=greys7 fillcolor=2];') % (nodename, + node_class_name)) + else: + if colored: + dotlist.append(('%s[label="%s", style=filled,' + ' fillcolor="%s"];') + % (nodename, node_class_name, + colorset[level])) + else: + dotlist.append(('%s[label="%s"];') + % (nodename, node_class_name)) + + for node in nx.topological_sort(self._graph): + if isinstance(node, Workflow): + fullname = '.'.join(hierarchy + [node.fullname]) + nodename = fullname.replace('.', '_') + dotlist.append('subgraph cluster_%s {' % nodename) + if colored: + dotlist.append(prefix + prefix + 'edge [color="%s"];' % (colorset[level + 1])) + dotlist.append(prefix + prefix + 'style=filled;') + dotlist.append(prefix + prefix + 'fillcolor="%s";' % (colorset[level + 2])) + dotlist.append(node._get_dot(prefix=prefix + prefix, + hierarchy=hierarchy + [self.name], + colored=colored, + simple_form=simple_form, level=level + 3)) + dotlist.append('}') + if level == 6: + level = 2 + else: + for subnode in self._graph.successors_iter(node): + if node._hierarchy != subnode._hierarchy: + continue + if not isinstance(subnode, Workflow): + nodefullname = '.'.join(hierarchy + [node.fullname]) + subnodefullname = '.'.join(hierarchy + + [subnode.fullname]) + nodename = nodefullname.replace('.', '_') + subnodename = subnodefullname.replace('.', '_') + for _ in self._graph.get_edge_data(node, + subnode)['connect']: + dotlist.append('%s -> %s;' % (nodename, + subnodename)) + logger.debug('connection: ' + dotlist[-1]) + # add between workflow connections + for u, v, d in self._graph.edges_iter(data=True): + uname = '.'.join(hierarchy + [u.fullname]) + vname = '.'.join(hierarchy + [v.fullname]) + for src, dest, _ in d['connect']: + uname1 = uname + vname1 = vname + if isinstance(src, tuple): + srcname = src[0] + else: + srcname = src + if '.' in srcname: + uname1 += '.' + '.'.join(srcname.split('.')[:-1]) + if '.' in dest and '@' not in dest: + if not isinstance(v, Workflow): + if 'datasink' not in \ + str(v._interface.__class__).lower(): + vname1 += '.' + '.'.join(dest.split('.')[:-1]) + else: + vname1 += '.' + '.'.join(dest.split('.')[:-1]) + if uname1.split('.')[:-1] != vname1.split('.')[:-1]: + dotlist.append('%s -> %s;' % (uname1.replace('.', '_'), + vname1.replace('.', '_'))) + logger.debug('cross connection: ' + dotlist[-1]) + return ('\n' + prefix).join(dotlist) + + +class CachedWorkflow(Workflow): + """ + Implements a kind of workflow that can be by-passed if all the fields + of an input `cachenode` are set. + """ + + def __init__(self, name, base_dir=None, cache_map=[]): + """Create a workflow object. + + Parameters + ---------- + + name : alphanumeric string + unique identifier for the workflow + base_dir : string, optional + path to workflow storage + cache_map : list of tuples, non-empty + each tuple indicates the input port name and the node and output + port name, for instance ('b', 'sum') will map the + workflow input 'cachenode.b' to 'outputnode.sum'. + """ + + from nipype.interfaces.utility import CheckInterface, Merge, Select + super(CachedWorkflow, self).__init__(name, base_dir) + + if cache_map is None or not cache_map: + raise ValueError('CachedWorkflow cache_map must be a ' + 'non-empty list of tuples') + + if isinstance(cache_map, tuple): + cache_map = [cache_map] + + cond_in, cond_out = zip(*cache_map) + self._cache = Node(IdentityInterface( + fields=list(cond_in)), name='cachenode') + self._check = Node(CheckInterface( + fields=list(cond_in)), 'decidenode', control=False) + self._outputnode = Node(IdentityInterface( + fields=cond_out), name='outputnode') + + def _switch_idx(val): + return [int(val)] + + def _fix_undefined(val): + from nipype.interfaces.base import isdefined + if isdefined(val): + return val + return None + + self._plain_connect(self._check, 'out', self._signalnode, 'disable', + conn_type='control') + self._switches = {} + for ci, co in cache_map: + m = Node(Merge(2), 'Merge_%s' % co, control=False) + s = Node(Select(), 'Switch_%s' % co, control=False) + self._plain_connect([ + (m, s, [('out', 'inlist')]), + (self._cache, self._check, [(ci, ci)]), + (self._cache, m, [((ci, _fix_undefined), 'in2')]), + (self._signalnode, s, [(('disable', _switch_idx), 'index')]), + (s, self._outputnode, [('out', co)]) + ]) + self._switches[co] = m + + def _plain_connect(self, *args, **kwargs): + super(CachedWorkflow, self).connect(*args, **kwargs) + + def connect(self, *args, **kwargs): + """Connect nodes in the pipeline. + """ + if len(args) == 1: + connection_list = args[0] + elif len(args) == 4: + connection_list = [(args[0], args[2], [(args[1], args[3])])] + else: + raise TypeError('connect() takes either 4 arguments, or 1 list of' + ' connection tuples (%d args given)' % len(args)) + if not kwargs: + kwargs = {} + + disconnect = kwargs.get('disconnect', False) + + if disconnect: + self.disconnect(connection_list) + return + + conn_type = kwargs.get('conn_type', 'data') + + list_conns = [] + for srcnode, dstnode, conns in connection_list: + is_output = (isinstance(dstnode, string_types) and + dstnode == 'output') + if is_output: + for srcport, dstport in conns: + mrgnode = self._switches.get(dstport, None) + if mrgnode is None: + raise RuntimeError('Destination port not found') + logger.debug('Mapping %s to %s' % (srcport, dstport)) + list_conns.append((srcnode, mrgnode, [(srcport, 'in1')])) + else: + if (isinstance(srcnode, string_types) and + srcnode == 'output'): + srcnode = self._outputnode + list_conns.append((srcnode, dstnode, conns)) + super(CachedWorkflow, self).connect(list_conns, disconnect=disconnect, + conn_type=conn_type) \ No newline at end of file diff --git a/nipype/pipeline/plugins/base.py b/nipype/pipeline/plugins/base.py index 9b7adad343..5b6a4cafe3 100644 --- a/nipype/pipeline/plugins/base.py +++ b/nipype/pipeline/plugins/base.py @@ -21,8 +21,8 @@ import scipy.sparse as ssp -from ..utils import (nx, dfs_preorder, topological_sort) -from ..engine import (MapNode, str2bool) +from ..engine.graph import (nx, dfs_preorder, topological_sort) +from ..engine.nodes import (MapNode, str2bool) from nipype.utils.filemanip import savepkl, loadpkl diff --git a/nipype/pipeline/plugins/debug.py b/nipype/pipeline/plugins/debug.py index 9d0a52adaa..f7dc8f357f 100644 --- a/nipype/pipeline/plugins/debug.py +++ b/nipype/pipeline/plugins/debug.py @@ -4,7 +4,7 @@ """ from .base import (PluginBase, logger) -from ..utils import (nx) +from ..engine.graph import (nx) class DebugPlugin(PluginBase): diff --git a/nipype/pipeline/plugins/linear.py b/nipype/pipeline/plugins/linear.py index 216d037757..5f740ddcf0 100644 --- a/nipype/pipeline/plugins/linear.py +++ b/nipype/pipeline/plugins/linear.py @@ -6,7 +6,7 @@ from .base import (PluginBase, logger, report_crash, report_nodes_not_run, str2bool) -from ..utils import (nx, dfs_preorder, topological_sort) +from ..engine.graph import (nx, dfs_preorder, topological_sort) class LinearPlugin(PluginBase): diff --git a/setup.py b/setup.py index 5a5159d166..406e0b8dc3 100755 --- a/setup.py +++ b/setup.py @@ -380,9 +380,10 @@ def main(**extra_args): 'nipype.interfaces.vista', 'nipype.interfaces.vista.tests', 'nipype.pipeline', + 'nipype.pipeline.engine', + 'nipype.pipeline.engine.tests', 'nipype.pipeline.plugins', 'nipype.pipeline.plugins.tests', - 'nipype.pipeline.tests', 'nipype.testing', 'nipype.testing.data', 'nipype.testing.data.bedpostxout',