From 4f343533340a84a6f211dc1581457e4e84a17ce4 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 19:46:29 +0100 Subject: [PATCH 1/8] refactoring engine --- nipype/pipeline/engine/__init__.py | 13 + nipype/pipeline/engine/base.py | 125 ++ .../pipeline/{engine.py => engine/nodes.py} | 1089 +---------------- .../{ => engine}/report_template.html | 0 .../{ => engine}/report_template2.html | 0 .../pipeline/{ => engine}/tests/__init__.py | 0 .../{ => engine}/tests/test_engine.py | 0 .../pipeline/{ => engine}/tests/test_join.py | 0 .../pipeline/{ => engine}/tests/test_utils.py | 0 nipype/pipeline/{ => engine}/utils.py | 71 ++ nipype/pipeline/engine/workflows.py | 999 +++++++++++++++ 11 files changed, 1234 insertions(+), 1063 deletions(-) create mode 100644 nipype/pipeline/engine/__init__.py create mode 100644 nipype/pipeline/engine/base.py rename nipype/pipeline/{engine.py => engine/nodes.py} (54%) rename nipype/pipeline/{ => engine}/report_template.html (100%) rename nipype/pipeline/{ => engine}/report_template2.html (100%) rename nipype/pipeline/{ => engine}/tests/__init__.py (100%) rename nipype/pipeline/{ => engine}/tests/test_engine.py (100%) rename nipype/pipeline/{ => engine}/tests/test_join.py (100%) rename nipype/pipeline/{ => engine}/tests/test_utils.py (100%) rename nipype/pipeline/{ => engine}/utils.py (94%) create mode 100644 nipype/pipeline/engine/workflows.py diff --git a/nipype/pipeline/engine/__init__.py b/nipype/pipeline/engine/__init__.py new file mode 100644 index 0000000000..51ae8e92ea --- /dev/null +++ b/nipype/pipeline/engine/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +""" +Package contains modules for generating pipelines using interfaces + +""" + +from __future__ import absolute_import +__docformat__ = 'restructuredtext' +from .workflows import Workflow +from .nodes import Node, MapNode, JoinNode diff --git a/nipype/pipeline/engine/base.py b/nipype/pipeline/engine/base.py new file mode 100644 index 0000000000..a4d21328f2 --- /dev/null +++ b/nipype/pipeline/engine/base.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Defines functionality for pipelined execution of interfaces + +The `EngineBase` class implements the more general view of a task. + + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) + +""" + +from future import standard_library +standard_library.install_aliases() +from builtins import object + +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +from copy import deepcopy +import re +import numpy as np +from nipype.interfaces.traits_extension import traits, Undefined +from nipype.interfaces.base import DynamicTraitedSpec +from nipype.utils.filemanip import loadpkl, savepkl + +from nipype import logging +logger = logging.getLogger('workflow') + + +class EngineBase(object): + """Defines common attributes and functions for workflows and nodes.""" + + def __init__(self, name=None, base_dir=None): + """ Initialize base parameters of a workflow or node + + Parameters + ---------- + name : string (mandatory) + Name of this node. Name must be alphanumeric and not contain any + special characters (e.g., '.', '@'). + base_dir : string + base output directory (will be hashed before creations) + default=None, which results in the use of mkdtemp + + """ + self.base_dir = base_dir + self.config = None + self._verify_name(name) + self.name = name + # for compatibility with node expansion using iterables + self._id = self.name + self._hierarchy = None + + @property + def inputs(self): + raise NotImplementedError + + @property + def outputs(self): + raise NotImplementedError + + @property + def fullname(self): + fullname = self.name + if self._hierarchy: + fullname = self._hierarchy + '.' + self.name + return fullname + + def clone(self, name): + """Clone a workflowbase object + + Parameters + ---------- + + name : string (mandatory) + A clone of node or workflow must have a new name + """ + if (name is None) or (name == self.name): + raise Exception('Cloning requires a new name') + self._verify_name(name) + clone = deepcopy(self) + clone.name = name + clone._id = name + clone._hierarchy = None + return clone + + def _check_outputs(self, parameter): + return hasattr(self.outputs, parameter) + + def _check_inputs(self, parameter): + if isinstance(self.inputs, DynamicTraitedSpec): + return True + return hasattr(self.inputs, parameter) + + def _verify_name(self, name): + valid_name = bool(re.match('^[\w-]+$', name)) + if not valid_name: + raise ValueError('[Workflow|Node] name \'%s\' contains' + ' special characters' % name) + + def __repr__(self): + if self._hierarchy: + return '.'.join((self._hierarchy, self._id)) + else: + return self._id + + def save(self, filename=None): + if filename is None: + filename = 'temp.pklz' + savepkl(filename, self) + + def load(self, filename): + if '.npz' in filename: + DeprecationWarning(('npz files will be deprecated in the next ' + 'release. you can use numpy to open them.')) + return np.load(filename) + return loadpkl(filename) diff --git a/nipype/pipeline/engine.py b/nipype/pipeline/engine/nodes.py similarity index 54% rename from nipype/pipeline/engine.py rename to nipype/pipeline/engine/nodes.py index 1c73918bf8..abb61bafe1 100644 --- a/nipype/pipeline/engine.py +++ b/nipype/pipeline/engine/nodes.py @@ -1,14 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Defines functionality for pipelined execution of interfaces -The `Workflow` class provides core functionality for batch processing. +The `Node` class provides core functionality for batch processing. - Change directory to provide relative paths for doctests - >>> import os - >>> filepath = os.path.dirname( os.path.realpath( __file__ ) ) - >>> datadir = os.path.realpath(os.path.join(filepath, '../testing/data')) - >>> os.chdir(datadir) + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) """ @@ -44,1072 +47,32 @@ import numpy as np import networkx as nx -from ..utils.misc import package_check, str2bool +from ...utils.misc import package_check, str2bool package_check('networkx', '1.3') -from .. import config, logging +from ... import config, logging logger = logging.getLogger('workflow') -from ..interfaces.base import (traits, InputMultiPath, CommandLine, - Undefined, TraitedSpec, DynamicTraitedSpec, - Bunch, InterfaceResult, md5, Interface, - TraitDictObject, TraitListObject, isdefined) -from ..utils.misc import (getsource, create_function_from_source, - flatten, unflatten) -from ..utils.filemanip import (save_json, FileNotFoundError, - filename_to_list, list_to_filename, - copyfiles, fnames_presuffix, loadpkl, - split_filename, load_json, savepkl, - write_rst_header, write_rst_dict, - write_rst_list) -from ..external.six import string_types +from ...interfaces.base import (traits, InputMultiPath, CommandLine, + Undefined, TraitedSpec, DynamicTraitedSpec, + Bunch, InterfaceResult, md5, Interface, + TraitDictObject, TraitListObject, isdefined) +from ...utils.misc import (getsource, create_function_from_source, + flatten, unflatten) +from ...utils.filemanip import (save_json, FileNotFoundError, + filename_to_list, list_to_filename, + copyfiles, fnames_presuffix, loadpkl, + split_filename, load_json, savepkl, + write_rst_header, write_rst_dict, + write_rst_list) +from ...external.six import string_types from .utils import (generate_expanded_graph, modify_paths, export_graph, make_output_dir, write_workflow_prov, clean_working_directory, format_dot, topological_sort, get_print_name, merge_dict, evaluate_connect_function) +from .base import EngineBase -def _write_inputs(node): - lines = [] - nodename = node.fullname.replace('.', '_') - for key, _ in list(node.inputs.items()): - val = getattr(node.inputs, key) - if isdefined(val): - if type(val) == str: - try: - func = create_function_from_source(val) - except RuntimeError as e: - lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) - else: - funcname = [name for name in func.__globals__ - if name != '__builtins__'][0] - lines.append(pickle.loads(val)) - if funcname == nodename: - lines[-1] = lines[-1].replace(' %s(' % funcname, - ' %s_1(' % funcname) - funcname = '%s_1' % funcname - lines.append('from nipype.utils.misc import getsource') - lines.append("%s.inputs.%s = getsource(%s)" % (nodename, - key, - funcname)) - else: - lines.append('%s.inputs.%s = %s' % (nodename, key, val)) - return lines - - -def format_node(node, format='python', include_config=False): - """Format a node in a given output syntax.""" - lines = [] - name = node.fullname.replace('.', '_') - if format == 'python': - klass = node._interface - importline = 'from %s import %s' % (klass.__module__, - klass.__class__.__name__) - comment = '# Node: %s' % node.fullname - spec = inspect.signature(node._interface.__init__) - args = spec.args[1:] - if args: - filled_args = [] - for arg in args: - if hasattr(node._interface, '_%s' % arg): - filled_args.append('%s=%s' % (arg, getattr(node._interface, - '_%s' % arg))) - args = ', '.join(filled_args) - else: - args = '' - klass_name = klass.__class__.__name__ - if isinstance(node, MapNode): - nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' \ - % (name, klass_name, args, node.iterfield, name) - else: - nodedef = '%s = Node(%s(%s), name="%s")' \ - % (name, klass_name, args, name) - lines = [importline, comment, nodedef] - - if include_config: - lines = [importline, "from collections import OrderedDict", - comment, nodedef] - lines.append('%s.config = %s' % (name, node.config)) - - if node.iterables is not None: - lines.append('%s.iterables = %s' % (name, node.iterables)) - lines.extend(_write_inputs(node)) - - return lines - - -class WorkflowBase(object): - """Defines common attributes and functions for workflows and nodes.""" - - def __init__(self, name=None, base_dir=None): - """ Initialize base parameters of a workflow or node - - Parameters - ---------- - name : string (mandatory) - Name of this node. Name must be alphanumeric and not contain any - special characters (e.g., '.', '@'). - base_dir : string - base output directory (will be hashed before creations) - default=None, which results in the use of mkdtemp - - """ - self.base_dir = base_dir - self.config = None - self._verify_name(name) - self.name = name - # for compatibility with node expansion using iterables - self._id = self.name - self._hierarchy = None - - @property - def inputs(self): - raise NotImplementedError - - @property - def outputs(self): - raise NotImplementedError - - @property - def fullname(self): - fullname = self.name - if self._hierarchy: - fullname = self._hierarchy + '.' + self.name - return fullname - - def clone(self, name): - """Clone a workflowbase object - - Parameters - ---------- - - name : string (mandatory) - A clone of node or workflow must have a new name - """ - if (name is None) or (name == self.name): - raise Exception('Cloning requires a new name') - self._verify_name(name) - clone = deepcopy(self) - clone.name = name - clone._id = name - clone._hierarchy = None - return clone - - def _check_outputs(self, parameter): - return hasattr(self.outputs, parameter) - - def _check_inputs(self, parameter): - if isinstance(self.inputs, DynamicTraitedSpec): - return True - return hasattr(self.inputs, parameter) - - def _verify_name(self, name): - valid_name = bool(re.match('^[\w-]+$', name)) - if not valid_name: - raise Exception('the name must not contain any special characters') - - def __repr__(self): - if self._hierarchy: - return '.'.join((self._hierarchy, self._id)) - else: - return self._id - - def save(self, filename=None): - if filename is None: - filename = 'temp.pklz' - savepkl(filename, self) - - def load(self, filename): - if '.npz' in filename: - DeprecationWarning(('npz files will be deprecated in the next ' - 'release. you can use numpy to open them.')) - return np.load(filename) - return loadpkl(filename) - - -class Workflow(WorkflowBase): - """Controls the setup and execution of a pipeline of processes.""" - - def __init__(self, name, base_dir=None): - """Create a workflow object. - - Parameters - ---------- - name : alphanumeric string - unique identifier for the workflow - base_dir : string, optional - path to workflow storage - - """ - super(Workflow, self).__init__(name, base_dir) - self._graph = nx.DiGraph() - self.config = deepcopy(config._sections) - - # PUBLIC API - def clone(self, name): - """Clone a workflow - - .. note:: - - Will reset attributes used for executing workflow. See - _init_runtime_fields. - - Parameters - ---------- - - name: alphanumeric name - unique name for the workflow - - """ - clone = super(Workflow, self).clone(name) - clone._reset_hierarchy() - return clone - - # Graph creation functions - def connect(self, *args, **kwargs): - """Connect nodes in the pipeline. - - This routine also checks if inputs and outputs are actually provided by - the nodes that are being connected. - - Creates edges in the directed graph using the nodes and edges specified - in the `connection_list`. Uses the NetworkX method - DiGraph.add_edges_from. - - Parameters - ---------- - - args : list or a set of four positional arguments - - Four positional arguments of the form:: - - connect(source, sourceoutput, dest, destinput) - - source : nodewrapper node - sourceoutput : string (must be in source.outputs) - dest : nodewrapper node - destinput : string (must be in dest.inputs) - - A list of 3-tuples of the following form:: - - [(source, target, - [('sourceoutput/attribute', 'targetinput'), - ...]), - ...] - - Or:: - - [(source, target, [(('sourceoutput1', func, arg2, ...), - 'targetinput'), ...]), - ...] - sourceoutput1 will always be the first argument to func - and func will be evaluated and the results sent ot targetinput - - currently func needs to define all its needed imports within the - function as we use the inspect module to get at the source code - and execute it remotely - """ - if len(args) == 1: - connection_list = args[0] - elif len(args) == 4: - connection_list = [(args[0], args[2], [(args[1], args[3])])] - else: - raise Exception('unknown set of parameters to connect function') - if not kwargs: - disconnect = False - else: - disconnect = kwargs['disconnect'] - newnodes = [] - for srcnode, destnode, _ in connection_list: - if self in [srcnode, destnode]: - msg = ('Workflow connect cannot contain itself as node:' - ' src[%s] dest[%s] workflow[%s]') % (srcnode, - destnode, - self.name) - - raise IOError(msg) - if (srcnode not in newnodes) and not self._has_node(srcnode): - newnodes.append(srcnode) - if (destnode not in newnodes) and not self._has_node(destnode): - newnodes.append(destnode) - if newnodes: - self._check_nodes(newnodes) - for node in newnodes: - if node._hierarchy is None: - node._hierarchy = self.name - not_found = [] - connected_ports = {} - for srcnode, destnode, connects in connection_list: - if destnode not in connected_ports: - connected_ports[destnode] = [] - # check to see which ports of destnode are already - # connected. - if not disconnect and (destnode in self._graph.nodes()): - for edge in self._graph.in_edges_iter(destnode): - data = self._graph.get_edge_data(*edge) - for sourceinfo, destname in data['connect']: - if destname not in connected_ports[destnode]: - connected_ports[destnode] += [destname] - for source, dest in connects: - # Currently datasource/sink/grabber.io modules - # determine their inputs/outputs depending on - # connection settings. Skip these modules in the check - if dest in connected_ports[destnode]: - raise Exception(""" -Trying to connect %s:%s to %s:%s but input '%s' of node '%s' is already -connected. -""" % (srcnode, source, destnode, dest, dest, destnode)) - if not (hasattr(destnode, '_interface') and - '.io' in str(destnode._interface.__class__)): - if not destnode._check_inputs(dest): - not_found.append(['in', destnode.name, dest]) - if not (hasattr(srcnode, '_interface') and - '.io' in str(srcnode._interface.__class__)): - if isinstance(source, tuple): - # handles the case that source is specified - # with a function - sourcename = source[0] - elif isinstance(source, string_types): - sourcename = source - else: - raise Exception(('Unknown source specification in ' - 'connection from output of %s') % - srcnode.name) - if sourcename and not srcnode._check_outputs(sourcename): - not_found.append(['out', srcnode.name, sourcename]) - connected_ports[destnode] += [dest] - infostr = [] - for info in not_found: - infostr += ["Module %s has no %sput called %s\n" % (info[1], - info[0], - info[2])] - if not_found: - raise Exception('\n'.join(['Some connections were not found'] + - infostr)) - - # turn functions into strings - for srcnode, destnode, connects in connection_list: - for idx, (src, dest) in enumerate(connects): - if isinstance(src, tuple) and not isinstance(src[1], string_types): - function_source = getsource(src[1]) - connects[idx] = ((src[0], function_source, src[2:]), dest) - - # add connections - for srcnode, destnode, connects in connection_list: - edge_data = self._graph.get_edge_data(srcnode, destnode, None) - if edge_data: - logger.debug('(%s, %s): Edge data exists: %s' - % (srcnode, destnode, str(edge_data))) - for data in connects: - if data not in edge_data['connect']: - edge_data['connect'].append(data) - if disconnect: - logger.debug('Removing connection: %s' % str(data)) - edge_data['connect'].remove(data) - if edge_data['connect']: - self._graph.add_edges_from([(srcnode, - destnode, - edge_data)]) - else: - # pass - logger.debug('Removing connection: %s->%s' % (srcnode, - destnode)) - self._graph.remove_edges_from([(srcnode, destnode)]) - elif not disconnect: - logger.debug('(%s, %s): No edge data' % (srcnode, destnode)) - self._graph.add_edges_from([(srcnode, destnode, - {'connect': connects})]) - edge_data = self._graph.get_edge_data(srcnode, destnode) - logger.debug('(%s, %s): new edge data: %s' % (srcnode, destnode, - str(edge_data))) - - def disconnect(self, *args): - """Disconnect two nodes - - See the docstring for connect for format. - """ - # yoh: explicit **dict was introduced for compatibility with Python 2.5 - return self.connect(*args, **dict(disconnect=True)) - - def add_nodes(self, nodes): - """ Add nodes to a workflow - - Parameters - ---------- - nodes : list - A list of WorkflowBase-based objects - """ - newnodes = [] - all_nodes = self._get_all_nodes() - for node in nodes: - if self._has_node(node): - raise IOError('Node %s already exists in the workflow' % node) - if isinstance(node, Workflow): - for subnode in node._get_all_nodes(): - if subnode in all_nodes: - raise IOError(('Subnode %s of node %s already exists ' - 'in the workflow') % (subnode, node)) - newnodes.append(node) - if not newnodes: - logger.debug('no new nodes to add') - return - for node in newnodes: - if not issubclass(node.__class__, WorkflowBase): - raise Exception('Node %s must be a subclass of WorkflowBase' % - str(node)) - self._check_nodes(newnodes) - for node in newnodes: - if node._hierarchy is None: - node._hierarchy = self.name - self._graph.add_nodes_from(newnodes) - - def remove_nodes(self, nodes): - """ Remove nodes from a workflow - - Parameters - ---------- - nodes : list - A list of WorkflowBase-based objects - """ - self._graph.remove_nodes_from(nodes) - - # Input-Output access - @property - def inputs(self): - return self._get_inputs() - - @property - def outputs(self): - return self._get_outputs() - - def get_node(self, name): - """Return an internal node by name - """ - nodenames = name.split('.') - nodename = nodenames[0] - outnode = [node for node in self._graph.nodes() if - str(node).endswith('.' + nodename)] - if outnode: - outnode = outnode[0] - if nodenames[1:] and issubclass(outnode.__class__, Workflow): - outnode = outnode.get_node('.'.join(nodenames[1:])) - else: - outnode = None - return outnode - - def list_node_names(self): - """List names of all nodes in a workflow - """ - outlist = [] - for node in nx.topological_sort(self._graph): - if isinstance(node, Workflow): - outlist.extend(['.'.join((node.name, nodename)) for nodename in - node.list_node_names()]) - else: - outlist.append(node.name) - return sorted(outlist) - - def write_graph(self, dotfilename='graph.dot', graph2use='hierarchical', - format="png", simple_form=True): - """Generates a graphviz dot file and a png file - - Parameters - ---------- - - graph2use: 'orig', 'hierarchical' (default), 'flat', 'exec', 'colored' - orig - creates a top level graph without expanding internal - workflow nodes; - flat - expands workflow nodes recursively; - hierarchical - expands workflow nodes recursively with a - notion on hierarchy; - colored - expands workflow nodes recursively with a - notion on hierarchy in color; - exec - expands workflows to depict iterables - - format: 'png', 'svg' - - simple_form: boolean (default: True) - Determines if the node name used in the graph should be of the form - 'nodename (package)' when True or 'nodename.Class.package' when - False. - - """ - graphtypes = ['orig', 'flat', 'hierarchical', 'exec', 'colored'] - if graph2use not in graphtypes: - raise ValueError('Unknown graph2use keyword. Must be one of: ' + - str(graphtypes)) - base_dir, dotfilename = op.split(dotfilename) - if base_dir == '': - if self.base_dir: - base_dir = self.base_dir - if self.name: - base_dir = op.join(base_dir, self.name) - else: - base_dir = os.getcwd() - base_dir = make_output_dir(base_dir) - if graph2use in ['hierarchical', 'colored']: - dotfilename = op.join(base_dir, dotfilename) - self.write_hierarchical_dotfile(dotfilename=dotfilename, - colored=graph2use == "colored", - simple_form=simple_form) - format_dot(dotfilename, format=format) - else: - graph = self._graph - if graph2use in ['flat', 'exec']: - graph = self._create_flat_graph() - if graph2use == 'exec': - graph = generate_expanded_graph(deepcopy(graph)) - export_graph(graph, base_dir, dotfilename=dotfilename, - format=format, simple_form=simple_form) - - def write_hierarchical_dotfile(self, dotfilename=None, colored=False, - simple_form=True): - dotlist = ['digraph %s{' % self.name] - dotlist.append(self._get_dot(prefix=' ', colored=colored, - simple_form=simple_form)) - dotlist.append('}') - dotstr = '\n'.join(dotlist) - if dotfilename: - fp = open(dotfilename, 'wt') - fp.writelines(dotstr) - fp.close() - else: - logger.info(dotstr) - - def export(self, filename=None, prefix="output", format="python", - include_config=False): - """Export object into a different format - - Parameters - ---------- - filename: string - file to save the code to; overrides prefix - prefix: string - prefix to use for output file - format: string - one of "python" - include_config: boolean - whether to include node and workflow config values - - """ - formats = ["python"] - if format not in formats: - raise ValueError('format must be one of: %s' % '|'.join(formats)) - flatgraph = self._create_flat_graph() - nodes = nx.topological_sort(flatgraph) - - lines = ['# Workflow'] - importlines = ['from nipype.pipeline.engine import Workflow, ' - 'Node, MapNode'] - functions = {} - if format == "python": - connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name - connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' \ - % self.name - wfdef = '%s = Workflow("%s")' % (self.name, self.name) - lines.append(wfdef) - if include_config: - lines.append('%s.config = %s' % (self.name, self.config)) - for idx, node in enumerate(nodes): - nodename = node.fullname.replace('.', '_') - # write nodes - nodelines = format_node(node, format='python', - include_config=include_config) - for line in nodelines: - if line.startswith('from'): - if line not in importlines: - importlines.append(line) - else: - lines.append(line) - # write connections - for u, _, d in flatgraph.in_edges_iter(nbunch=node, - data=True): - for cd in d['connect']: - if isinstance(cd[0], tuple): - args = list(cd[0]) - if args[1] in functions: - funcname = functions[args[1]] - else: - func = create_function_from_source(args[1]) - funcname = [name for name in func.__globals__ - if name != '__builtins__'][0] - functions[args[1]] = funcname - args[1] = funcname - args = tuple([arg for arg in args if arg]) - line_args = (u.fullname.replace('.', '_'), - args, nodename, cd[1]) - line = connect_template % line_args - line = line.replace("'%s'" % funcname, funcname) - lines.append(line) - else: - line_args = (u.fullname.replace('.', '_'), - cd[0], nodename, cd[1]) - lines.append(connect_template2 % line_args) - functionlines = ['# Functions'] - for function in functions: - functionlines.append(pickle.loads(function).rstrip()) - all_lines = importlines + functionlines + lines - - if not filename: - filename = '%s%s.py' % (prefix, self.name) - with open(filename, 'wt') as fp: - fp.writelines('\n'.join(all_lines)) - return all_lines - - def run(self, plugin=None, plugin_args=None, updatehash=False): - """ Execute the workflow - - Parameters - ---------- - - plugin: plugin name or object - Plugin to use for execution. You can create your own plugins for - execution. - plugin_args : dictionary containing arguments to be sent to plugin - constructor. see individual plugin doc strings for details. - """ - if plugin is None: - plugin = config.get('execution', 'plugin') - if not isinstance(plugin, string_types): - runner = plugin - else: - name = 'nipype.pipeline.plugins' - try: - __import__(name) - except ImportError: - msg = 'Could not import plugin module: %s' % name - logger.error(msg) - raise ImportError(msg) - else: - plugin_mod = getattr(sys.modules[name], '%sPlugin' % plugin) - runner = plugin_mod(plugin_args=plugin_args) - flatgraph = self._create_flat_graph() - self.config = merge_dict(deepcopy(config._sections), self.config) - if 'crashdump_dir' in self.config: - warn(("Deprecated: workflow.config['crashdump_dir']\n" - "Please use config['execution']['crashdump_dir']")) - crash_dir = self.config['crashdump_dir'] - self.config['execution']['crashdump_dir'] = crash_dir - del self.config['crashdump_dir'] - logger.info(str(sorted(self.config))) - self._set_needed_outputs(flatgraph) - execgraph = generate_expanded_graph(deepcopy(flatgraph)) - for index, node in enumerate(execgraph.nodes()): - node.config = merge_dict(deepcopy(self.config), node.config) - node.base_dir = self.base_dir - node.index = index - if isinstance(node, MapNode): - node.use_plugin = (plugin, plugin_args) - self._configure_exec_nodes(execgraph) - if str2bool(self.config['execution']['create_report']): - self._write_report_info(self.base_dir, self.name, execgraph) - runner.run(execgraph, updatehash=updatehash, config=self.config) - datestr = datetime.utcnow().strftime('%Y%m%dT%H%M%S') - if str2bool(self.config['execution']['write_provenance']): - prov_base = op.join(self.base_dir, - 'workflow_provenance_%s' % datestr) - logger.info('Provenance file prefix: %s' % prov_base) - write_workflow_prov(execgraph, prov_base, format='all') - return execgraph - - # PRIVATE API AND FUNCTIONS - - def _write_report_info(self, workingdir, name, graph): - if workingdir is None: - workingdir = os.getcwd() - report_dir = op.join(workingdir, name) - if not op.exists(report_dir): - os.makedirs(report_dir) - shutil.copyfile(op.join(op.dirname(__file__), - 'report_template.html'), - op.join(report_dir, 'index.html')) - shutil.copyfile(op.join(op.dirname(__file__), - '..', 'external', 'd3.js'), - op.join(report_dir, 'd3.js')) - nodes, groups = topological_sort(graph, depth_first=True) - graph_file = op.join(report_dir, 'graph1.json') - json_dict = {'nodes': [], 'links': [], 'groups': [], 'maxN': 0} - for i, node in enumerate(nodes): - report_file = "%s/_report/report.rst" % \ - node.output_dir().replace(report_dir, '') - result_file = "%s/result_%s.pklz" % \ - (node.output_dir().replace(report_dir, ''), - node.name) - json_dict['nodes'].append(dict(name='%d_%s' % (i, node.name), - report=report_file, - result=result_file, - group=groups[i])) - maxN = 0 - for gid in np.unique(groups): - procs = [i for i, val in enumerate(groups) if val == gid] - N = len(procs) - if N > maxN: - maxN = N - json_dict['groups'].append(dict(procs=procs, - total=N, - name='Group_%05d' % gid)) - json_dict['maxN'] = maxN - for u, v in graph.in_edges_iter(): - json_dict['links'].append(dict(source=nodes.index(u), - target=nodes.index(v), - value=1)) - save_json(graph_file, json_dict) - graph_file = op.join(report_dir, 'graph.json') - template = '%%0%dd_' % np.ceil(np.log10(len(nodes))).astype(int) - - def getname(u, i): - name_parts = u.fullname.split('.') - # return '.'.join(name_parts[:-1] + [template % i + name_parts[-1]]) - return template % i + name_parts[-1] - json_dict = [] - for i, node in enumerate(nodes): - imports = [] - for u, v in graph.in_edges_iter(nbunch=node): - imports.append(getname(u, nodes.index(u))) - json_dict.append(dict(name=getname(node, i), - size=1, - group=groups[i], - imports=imports)) - save_json(graph_file, json_dict) - - def _set_needed_outputs(self, graph): - """Initialize node with list of which outputs are needed.""" - rm_outputs = self.config['execution']['remove_unnecessary_outputs'] - if not str2bool(rm_outputs): - return - for node in graph.nodes(): - node.needed_outputs = [] - for edge in graph.out_edges_iter(node): - data = graph.get_edge_data(*edge) - sourceinfo = [v1[0] if isinstance(v1, tuple) else v1 - for v1, v2 in data['connect']] - node.needed_outputs += [v for v in sourceinfo - if v not in node.needed_outputs] - if node.needed_outputs: - node.needed_outputs = sorted(node.needed_outputs) - - def _configure_exec_nodes(self, graph): - """Ensure that each node knows where to get inputs from - """ - for node in graph.nodes(): - node.input_source = {} - for edge in graph.in_edges_iter(node): - data = graph.get_edge_data(*edge) - for sourceinfo, field in sorted(data['connect']): - node.input_source[field] = \ - (op.join(edge[0].output_dir(), - 'result_%s.pklz' % edge[0].name), - sourceinfo) - - def _check_nodes(self, nodes): - """Checks if any of the nodes are already in the graph - - """ - node_names = [node.name for node in self._graph.nodes()] - node_lineage = [node._hierarchy for node in self._graph.nodes()] - for node in nodes: - if node.name in node_names: - idx = node_names.index(node.name) - if node_lineage[idx] in [node._hierarchy, self.name]: - raise IOError('Duplicate node name %s found.' % node.name) - else: - node_names.append(node.name) - - def _has_attr(self, parameter, subtype='in'): - """Checks if a parameter is available as an input or output - """ - if subtype == 'in': - subobject = self.inputs - else: - subobject = self.outputs - attrlist = parameter.split('.') - cur_out = subobject - for attr in attrlist: - if not hasattr(cur_out, attr): - return False - cur_out = getattr(cur_out, attr) - return True - - def _get_parameter_node(self, parameter, subtype='in'): - """Returns the underlying node corresponding to an input or - output parameter - """ - if subtype == 'in': - subobject = self.inputs - else: - subobject = self.outputs - attrlist = parameter.split('.') - cur_out = subobject - for attr in attrlist[:-1]: - cur_out = getattr(cur_out, attr) - return cur_out.traits()[attrlist[-1]].node - - def _check_outputs(self, parameter): - return self._has_attr(parameter, subtype='out') - - def _check_inputs(self, parameter): - return self._has_attr(parameter, subtype='in') - - def _get_inputs(self): - """Returns the inputs of a workflow - - This function does not return any input ports that are already - connected - """ - inputdict = TraitedSpec() - for node in self._graph.nodes(): - inputdict.add_trait(node.name, traits.Instance(TraitedSpec)) - if isinstance(node, Workflow): - setattr(inputdict, node.name, node.inputs) - else: - taken_inputs = [] - for _, _, d in self._graph.in_edges_iter(nbunch=node, - data=True): - for cd in d['connect']: - taken_inputs.append(cd[1]) - unconnectedinputs = TraitedSpec() - for key, trait in list(node.inputs.items()): - if key not in taken_inputs: - unconnectedinputs.add_trait(key, - traits.Trait(trait, - node=node)) - value = getattr(node.inputs, key) - setattr(unconnectedinputs, key, value) - setattr(inputdict, node.name, unconnectedinputs) - getattr(inputdict, node.name).on_trait_change(self._set_input) - return inputdict - - def _get_outputs(self): - """Returns all possible output ports that are not already connected - """ - outputdict = TraitedSpec() - for node in self._graph.nodes(): - outputdict.add_trait(node.name, traits.Instance(TraitedSpec)) - if isinstance(node, Workflow): - setattr(outputdict, node.name, node.outputs) - elif node.outputs: - outputs = TraitedSpec() - for key, _ in list(node.outputs.items()): - outputs.add_trait(key, traits.Any(node=node)) - setattr(outputs, key, None) - setattr(outputdict, node.name, outputs) - return outputdict - - def _set_input(self, object, name, newvalue): - """Trait callback function to update a node input - """ - object.traits()[name].node.set_input(name, newvalue) - - def _set_node_input(self, node, param, source, sourceinfo): - """Set inputs of a node given the edge connection""" - if isinstance(sourceinfo, string_types): - val = source.get_output(sourceinfo) - elif isinstance(sourceinfo, tuple): - if callable(sourceinfo[1]): - val = sourceinfo[1](source.get_output(sourceinfo[0]), - *sourceinfo[2:]) - newval = val - if isinstance(val, TraitDictObject): - newval = dict(val) - if isinstance(val, TraitListObject): - newval = val[:] - logger.debug('setting node input: %s->%s', param, str(newval)) - node.set_input(param, deepcopy(newval)) - - def _get_all_nodes(self): - allnodes = [] - for node in self._graph.nodes(): - if isinstance(node, Workflow): - allnodes.extend(node._get_all_nodes()) - else: - allnodes.append(node) - return allnodes - - def _has_node(self, wanted_node): - for node in self._graph.nodes(): - if wanted_node == node: - return True - if isinstance(node, Workflow): - if node._has_node(wanted_node): - return True - return False - - def _create_flat_graph(self): - """Make a simple DAG where no node is a workflow.""" - logger.debug('Creating flat graph for workflow: %s', self.name) - workflowcopy = deepcopy(self) - workflowcopy._generate_flatgraph() - return workflowcopy._graph - - def _reset_hierarchy(self): - """Reset the hierarchy on a graph - """ - for node in self._graph.nodes(): - if isinstance(node, Workflow): - node._reset_hierarchy() - for innernode in node._graph.nodes(): - innernode._hierarchy = '.'.join((self.name, - innernode._hierarchy)) - else: - node._hierarchy = self.name - - def _generate_flatgraph(self): - """Generate a graph containing only Nodes or MapNodes - """ - logger.debug('expanding workflow: %s', self) - nodes2remove = [] - if not nx.is_directed_acyclic_graph(self._graph): - raise Exception(('Workflow: %s is not a directed acyclic graph ' - '(DAG)') % self.name) - nodes = nx.topological_sort(self._graph) - for node in nodes: - logger.debug('processing node: %s' % node) - if isinstance(node, Workflow): - nodes2remove.append(node) - # use in_edges instead of in_edges_iter to allow - # disconnections to take place properly. otherwise, the - # edge dict is modified. - for u, _, d in self._graph.in_edges(nbunch=node, data=True): - logger.debug('in: connections-> %s' % str(d['connect'])) - for cd in deepcopy(d['connect']): - logger.debug("in: %s" % str(cd)) - dstnode = node._get_parameter_node(cd[1], subtype='in') - srcnode = u - srcout = cd[0] - dstin = cd[1].split('.')[-1] - logger.debug('in edges: %s %s %s %s' % - (srcnode, srcout, dstnode, dstin)) - self.disconnect(u, cd[0], node, cd[1]) - self.connect(srcnode, srcout, dstnode, dstin) - # do not use out_edges_iter for reasons stated in in_edges - for _, v, d in self._graph.out_edges(nbunch=node, data=True): - logger.debug('out: connections-> %s' % str(d['connect'])) - for cd in deepcopy(d['connect']): - logger.debug("out: %s" % str(cd)) - dstnode = v - if isinstance(cd[0], tuple): - parameter = cd[0][0] - else: - parameter = cd[0] - srcnode = node._get_parameter_node(parameter, - subtype='out') - if isinstance(cd[0], tuple): - srcout = list(cd[0]) - srcout[0] = parameter.split('.')[-1] - srcout = tuple(srcout) - else: - srcout = parameter.split('.')[-1] - dstin = cd[1] - logger.debug('out edges: %s %s %s %s' % (srcnode, - srcout, - dstnode, - dstin)) - self.disconnect(node, cd[0], v, cd[1]) - self.connect(srcnode, srcout, dstnode, dstin) - # expand the workflow node - # logger.debug('expanding workflow: %s', node) - node._generate_flatgraph() - for innernode in node._graph.nodes(): - innernode._hierarchy = '.'.join((self.name, - innernode._hierarchy)) - self._graph.add_nodes_from(node._graph.nodes()) - self._graph.add_edges_from(node._graph.edges(data=True)) - if nodes2remove: - self._graph.remove_nodes_from(nodes2remove) - logger.debug('finished expanding workflow: %s', self) - - def _get_dot(self, prefix=None, hierarchy=None, colored=False, - simple_form=True, level=0): - """Create a dot file with connection info - """ - if prefix is None: - prefix = ' ' - if hierarchy is None: - hierarchy = [] - colorset = ['#FFFFC8', '#0000FF', '#B4B4FF', '#E6E6FF', '#FF0000', - '#FFB4B4', '#FFE6E6', '#00A300', '#B4FFB4', '#E6FFE6'] - - dotlist = ['%slabel="%s";' % (prefix, self.name)] - for node in nx.topological_sort(self._graph): - fullname = '.'.join(hierarchy + [node.fullname]) - nodename = fullname.replace('.', '_') - if not isinstance(node, Workflow): - node_class_name = get_print_name(node, simple_form=simple_form) - if not simple_form: - node_class_name = '.'.join(node_class_name.split('.')[1:]) - if hasattr(node, 'iterables') and node.iterables: - dotlist.append(('%s[label="%s", shape=box3d,' - 'style=filled, color=black, colorscheme' - '=greys7 fillcolor=2];') % (nodename, - node_class_name)) - else: - if colored: - dotlist.append(('%s[label="%s", style=filled,' - ' fillcolor="%s"];') - % (nodename, node_class_name, - colorset[level])) - else: - dotlist.append(('%s[label="%s"];') - % (nodename, node_class_name)) - - for node in nx.topological_sort(self._graph): - if isinstance(node, Workflow): - fullname = '.'.join(hierarchy + [node.fullname]) - nodename = fullname.replace('.', '_') - dotlist.append('subgraph cluster_%s {' % nodename) - if colored: - dotlist.append(prefix + prefix + 'edge [color="%s"];' % (colorset[level + 1])) - dotlist.append(prefix + prefix + 'style=filled;') - dotlist.append(prefix + prefix + 'fillcolor="%s";' % (colorset[level + 2])) - dotlist.append(node._get_dot(prefix=prefix + prefix, - hierarchy=hierarchy + [self.name], - colored=colored, - simple_form=simple_form, level=level + 3)) - dotlist.append('}') - if level == 6: - level = 2 - else: - for subnode in self._graph.successors_iter(node): - if node._hierarchy != subnode._hierarchy: - continue - if not isinstance(subnode, Workflow): - nodefullname = '.'.join(hierarchy + [node.fullname]) - subnodefullname = '.'.join(hierarchy + - [subnode.fullname]) - nodename = nodefullname.replace('.', '_') - subnodename = subnodefullname.replace('.', '_') - for _ in self._graph.get_edge_data(node, - subnode)['connect']: - dotlist.append('%s -> %s;' % (nodename, - subnodename)) - logger.debug('connection: ' + dotlist[-1]) - # add between workflow connections - for u, v, d in self._graph.edges_iter(data=True): - uname = '.'.join(hierarchy + [u.fullname]) - vname = '.'.join(hierarchy + [v.fullname]) - for src, dest in d['connect']: - uname1 = uname - vname1 = vname - if isinstance(src, tuple): - srcname = src[0] - else: - srcname = src - if '.' in srcname: - uname1 += '.' + '.'.join(srcname.split('.')[:-1]) - if '.' in dest and '@' not in dest: - if not isinstance(v, Workflow): - if 'datasink' not in \ - str(v._interface.__class__).lower(): - vname1 += '.' + '.'.join(dest.split('.')[:-1]) - else: - vname1 += '.' + '.'.join(dest.split('.')[:-1]) - if uname1.split('.')[:-1] != vname1.split('.')[:-1]: - dotlist.append('%s -> %s;' % (uname1.replace('.', '_'), - vname1.replace('.', '_'))) - logger.debug('cross connection: ' + dotlist[-1]) - return ('\n' + prefix).join(dotlist) - - -class Node(WorkflowBase): +class Node(EngineBase): """Wraps interface objects for use in pipeline A Node creates a sandbox-like directory for executing the underlying diff --git a/nipype/pipeline/report_template.html b/nipype/pipeline/engine/report_template.html similarity index 100% rename from nipype/pipeline/report_template.html rename to nipype/pipeline/engine/report_template.html diff --git a/nipype/pipeline/report_template2.html b/nipype/pipeline/engine/report_template2.html similarity index 100% rename from nipype/pipeline/report_template2.html rename to nipype/pipeline/engine/report_template2.html diff --git a/nipype/pipeline/tests/__init__.py b/nipype/pipeline/engine/tests/__init__.py similarity index 100% rename from nipype/pipeline/tests/__init__.py rename to nipype/pipeline/engine/tests/__init__.py diff --git a/nipype/pipeline/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py similarity index 100% rename from nipype/pipeline/tests/test_engine.py rename to nipype/pipeline/engine/tests/test_engine.py diff --git a/nipype/pipeline/tests/test_join.py b/nipype/pipeline/engine/tests/test_join.py similarity index 100% rename from nipype/pipeline/tests/test_join.py rename to nipype/pipeline/engine/tests/test_join.py diff --git a/nipype/pipeline/tests/test_utils.py b/nipype/pipeline/engine/tests/test_utils.py similarity index 100% rename from nipype/pipeline/tests/test_utils.py rename to nipype/pipeline/engine/tests/test_utils.py diff --git a/nipype/pipeline/utils.py b/nipype/pipeline/engine/utils.py similarity index 94% rename from nipype/pipeline/utils.py rename to nipype/pipeline/engine/utils.py index 64caa482eb..ac9560da69 100644 --- a/nipype/pipeline/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Utility routines for workflow graphs @@ -83,6 +85,75 @@ def relpath(path, start=None): return op.join(*rel_list) +def _write_inputs(node): + lines = [] + nodename = node.fullname.replace('.', '_') + for key, _ in list(node.inputs.items()): + val = getattr(node.inputs, key) + if isdefined(val): + if type(val) == str: + try: + func = create_function_from_source(val) + except RuntimeError as e: + lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) + else: + funcname = [name for name in func.__globals__ + if name != '__builtins__'][0] + lines.append(pickle.loads(val)) + if funcname == nodename: + lines[-1] = lines[-1].replace(' %s(' % funcname, + ' %s_1(' % funcname) + funcname = '%s_1' % funcname + lines.append('from nipype.utils.misc import getsource') + lines.append("%s.inputs.%s = getsource(%s)" % (nodename, + key, + funcname)) + else: + lines.append('%s.inputs.%s = %s' % (nodename, key, val)) + return lines + + +def format_node(node, format='python', include_config=False): + """Format a node in a given output syntax.""" + lines = [] + name = node.fullname.replace('.', '_') + if format == 'python': + klass = node._interface + importline = 'from %s import %s' % (klass.__module__, + klass.__class__.__name__) + comment = '# Node: %s' % node.fullname + spec = inspect.signature(node._interface.__init__) + args = spec.args[1:] + if args: + filled_args = [] + for arg in args: + if hasattr(node._interface, '_%s' % arg): + filled_args.append('%s=%s' % (arg, getattr(node._interface, + '_%s' % arg))) + args = ', '.join(filled_args) + else: + args = '' + klass_name = klass.__class__.__name__ + if isinstance(node, MapNode): + nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' \ + % (name, klass_name, args, node.iterfield, name) + else: + nodedef = '%s = Node(%s(%s), name="%s")' \ + % (name, klass_name, args, name) + lines = [importline, comment, nodedef] + + if include_config: + lines = [importline, "from collections import OrderedDict", + comment, nodedef] + lines.append('%s.config = %s' % (name, node.config)) + + if node.iterables is not None: + lines.append('%s.iterables = %s' % (name, node.iterables)) + lines.extend(_write_inputs(node)) + + return lines + + def modify_paths(object, relative=True, basedir=None): """Convert paths in data structure to either full paths or relative paths diff --git a/nipype/pipeline/engine/workflows.py b/nipype/pipeline/engine/workflows.py new file mode 100644 index 0000000000..9328b414bd --- /dev/null +++ b/nipype/pipeline/engine/workflows.py @@ -0,0 +1,999 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Defines functionality for pipelined execution of interfaces + +The `Workflow` class provides core functionality for batch processing. + + .. testsetup:: + # Change directory to provide relative paths for doctests + import os + filepath = os.path.dirname(os.path.realpath( __file__ )) + datadir = os.path.realpath(os.path.join(filepath, '../../testing/data')) + os.chdir(datadir) + +""" + +from future import standard_library +standard_library.install_aliases() +from builtins import range +from builtins import object + +from datetime import datetime +from nipype.utils.misc import flatten, unflatten +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict + +from copy import deepcopy +import pickle +from glob import glob +import gzip +import inspect +import os +import os.path as op +import re +import shutil +import errno +import socket +from shutil import rmtree +import sys +from tempfile import mkdtemp +from warnings import warn +from hashlib import sha1 + +import numpy as np +import networkx as nx + +from ...utils.misc import package_check, str2bool +package_check('networkx', '1.3') + +from ... import config, logging +logger = logging.getLogger('workflow') +from ...interfaces.base import (traits, InputMultiPath, CommandLine, + Undefined, TraitedSpec, DynamicTraitedSpec, + Bunch, InterfaceResult, md5, Interface, + TraitDictObject, TraitListObject, isdefined) +from ...utils.misc import (getsource, create_function_from_source, + flatten, unflatten) +from ...utils.filemanip import (save_json, FileNotFoundError, + filename_to_list, list_to_filename, + copyfiles, fnames_presuffix, loadpkl, + split_filename, load_json, savepkl, + write_rst_header, write_rst_dict, + write_rst_list) +from ...external.six import string_types +from .utils import (generate_expanded_graph, modify_paths, + export_graph, make_output_dir, write_workflow_prov, + clean_working_directory, format_dot, topological_sort, + get_print_name, merge_dict, evaluate_connect_function, + _write_inputs, format_node) + +from .base import EngineBase +from .nodes import Node, MapNode + + +class Workflow(EngineBase): + """Controls the setup and execution of a pipeline of processes.""" + + def __init__(self, name, base_dir=None): + """Create a workflow object. + + Parameters + ---------- + name : alphanumeric string + unique identifier for the workflow + base_dir : string, optional + path to workflow storage + + """ + super(Workflow, self).__init__(name, base_dir) + self._graph = nx.DiGraph() + self.config = deepcopy(config._sections) + + # PUBLIC API + def clone(self, name): + """Clone a workflow + + .. note:: + + Will reset attributes used for executing workflow. See + _init_runtime_fields. + + Parameters + ---------- + + name: alphanumeric name + unique name for the workflow + + """ + clone = super(Workflow, self).clone(name) + clone._reset_hierarchy() + return clone + + # Graph creation functions + def connect(self, *args, **kwargs): + """Connect nodes in the pipeline. + + This routine also checks if inputs and outputs are actually provided by + the nodes that are being connected. + + Creates edges in the directed graph using the nodes and edges specified + in the `connection_list`. Uses the NetworkX method + DiGraph.add_edges_from. + + Parameters + ---------- + + args : list or a set of four positional arguments + + Four positional arguments of the form:: + + connect(source, sourceoutput, dest, destinput) + + source : nodewrapper node + sourceoutput : string (must be in source.outputs) + dest : nodewrapper node + destinput : string (must be in dest.inputs) + + A list of 3-tuples of the following form:: + + [(source, target, + [('sourceoutput/attribute', 'targetinput'), + ...]), + ...] + + Or:: + + [(source, target, [(('sourceoutput1', func, arg2, ...), + 'targetinput'), ...]), + ...] + sourceoutput1 will always be the first argument to func + and func will be evaluated and the results sent ot targetinput + + currently func needs to define all its needed imports within the + function as we use the inspect module to get at the source code + and execute it remotely + """ + if len(args) == 1: + connection_list = args[0] + elif len(args) == 4: + connection_list = [(args[0], args[2], [(args[1], args[3])])] + else: + raise TypeError('connect() takes either 4 arguments, or 1 list of' + ' connection tuples (%d args given)' % len(args)) + + disconnect = False + if kwargs: + disconnect = kwargs.get('disconnect', False) + + if disconnect: + self.disconnect(connection_list) + return + + newnodes = [] + for srcnode, destnode, _ in connection_list: + if self in [srcnode, destnode]: + msg = ('Workflow connect cannot contain itself as node:' + ' src[%s] dest[%s] workflow[%s]') % (srcnode, + destnode, + self.name) + + raise IOError(msg) + if (srcnode not in newnodes) and not self._has_node(srcnode): + newnodes.append(srcnode) + if (destnode not in newnodes) and not self._has_node(destnode): + newnodes.append(destnode) + if newnodes: + self._check_nodes(newnodes) + for node in newnodes: + if node._hierarchy is None: + node._hierarchy = self.name + not_found = [] + connected_ports = {} + for srcnode, destnode, connects in connection_list: + if destnode not in connected_ports: + connected_ports[destnode] = [] + # check to see which ports of destnode are already + # connected. + if not disconnect and (destnode in self._graph.nodes()): + for edge in self._graph.in_edges_iter(destnode): + data = self._graph.get_edge_data(*edge) + for sourceinfo, destname in data['connect']: + if destname not in connected_ports[destnode]: + connected_ports[destnode] += [destname] + for source, dest in connects: + # Currently datasource/sink/grabber.io modules + # determine their inputs/outputs depending on + # connection settings. Skip these modules in the check + if dest in connected_ports[destnode]: + raise Exception(""" +Trying to connect %s:%s to %s:%s but input '%s' of node '%s' is already +connected. +""" % (srcnode, source, destnode, dest, dest, destnode)) + if not (hasattr(destnode, '_interface') and + '.io' in str(destnode._interface.__class__)): + if not destnode._check_inputs(dest): + not_found.append(['in', destnode.name, dest]) + if not (hasattr(srcnode, '_interface') and + '.io' in str(srcnode._interface.__class__)): + if isinstance(source, tuple): + # handles the case that source is specified + # with a function + sourcename = source[0] + elif isinstance(source, string_types): + sourcename = source + else: + raise Exception(('Unknown source specification in ' + 'connection from output of %s') % + srcnode.name) + if sourcename and not srcnode._check_outputs(sourcename): + not_found.append(['out', srcnode.name, sourcename]) + connected_ports[destnode] += [dest] + infostr = [] + for info in not_found: + infostr += ["Module %s has no %sput called %s\n" % (info[1], + info[0], + info[2])] + if not_found: + raise Exception('\n'.join(['Some connections were not found'] + + infostr)) + + # turn functions into strings + for srcnode, destnode, connects in connection_list: + for idx, (src, dest) in enumerate(connects): + if isinstance(src, tuple) and not isinstance(src[1], string_types): + function_source = getsource(src[1]) + connects[idx] = ((src[0], function_source, src[2:]), dest) + + # add connections + for srcnode, destnode, connects in connection_list: + edge_data = self._graph.get_edge_data(srcnode, destnode, None) + if edge_data: + logger.debug('(%s, %s): Edge data exists: %s' + % (srcnode, destnode, str(edge_data))) + for data in connects: + if data not in edge_data['connect']: + edge_data['connect'].append(data) + if disconnect: + logger.debug('Removing connection: %s' % str(data)) + edge_data['connect'].remove(data) + if edge_data['connect']: + self._graph.add_edges_from([(srcnode, + destnode, + edge_data)]) + else: + # pass + logger.debug('Removing connection: %s->%s' % (srcnode, + destnode)) + self._graph.remove_edges_from([(srcnode, destnode)]) + elif not disconnect: + logger.debug('(%s, %s): No edge data' % (srcnode, destnode)) + self._graph.add_edges_from([(srcnode, destnode, + {'connect': connects})]) + edge_data = self._graph.get_edge_data(srcnode, destnode) + logger.debug('(%s, %s): new edge data: %s' % (srcnode, destnode, + str(edge_data))) + + def disconnect(self, *args): + """Disconnect nodes + See the docstring for connect for format. + """ + if len(args) == 1: + connection_list = args[0] + elif len(args) == 4: + connection_list = [(args[0], args[2], [(args[1], args[3])])] + else: + raise TypeError('disconnect() takes either 4 arguments, or 1 list ' + 'of connection tuples (%d args given)' % len(args)) + + for srcnode, dstnode, conn in connection_list: + logger.debug('disconnect(): %s->%s %s' % (srcnode, dstnode, conn)) + if self in [srcnode, dstnode]: + raise IOError( + 'Workflow connect cannot contain itself as node: src[%s] ' + 'dest[%s] workflow[%s]') % (srcnode, dstnode, self.name) + + # If node is not in the graph, not connected + if not self._has_node(srcnode) or not self._has_node(dstnode): + continue + + edge_data = self._graph.get_edge_data( + srcnode, dstnode, {'connect': []}) + ed_conns = [(c[0], c[1]) for c in edge_data['connect']] + + remove = [] + for edge in conn: + if edge in ed_conns: + idx = ed_conns.index(edge) + remove.append((edge[0], edge[1])) + + logger.debug('disconnect(): remove list %s' % remove) + for el in remove: + edge_data['connect'].remove(el) + logger.debug('disconnect(): removed connection %s' % str(el)) + + if not edge_data['connect']: + self._graph.remove_edge(srcnode, dstnode) + else: + self._graph.add_edges_from([(srcnode, dstnode, edge_data)]) + + def add_nodes(self, nodes): + """ Add nodes to a workflow + + Parameters + ---------- + nodes : list + A list of WorkflowBase-based objects + """ + newnodes = [] + all_nodes = self._get_all_nodes() + for node in nodes: + if self._has_node(node): + raise IOError('Node %s already exists in the workflow' % node) + if isinstance(node, Workflow): + for subnode in node._get_all_nodes(): + if subnode in all_nodes: + raise IOError(('Subnode %s of node %s already exists ' + 'in the workflow') % (subnode, node)) + newnodes.append(node) + if not newnodes: + logger.debug('no new nodes to add') + return + for node in newnodes: + if not issubclass(node.__class__, WorkflowBase): + raise Exception('Node %s must be a subclass of WorkflowBase' % + str(node)) + self._check_nodes(newnodes) + for node in newnodes: + if node._hierarchy is None: + node._hierarchy = self.name + self._graph.add_nodes_from(newnodes) + + def remove_nodes(self, nodes): + """ Remove nodes from a workflow + + Parameters + ---------- + nodes : list + A list of WorkflowBase-based objects + """ + self._graph.remove_nodes_from(nodes) + + # Input-Output access + @property + def inputs(self): + return self._get_inputs() + + @property + def outputs(self): + return self._get_outputs() + + def get_node(self, name): + """Return an internal node by name + """ + nodenames = name.split('.') + nodename = nodenames[0] + outnode = [node for node in self._graph.nodes() if + str(node).endswith('.' + nodename)] + if outnode: + outnode = outnode[0] + if nodenames[1:] and issubclass(outnode.__class__, Workflow): + outnode = outnode.get_node('.'.join(nodenames[1:])) + else: + outnode = None + return outnode + + def list_node_names(self): + """List names of all nodes in a workflow + """ + outlist = [] + for node in nx.topological_sort(self._graph): + if isinstance(node, Workflow): + outlist.extend(['.'.join((node.name, nodename)) for nodename in + node.list_node_names()]) + else: + outlist.append(node.name) + return sorted(outlist) + + def write_graph(self, dotfilename='graph.dot', graph2use='hierarchical', + format="png", simple_form=True): + """Generates a graphviz dot file and a png file + + Parameters + ---------- + + graph2use: 'orig', 'hierarchical' (default), 'flat', 'exec', 'colored' + orig - creates a top level graph without expanding internal + workflow nodes; + flat - expands workflow nodes recursively; + hierarchical - expands workflow nodes recursively with a + notion on hierarchy; + colored - expands workflow nodes recursively with a + notion on hierarchy in color; + exec - expands workflows to depict iterables + + format: 'png', 'svg' + + simple_form: boolean (default: True) + Determines if the node name used in the graph should be of the form + 'nodename (package)' when True or 'nodename.Class.package' when + False. + + """ + graphtypes = ['orig', 'flat', 'hierarchical', 'exec', 'colored'] + if graph2use not in graphtypes: + raise ValueError('Unknown graph2use keyword. Must be one of: ' + + str(graphtypes)) + base_dir, dotfilename = op.split(dotfilename) + if base_dir == '': + if self.base_dir: + base_dir = self.base_dir + if self.name: + base_dir = op.join(base_dir, self.name) + else: + base_dir = os.getcwd() + base_dir = make_output_dir(base_dir) + if graph2use in ['hierarchical', 'colored']: + dotfilename = op.join(base_dir, dotfilename) + self.write_hierarchical_dotfile(dotfilename=dotfilename, + colored=graph2use == "colored", + simple_form=simple_form) + format_dot(dotfilename, format=format) + else: + graph = self._graph + if graph2use in ['flat', 'exec']: + graph = self._create_flat_graph() + if graph2use == 'exec': + graph = generate_expanded_graph(deepcopy(graph)) + export_graph(graph, base_dir, dotfilename=dotfilename, + format=format, simple_form=simple_form) + + def write_hierarchical_dotfile(self, dotfilename=None, colored=False, + simple_form=True): + dotlist = ['digraph %s{' % self.name] + dotlist.append(self._get_dot(prefix=' ', colored=colored, + simple_form=simple_form)) + dotlist.append('}') + dotstr = '\n'.join(dotlist) + if dotfilename: + fp = open(dotfilename, 'wt') + fp.writelines(dotstr) + fp.close() + else: + logger.info(dotstr) + + def export(self, filename=None, prefix="output", format="python", + include_config=False): + """Export object into a different format + + Parameters + ---------- + filename: string + file to save the code to; overrides prefix + prefix: string + prefix to use for output file + format: string + one of "python" + include_config: boolean + whether to include node and workflow config values + + """ + formats = ["python"] + if format not in formats: + raise ValueError('format must be one of: %s' % '|'.join(formats)) + flatgraph = self._create_flat_graph() + nodes = nx.topological_sort(flatgraph) + + lines = ['# Workflow'] + importlines = ['from nipype.pipeline.engine import Workflow, ' + 'Node, MapNode'] + functions = {} + if format == "python": + connect_template = '%s.connect(%%s, %%s, %%s, "%%s")' % self.name + connect_template2 = '%s.connect(%%s, "%%s", %%s, "%%s")' \ + % self.name + wfdef = '%s = Workflow("%s")' % (self.name, self.name) + lines.append(wfdef) + if include_config: + lines.append('%s.config = %s' % (self.name, self.config)) + for idx, node in enumerate(nodes): + nodename = node.fullname.replace('.', '_') + # write nodes + nodelines = format_node(node, format='python', + include_config=include_config) + for line in nodelines: + if line.startswith('from'): + if line not in importlines: + importlines.append(line) + else: + lines.append(line) + # write connections + for u, _, d in flatgraph.in_edges_iter(nbunch=node, + data=True): + for cd in d['connect']: + if isinstance(cd[0], tuple): + args = list(cd[0]) + if args[1] in functions: + funcname = functions[args[1]] + else: + func = create_function_from_source(args[1]) + funcname = [name for name in func.__globals__ + if name != '__builtins__'][0] + functions[args[1]] = funcname + args[1] = funcname + args = tuple([arg for arg in args if arg]) + line_args = (u.fullname.replace('.', '_'), + args, nodename, cd[1]) + line = connect_template % line_args + line = line.replace("'%s'" % funcname, funcname) + lines.append(line) + else: + line_args = (u.fullname.replace('.', '_'), + cd[0], nodename, cd[1]) + lines.append(connect_template2 % line_args) + functionlines = ['# Functions'] + for function in functions: + functionlines.append(pickle.loads(function).rstrip()) + all_lines = importlines + functionlines + lines + + if not filename: + filename = '%s%s.py' % (prefix, self.name) + with open(filename, 'wt') as fp: + fp.writelines('\n'.join(all_lines)) + return all_lines + + def run(self, plugin=None, plugin_args=None, updatehash=False): + """ Execute the workflow + + Parameters + ---------- + + plugin: plugin name or object + Plugin to use for execution. You can create your own plugins for + execution. + plugin_args : dictionary containing arguments to be sent to plugin + constructor. see individual plugin doc strings for details. + """ + if plugin is None: + plugin = config.get('execution', 'plugin') + if not isinstance(plugin, string_types): + runner = plugin + else: + name = 'nipype.pipeline.plugins' + try: + __import__(name) + except ImportError: + msg = 'Could not import plugin module: %s' % name + logger.error(msg) + raise ImportError(msg) + else: + plugin_mod = getattr(sys.modules[name], '%sPlugin' % plugin) + runner = plugin_mod(plugin_args=plugin_args) + flatgraph = self._create_flat_graph() + self.config = merge_dict(deepcopy(config._sections), self.config) + if 'crashdump_dir' in self.config: + warn(("Deprecated: workflow.config['crashdump_dir']\n" + "Please use config['execution']['crashdump_dir']")) + crash_dir = self.config['crashdump_dir'] + self.config['execution']['crashdump_dir'] = crash_dir + del self.config['crashdump_dir'] + logger.info(str(sorted(self.config))) + self._set_needed_outputs(flatgraph) + execgraph = generate_expanded_graph(deepcopy(flatgraph)) + for index, node in enumerate(execgraph.nodes()): + node.config = merge_dict(deepcopy(self.config), node.config) + node.base_dir = self.base_dir + node.index = index + if isinstance(node, MapNode): + node.use_plugin = (plugin, plugin_args) + self._configure_exec_nodes(execgraph) + if str2bool(self.config['execution']['create_report']): + self._write_report_info(self.base_dir, self.name, execgraph) + runner.run(execgraph, updatehash=updatehash, config=self.config) + datestr = datetime.utcnow().strftime('%Y%m%dT%H%M%S') + if str2bool(self.config['execution']['write_provenance']): + prov_base = op.join(self.base_dir, + 'workflow_provenance_%s' % datestr) + logger.info('Provenance file prefix: %s' % prov_base) + write_workflow_prov(execgraph, prov_base, format='all') + return execgraph + + # PRIVATE API AND FUNCTIONS + + def _write_report_info(self, workingdir, name, graph): + if workingdir is None: + workingdir = os.getcwd() + report_dir = op.join(workingdir, name) + if not op.exists(report_dir): + os.makedirs(report_dir) + shutil.copyfile(op.join(op.dirname(__file__), + 'report_template.html'), + op.join(report_dir, 'index.html')) + shutil.copyfile(op.join(op.dirname(__file__), + '..', '..', 'external', 'd3.js'), + op.join(report_dir, 'd3.js')) + nodes, groups = topological_sort(graph, depth_first=True) + graph_file = op.join(report_dir, 'graph1.json') + json_dict = {'nodes': [], 'links': [], 'groups': [], 'maxN': 0} + for i, node in enumerate(nodes): + report_file = "%s/_report/report.rst" % \ + node.output_dir().replace(report_dir, '') + result_file = "%s/result_%s.pklz" % \ + (node.output_dir().replace(report_dir, ''), + node.name) + json_dict['nodes'].append(dict(name='%d_%s' % (i, node.name), + report=report_file, + result=result_file, + group=groups[i])) + maxN = 0 + for gid in np.unique(groups): + procs = [i for i, val in enumerate(groups) if val == gid] + N = len(procs) + if N > maxN: + maxN = N + json_dict['groups'].append(dict(procs=procs, + total=N, + name='Group_%05d' % gid)) + json_dict['maxN'] = maxN + for u, v in graph.in_edges_iter(): + json_dict['links'].append(dict(source=nodes.index(u), + target=nodes.index(v), + value=1)) + save_json(graph_file, json_dict) + graph_file = op.join(report_dir, 'graph.json') + template = '%%0%dd_' % np.ceil(np.log10(len(nodes))).astype(int) + + def getname(u, i): + name_parts = u.fullname.split('.') + # return '.'.join(name_parts[:-1] + [template % i + name_parts[-1]]) + return template % i + name_parts[-1] + json_dict = [] + for i, node in enumerate(nodes): + imports = [] + for u, v in graph.in_edges_iter(nbunch=node): + imports.append(getname(u, nodes.index(u))) + json_dict.append(dict(name=getname(node, i), + size=1, + group=groups[i], + imports=imports)) + save_json(graph_file, json_dict) + + def _set_needed_outputs(self, graph): + """Initialize node with list of which outputs are needed.""" + rm_outputs = self.config['execution']['remove_unnecessary_outputs'] + if not str2bool(rm_outputs): + return + for node in graph.nodes(): + node.needed_outputs = [] + for edge in graph.out_edges_iter(node): + data = graph.get_edge_data(*edge) + sourceinfo = [v1[0] if isinstance(v1, tuple) else v1 + for v1, v2 in data['connect']] + node.needed_outputs += [v for v in sourceinfo + if v not in node.needed_outputs] + if node.needed_outputs: + node.needed_outputs = sorted(node.needed_outputs) + + def _configure_exec_nodes(self, graph): + """Ensure that each node knows where to get inputs from + """ + for node in graph.nodes(): + node.input_source = {} + for edge in graph.in_edges_iter(node): + data = graph.get_edge_data(*edge) + for sourceinfo, field in sorted(data['connect']): + node.input_source[field] = \ + (op.join(edge[0].output_dir(), + 'result_%s.pklz' % edge[0].name), + sourceinfo) + + def _check_nodes(self, nodes): + """Checks if any of the nodes are already in the graph + + """ + node_names = [node.name for node in self._graph.nodes()] + node_lineage = [node._hierarchy for node in self._graph.nodes()] + for node in nodes: + if node.name in node_names: + idx = node_names.index(node.name) + if node_lineage[idx] in [node._hierarchy, self.name]: + raise IOError('Duplicate node name %s found.' % node.name) + else: + node_names.append(node.name) + + def _has_attr(self, parameter, subtype='in'): + """Checks if a parameter is available as an input or output + """ + if subtype == 'in': + subobject = self.inputs + else: + subobject = self.outputs + attrlist = parameter.split('.') + cur_out = subobject + for attr in attrlist: + if not hasattr(cur_out, attr): + return False + cur_out = getattr(cur_out, attr) + return True + + def _get_parameter_node(self, parameter, subtype='in'): + """Returns the underlying node corresponding to an input or + output parameter + """ + if subtype == 'in': + subobject = self.inputs + else: + subobject = self.outputs + attrlist = parameter.split('.') + cur_out = subobject + for attr in attrlist[:-1]: + cur_out = getattr(cur_out, attr) + return cur_out.traits()[attrlist[-1]].node + + def _check_outputs(self, parameter): + return self._has_attr(parameter, subtype='out') + + def _check_inputs(self, parameter): + return self._has_attr(parameter, subtype='in') + + def _get_inputs(self): + """Returns the inputs of a workflow + + This function does not return any input ports that are already + connected + """ + inputdict = TraitedSpec() + for node in self._graph.nodes(): + inputdict.add_trait(node.name, traits.Instance(TraitedSpec)) + if isinstance(node, Workflow): + setattr(inputdict, node.name, node.inputs) + else: + taken_inputs = [] + for _, _, d in self._graph.in_edges_iter(nbunch=node, + data=True): + for cd in d['connect']: + taken_inputs.append(cd[1]) + unconnectedinputs = TraitedSpec() + for key, trait in list(node.inputs.items()): + if key not in taken_inputs: + unconnectedinputs.add_trait(key, + traits.Trait(trait, + node=node)) + value = getattr(node.inputs, key) + setattr(unconnectedinputs, key, value) + setattr(inputdict, node.name, unconnectedinputs) + getattr(inputdict, node.name).on_trait_change(self._set_input) + return inputdict + + def _get_outputs(self): + """Returns all possible output ports that are not already connected + """ + outputdict = TraitedSpec() + for node in self._graph.nodes(): + outputdict.add_trait(node.name, traits.Instance(TraitedSpec)) + if isinstance(node, Workflow): + setattr(outputdict, node.name, node.outputs) + elif node.outputs: + outputs = TraitedSpec() + for key, _ in list(node.outputs.items()): + outputs.add_trait(key, traits.Any(node=node)) + setattr(outputs, key, None) + setattr(outputdict, node.name, outputs) + return outputdict + + def _set_input(self, object, name, newvalue): + """Trait callback function to update a node input + """ + object.traits()[name].node.set_input(name, newvalue) + + def _set_node_input(self, node, param, source, sourceinfo): + """Set inputs of a node given the edge connection""" + if isinstance(sourceinfo, string_types): + val = source.get_output(sourceinfo) + elif isinstance(sourceinfo, tuple): + if callable(sourceinfo[1]): + val = sourceinfo[1](source.get_output(sourceinfo[0]), + *sourceinfo[2:]) + newval = val + if isinstance(val, TraitDictObject): + newval = dict(val) + if isinstance(val, TraitListObject): + newval = val[:] + logger.debug('setting node input: %s->%s', param, str(newval)) + node.set_input(param, deepcopy(newval)) + + def _get_all_nodes(self): + allnodes = [] + for node in self._graph.nodes(): + if isinstance(node, Workflow): + allnodes.extend(node._get_all_nodes()) + else: + allnodes.append(node) + return allnodes + + def _has_node(self, wanted_node): + for node in self._graph.nodes(): + if wanted_node == node: + return True + if isinstance(node, Workflow): + if node._has_node(wanted_node): + return True + return False + + def _create_flat_graph(self): + """Make a simple DAG where no node is a workflow.""" + logger.debug('Creating flat graph for workflow: %s', self.name) + workflowcopy = deepcopy(self) + workflowcopy._generate_flatgraph() + return workflowcopy._graph + + def _reset_hierarchy(self): + """Reset the hierarchy on a graph + """ + for node in self._graph.nodes(): + if isinstance(node, Workflow): + node._reset_hierarchy() + for innernode in node._graph.nodes(): + innernode._hierarchy = '.'.join((self.name, + innernode._hierarchy)) + else: + node._hierarchy = self.name + + def _generate_flatgraph(self): + """Generate a graph containing only Nodes or MapNodes + """ + logger.debug('expanding workflow: %s', self) + nodes2remove = [] + if not nx.is_directed_acyclic_graph(self._graph): + raise Exception(('Workflow: %s is not a directed acyclic graph ' + '(DAG)') % self.name) + nodes = nx.topological_sort(self._graph) + for node in nodes: + logger.debug('processing node: %s' % node) + if isinstance(node, Workflow): + nodes2remove.append(node) + # use in_edges instead of in_edges_iter to allow + # disconnections to take place properly. otherwise, the + # edge dict is modified. + for u, _, d in self._graph.in_edges(nbunch=node, data=True): + logger.debug('in: connections-> %s' % str(d['connect'])) + for cd in deepcopy(d['connect']): + logger.debug("in: %s" % str(cd)) + dstnode = node._get_parameter_node(cd[1], subtype='in') + srcnode = u + srcout = cd[0] + dstin = cd[1].split('.')[-1] + logger.debug('in edges: %s %s %s %s' % + (srcnode, srcout, dstnode, dstin)) + self.disconnect(u, cd[0], node, cd[1]) + self.connect(srcnode, srcout, dstnode, dstin) + # do not use out_edges_iter for reasons stated in in_edges + for _, v, d in self._graph.out_edges(nbunch=node, data=True): + logger.debug('out: connections-> %s' % str(d['connect'])) + for cd in deepcopy(d['connect']): + logger.debug("out: %s" % str(cd)) + dstnode = v + if isinstance(cd[0], tuple): + parameter = cd[0][0] + else: + parameter = cd[0] + srcnode = node._get_parameter_node(parameter, + subtype='out') + if isinstance(cd[0], tuple): + srcout = list(cd[0]) + srcout[0] = parameter.split('.')[-1] + srcout = tuple(srcout) + else: + srcout = parameter.split('.')[-1] + dstin = cd[1] + logger.debug('out edges: %s %s %s %s' % (srcnode, + srcout, + dstnode, + dstin)) + self.disconnect(node, cd[0], v, cd[1]) + self.connect(srcnode, srcout, dstnode, dstin) + # expand the workflow node + # logger.debug('expanding workflow: %s', node) + node._generate_flatgraph() + for innernode in node._graph.nodes(): + innernode._hierarchy = '.'.join((self.name, + innernode._hierarchy)) + self._graph.add_nodes_from(node._graph.nodes()) + self._graph.add_edges_from(node._graph.edges(data=True)) + if nodes2remove: + self._graph.remove_nodes_from(nodes2remove) + logger.debug('finished expanding workflow: %s', self) + + def _get_dot(self, prefix=None, hierarchy=None, colored=False, + simple_form=True, level=0): + """Create a dot file with connection info + """ + if prefix is None: + prefix = ' ' + if hierarchy is None: + hierarchy = [] + colorset = ['#FFFFC8', '#0000FF', '#B4B4FF', '#E6E6FF', '#FF0000', + '#FFB4B4', '#FFE6E6', '#00A300', '#B4FFB4', '#E6FFE6'] + + dotlist = ['%slabel="%s";' % (prefix, self.name)] + for node in nx.topological_sort(self._graph): + fullname = '.'.join(hierarchy + [node.fullname]) + nodename = fullname.replace('.', '_') + if not isinstance(node, Workflow): + node_class_name = get_print_name(node, simple_form=simple_form) + if not simple_form: + node_class_name = '.'.join(node_class_name.split('.')[1:]) + if hasattr(node, 'iterables') and node.iterables: + dotlist.append(('%s[label="%s", shape=box3d,' + 'style=filled, color=black, colorscheme' + '=greys7 fillcolor=2];') % (nodename, + node_class_name)) + else: + if colored: + dotlist.append(('%s[label="%s", style=filled,' + ' fillcolor="%s"];') + % (nodename, node_class_name, + colorset[level])) + else: + dotlist.append(('%s[label="%s"];') + % (nodename, node_class_name)) + + for node in nx.topological_sort(self._graph): + if isinstance(node, Workflow): + fullname = '.'.join(hierarchy + [node.fullname]) + nodename = fullname.replace('.', '_') + dotlist.append('subgraph cluster_%s {' % nodename) + if colored: + dotlist.append(prefix + prefix + 'edge [color="%s"];' % (colorset[level + 1])) + dotlist.append(prefix + prefix + 'style=filled;') + dotlist.append(prefix + prefix + 'fillcolor="%s";' % (colorset[level + 2])) + dotlist.append(node._get_dot(prefix=prefix + prefix, + hierarchy=hierarchy + [self.name], + colored=colored, + simple_form=simple_form, level=level + 3)) + dotlist.append('}') + if level == 6: + level = 2 + else: + for subnode in self._graph.successors_iter(node): + if node._hierarchy != subnode._hierarchy: + continue + if not isinstance(subnode, Workflow): + nodefullname = '.'.join(hierarchy + [node.fullname]) + subnodefullname = '.'.join(hierarchy + + [subnode.fullname]) + nodename = nodefullname.replace('.', '_') + subnodename = subnodefullname.replace('.', '_') + for _ in self._graph.get_edge_data(node, + subnode)['connect']: + dotlist.append('%s -> %s;' % (nodename, + subnodename)) + logger.debug('connection: ' + dotlist[-1]) + # add between workflow connections + for u, v, d in self._graph.edges_iter(data=True): + uname = '.'.join(hierarchy + [u.fullname]) + vname = '.'.join(hierarchy + [v.fullname]) + for src, dest in d['connect']: + uname1 = uname + vname1 = vname + if isinstance(src, tuple): + srcname = src[0] + else: + srcname = src + if '.' in srcname: + uname1 += '.' + '.'.join(srcname.split('.')[:-1]) + if '.' in dest and '@' not in dest: + if not isinstance(v, Workflow): + if 'datasink' not in \ + str(v._interface.__class__).lower(): + vname1 += '.' + '.'.join(dest.split('.')[:-1]) + else: + vname1 += '.' + '.'.join(dest.split('.')[:-1]) + if uname1.split('.')[:-1] != vname1.split('.')[:-1]: + dotlist.append('%s -> %s;' % (uname1.replace('.', '_'), + vname1.replace('.', '_'))) + logger.debug('cross connection: ' + dotlist[-1]) + return ('\n' + prefix).join(dotlist) From 879862f583081c8ae8d9475b467c5d74ccac6592 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 20:37:14 +0100 Subject: [PATCH 2/8] restructure engine module --- nipype/caching/memory.py | 2 +- nipype/pipeline/engine/base.py | 8 ++++---- nipype/pipeline/engine/utils.py | 24 ++++++++++++------------ nipype/pipeline/plugins/base.py | 7 ++++--- nipype/pipeline/plugins/debug.py | 2 +- nipype/pipeline/plugins/linear.py | 2 +- nipype/workflows/fmri/spm/preprocess.py | 3 ++- 7 files changed, 25 insertions(+), 23 deletions(-) diff --git a/nipype/caching/memory.py b/nipype/caching/memory.py index 67ef605e32..d8b14fb396 100644 --- a/nipype/caching/memory.py +++ b/nipype/caching/memory.py @@ -21,7 +21,7 @@ from ..interfaces.base import BaseInterface from ..pipeline.engine import Node -from ..pipeline.utils import modify_paths +from ..pipeline.engine.utils import modify_paths ################################################################################ # PipeFunc object: callable interface to nipype.interface objects diff --git a/nipype/pipeline/engine/base.py b/nipype/pipeline/engine/base.py index a4d21328f2..f93a02247a 100644 --- a/nipype/pipeline/engine/base.py +++ b/nipype/pipeline/engine/base.py @@ -27,11 +27,11 @@ from copy import deepcopy import re import numpy as np -from nipype.interfaces.traits_extension import traits, Undefined -from nipype.interfaces.base import DynamicTraitedSpec -from nipype.utils.filemanip import loadpkl, savepkl +from ...interfaces.traits_extension import traits, Undefined +from ...interfaces.base import DynamicTraitedSpec +from ...utils.filemanip import loadpkl, savepkl -from nipype import logging +from ... import logging logger = logging.getLogger('workflow') diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index ac9560da69..d81361d771 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -32,16 +32,16 @@ import networkx as nx -from ..external.six import string_types -from ..utils.filemanip import (fname_presuffix, FileNotFoundError, - filename_to_list, get_related_files) -from ..utils.misc import create_function_from_source, str2bool -from ..interfaces.base import (CommandLine, isdefined, Undefined, - InterfaceResult) -from ..interfaces.utility import IdentityInterface -from ..utils.provenance import ProvStore, pm, nipype_ns, get_id - -from .. import logging, config +from ...external.six import string_types +from ...utils.filemanip import (fname_presuffix, FileNotFoundError, + filename_to_list, get_related_files) +from ...utils.misc import create_function_from_source, str2bool +from ...interfaces.base import (CommandLine, isdefined, Undefined, + InterfaceResult) +from ...interfaces.utility import IdentityInterface +from ...utils.provenance import ProvStore, pm, nipype_ns, get_id + +from ... import logging, config logger = logging.getLogger('workflow') try: @@ -361,7 +361,7 @@ def walk(children, level=0, path=None, usename=True): Examples -------- - >>> from nipype.pipeline.utils import walk + >>> from nipype.pipeline.engine.utils import walk >>> iterables = [('a', lambda: [1, 2]), ('b', lambda: [3, 4])] >>> [val['a'] for val in walk(iterables)] [1, 1, 2, 2] @@ -396,7 +396,7 @@ def synchronize_iterables(iterables): Examples -------- - >>> from nipype.pipeline.utils import synchronize_iterables + >>> from nipype.pipeline.engine.utils import synchronize_iterables >>> iterables = dict(a=lambda: [1, 2], b=lambda: [3, 4]) >>> synced = synchronize_iterables(iterables) >>> synced == [{'a': 1, 'b': 3}, {'a': 2, 'b': 4}] diff --git a/nipype/pipeline/plugins/base.py b/nipype/pipeline/plugins/base.py index 9b7adad343..162ddd9df4 100644 --- a/nipype/pipeline/plugins/base.py +++ b/nipype/pipeline/plugins/base.py @@ -21,10 +21,11 @@ import scipy.sparse as ssp -from ..utils import (nx, dfs_preorder, topological_sort) -from ..engine import (MapNode, str2bool) +from ...utils.filemanip import savepkl, loadpkl +from ...utils.misc import str2bool +from ..engine.utils import (nx, dfs_preorder, topological_sort) +from ..engine import MapNode -from nipype.utils.filemanip import savepkl, loadpkl from ... import logging logger = logging.getLogger('workflow') diff --git a/nipype/pipeline/plugins/debug.py b/nipype/pipeline/plugins/debug.py index 9d0a52adaa..9d219ac7df 100644 --- a/nipype/pipeline/plugins/debug.py +++ b/nipype/pipeline/plugins/debug.py @@ -4,7 +4,7 @@ """ from .base import (PluginBase, logger) -from ..utils import (nx) +from ..engine.utils import (nx) class DebugPlugin(PluginBase): diff --git a/nipype/pipeline/plugins/linear.py b/nipype/pipeline/plugins/linear.py index 216d037757..48e8aba64c 100644 --- a/nipype/pipeline/plugins/linear.py +++ b/nipype/pipeline/plugins/linear.py @@ -6,7 +6,7 @@ from .base import (PluginBase, logger, report_crash, report_nodes_not_run, str2bool) -from ..utils import (nx, dfs_preorder, topological_sort) +from ..engine.utils import (nx, dfs_preorder, topological_sort) class LinearPlugin(PluginBase): diff --git a/nipype/workflows/fmri/spm/preprocess.py b/nipype/workflows/fmri/spm/preprocess.py index c1f533323d..1208b30c90 100644 --- a/nipype/workflows/fmri/spm/preprocess.py +++ b/nipype/workflows/fmri/spm/preprocess.py @@ -10,7 +10,8 @@ from ....interfaces.matlab import no_matlab from ...smri.freesurfer.utils import create_getmask_flow -logger = pe.logger +from .... import logging +logger = logging.getLogger('workflow') def create_spm_preproc(name='preproc'): From 04fd0d2ecc55ee27be87ef8375c233d54c812d0c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 20:38:10 +0100 Subject: [PATCH 3/8] remove unnecessary pass --- nipype/pipeline/engine/nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nipype/pipeline/engine/nodes.py b/nipype/pipeline/engine/nodes.py index abb61bafe1..9f9165e3b2 100644 --- a/nipype/pipeline/engine/nodes.py +++ b/nipype/pipeline/engine/nodes.py @@ -367,7 +367,6 @@ def run(self, updatehash=False): logger.warn(('An exception was raised trying to remove old %s, ' 'but the path seems empty. Is it an NFS mount?. ' 'Passing the exception.') % outdir) - pass elif ((ex.errno == errno.ENOTEMPTY) and (len(outdircont) != 0)): logger.debug(('Folder contents (%d items): ' '%s') % (len(outdircont), outdircont)) From 7aed6915a7383ff2c53f5e19f3025d56fdb06913 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 20:59:37 +0100 Subject: [PATCH 4/8] add new packages --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5a5159d166..f32d3792b8 100755 --- a/setup.py +++ b/setup.py @@ -380,9 +380,10 @@ def main(**extra_args): 'nipype.interfaces.vista', 'nipype.interfaces.vista.tests', 'nipype.pipeline', + 'nipype.pipeline.engine', + 'nipype.pipeline.engine.tests', 'nipype.pipeline.plugins', 'nipype.pipeline.plugins.tests', - 'nipype.pipeline.tests', 'nipype.testing', 'nipype.testing.data', 'nipype.testing.data.bedpostxout', @@ -424,7 +425,7 @@ def main(**extra_args): pjoin('testing', 'data', 'bedpostxout', '*'), pjoin('testing', 'data', 'tbss_dir', '*'), pjoin('workflows', 'data', '*'), - pjoin('pipeline', 'report_template.html'), + pjoin('pipeline', 'engine', 'report_template.html'), pjoin('external', 'd3.js'), pjoin('interfaces', 'script_templates', '*'), pjoin('interfaces', 'tests', 'realign_json.json') From 018195def146125d6654f8d4b5245a480e2c6dea Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 21:16:37 +0100 Subject: [PATCH 5/8] fix tests --- nipype/caching/tests/test_memory.py | 2 +- nipype/pipeline/engine/__init__.py | 1 + nipype/pipeline/engine/base.py | 2 +- nipype/pipeline/engine/tests/test_engine.py | 3 ++- nipype/pipeline/engine/workflows.py | 8 ++++---- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/nipype/caching/tests/test_memory.py b/nipype/caching/tests/test_memory.py index 784eca1b93..89afcfa7a7 100644 --- a/nipype/caching/tests/test_memory.py +++ b/nipype/caching/tests/test_memory.py @@ -7,7 +7,7 @@ from nose.tools import assert_equal from .. import Memory -from ...pipeline.tests.test_engine import TestInterface +from ...pipeline.engine.tests.test_engine import TestInterface from ... import config config.set_default_config() diff --git a/nipype/pipeline/engine/__init__.py b/nipype/pipeline/engine/__init__.py index 51ae8e92ea..e950086307 100644 --- a/nipype/pipeline/engine/__init__.py +++ b/nipype/pipeline/engine/__init__.py @@ -11,3 +11,4 @@ __docformat__ = 'restructuredtext' from .workflows import Workflow from .nodes import Node, MapNode, JoinNode +from .utils import generate_expanded_graph diff --git a/nipype/pipeline/engine/base.py b/nipype/pipeline/engine/base.py index f93a02247a..148db2f271 100644 --- a/nipype/pipeline/engine/base.py +++ b/nipype/pipeline/engine/base.py @@ -75,7 +75,7 @@ def fullname(self): return fullname def clone(self, name): - """Clone a workflowbase object + """Clone an EngineBase object Parameters ---------- diff --git a/nipype/pipeline/engine/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py index 30b2981b4c..5f46d09a28 100644 --- a/nipype/pipeline/engine/tests/test_engine.py +++ b/nipype/pipeline/engine/tests/test_engine.py @@ -540,7 +540,8 @@ def func1(in1): try: n2.run() except Exception as e: - pe.logger.info('Exception: %s' % str(e)) + from nipype.pipeline.engine.base import logger + logger.info('Exception: %s' % str(e)) error_raised = True yield assert_true, error_raised diff --git a/nipype/pipeline/engine/workflows.py b/nipype/pipeline/engine/workflows.py index 9328b414bd..193fb0650a 100644 --- a/nipype/pipeline/engine/workflows.py +++ b/nipype/pipeline/engine/workflows.py @@ -326,7 +326,7 @@ def add_nodes(self, nodes): Parameters ---------- nodes : list - A list of WorkflowBase-based objects + A list of EngineBase-based objects """ newnodes = [] all_nodes = self._get_all_nodes() @@ -343,8 +343,8 @@ def add_nodes(self, nodes): logger.debug('no new nodes to add') return for node in newnodes: - if not issubclass(node.__class__, WorkflowBase): - raise Exception('Node %s must be a subclass of WorkflowBase' % + if not issubclass(node.__class__, EngineBase): + raise Exception('Node %s must be a subclass of EngineBase' % str(node)) self._check_nodes(newnodes) for node in newnodes: @@ -358,7 +358,7 @@ def remove_nodes(self, nodes): Parameters ---------- nodes : list - A list of WorkflowBase-based objects + A list of EngineBase-based objects """ self._graph.remove_nodes_from(nodes) From d0cc5d6da0f8961e040fde31501e1859b67a54a8 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 21:26:05 +0100 Subject: [PATCH 6/8] fix tests using logger --- nipype/pipeline/engine/tests/test_engine.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/nipype/pipeline/engine/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py index 5f46d09a28..aadfb2607d 100644 --- a/nipype/pipeline/engine/tests/test_engine.py +++ b/nipype/pipeline/engine/tests/test_engine.py @@ -586,7 +586,8 @@ def _submit_job(self, node, updatehash=False): try: w1.run(plugin=RaiseError()) except Exception as e: - pe.logger.info('Exception: %s' % str(e)) + from nipype.pipeline.engine.base import logger + logger.info('Exception: %s' % str(e)) error_raised = True yield assert_true, error_raised # yield assert_true, 'Submit called' in e @@ -600,7 +601,8 @@ def _submit_job(self, node, updatehash=False): try: w1.run(plugin=RaiseError()) except Exception as e: - pe.logger.info('Exception: %s' % str(e)) + from nipype.pipeline.engine.base import logger + logger.info('Exception: %s' % str(e)) error_raised = True yield assert_false, error_raised os.chdir(cwd) @@ -638,7 +640,8 @@ def func2(a): try: w1.run(plugin='Linear') except Exception as e: - pe.logger.info('Exception: %s' % str(e)) + from nipype.pipeline.engine.base import logger + logger.info('Exception: %s' % str(e)) error_raised = True yield assert_false, error_raised os.chdir(cwd) @@ -722,7 +725,8 @@ def func1(in1): try: w1.run(plugin='MultiProc') except Exception as e: - pe.logger.info('Exception: %s' % str(e)) + from nipype.pipeline.engine.base import logger + logger.info('Exception: %s' % str(e)) error_raised = True yield assert_false, error_raised @@ -735,7 +739,8 @@ def func1(in1): try: w1.run(plugin='MultiProc') except Exception as e: - pe.logger.info('Exception: %s' % str(e)) + from nipype.pipeline.engine.base import logger + logger.info('Exception: %s' % str(e)) error_raised = True yield assert_false, error_raised From 913ec40d63e6de6608ff4128780c1f060f6f36fd Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 21:39:07 +0100 Subject: [PATCH 7/8] fix test_utils --- nipype/pipeline/engine/tests/test_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nipype/pipeline/engine/tests/test_utils.py b/nipype/pipeline/engine/tests/test_utils.py index 50d44b78a0..8420f587c2 100644 --- a/nipype/pipeline/engine/tests/test_utils.py +++ b/nipype/pipeline/engine/tests/test_utils.py @@ -9,11 +9,11 @@ from tempfile import mkdtemp from shutil import rmtree -from ...testing import (assert_equal, assert_true, assert_false) -import nipype.pipeline.engine as pe -import nipype.interfaces.base as nib -import nipype.interfaces.utility as niu -from ... import config +from ....testing import (assert_equal, assert_true, assert_false) +from ... import engine as pe +from ....interfaces import base as nib +from ....interfaces import utility as niu +from .... import config from ..utils import merge_dict, clean_working_directory, write_workflow_prov From 901ae6c4276a82c9caf870568587fa2421edcf05 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 29 Dec 2015 21:47:26 +0100 Subject: [PATCH 8/8] use relative imports in tests --- nipype/pipeline/engine/tests/test_engine.py | 6 +++--- nipype/pipeline/engine/tests/test_join.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nipype/pipeline/engine/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py index aadfb2607d..5eaaa81fbf 100644 --- a/nipype/pipeline/engine/tests/test_engine.py +++ b/nipype/pipeline/engine/tests/test_engine.py @@ -12,9 +12,9 @@ import networkx as nx -from nipype.testing import (assert_raises, assert_equal, assert_true, assert_false) -import nipype.interfaces.base as nib -import nipype.pipeline.engine as pe +from ....testing import (assert_raises, assert_equal, assert_true, assert_false) +from ... import engine as pe +from ....interfaces import base as nib class InputSpec(nib.TraitedSpec): diff --git a/nipype/pipeline/engine/tests/test_join.py b/nipype/pipeline/engine/tests/test_join.py index 4c5119ff46..b0882de91e 100644 --- a/nipype/pipeline/engine/tests/test_join.py +++ b/nipype/pipeline/engine/tests/test_join.py @@ -7,11 +7,11 @@ from shutil import rmtree from tempfile import mkdtemp -from nipype.testing import (assert_equal, assert_true) -import nipype.interfaces.base as nib -import nipype.pipeline.engine as pe -from nipype.interfaces.utility import IdentityInterface -from nipype.interfaces.base import traits, File +from ....testing import (assert_equal, assert_true) +from ... import engine as pe +from ....interfaces import base as nib +from ....interfaces.utility import IdentityInterface +from ....interfaces.base import traits, File class PickFirstSpec(nib.TraitedSpec):