diff --git a/nipype/exceptions.py b/nipype/exceptions.py new file mode 100644 index 0000000000..f6352c78f0 --- /dev/null +++ b/nipype/exceptions.py @@ -0,0 +1,26 @@ +class NipypeError(Exception): + pass + + +class PipelineError(NipypeError): + pass + + +class NodeError(PipelineError): + pass + + +class WorkflowError(NodeError): + pass + + +class MappingError(NodeError): + pass + + +class JoinError(NodeError): + pass + + +class InterfaceError(NipypeError): + pass diff --git a/nipype/pipeline/engine/workflows.py b/nipype/pipeline/engine/workflows.py index d2f040786e..c0e253c0e3 100644 --- a/nipype/pipeline/engine/workflows.py +++ b/nipype/pipeline/engine/workflows.py @@ -22,6 +22,7 @@ import networkx as nx from ... import config, logging +from ...exceptions import NodeError, WorkflowError, MappingError, JoinError from ...utils.misc import str2bool from ...utils.functions import (getsource, create_function_from_source) @@ -33,7 +34,7 @@ get_print_name, merge_dict, format_node) from .base import EngineBase -from .nodes import MapNode +from .nodes import MapNode, Node # Py2 compat: http://python-future.org/compatible_idioms.html#collections-counter-and-ordereddict from future import standard_library @@ -1043,3 +1044,109 @@ def _get_dot(self, vname1.replace('.', '_'))) logger.debug('cross connection: %s', dotlist[-1]) return ('\n' + prefix).join(dotlist) + + def add(self, name, node_like): + if is_interface(node_like): + node = Node(node_like, name=name) + elif is_node(node_like): + node = node_like + + self.add_nodes([node]) + + +class Map(Node): + pass + + +class Join(Node): + pass + + +class MapState(object): + pass + +class NewNode(EngineBase): + def __init__(self, inputs={}, map_on=None, join_by=None, + *args, **kwargs): + self._mappers = {} + self._joiners = {} + + def map(self, field, values=None): + if isinstance(field, list): + for field_ + if values is not None: + if len(values != len(field)): + elif isinstance(field, tuple): + pass + if values is None: + values = getattr(self._inputs, field) + if values is None: + raise MappingError('Cannot map unassigned input field') + self._mappers[field] = values + + def join(self, field): + pass + + +class NewWorkflow(NewNode): + def __init__(self, inputs={}, *args, **kwargs): + super(NewWorkflow, self).__init__(*args, **kwargs) + + self._nodes = {} + + mro = self.__class__.mro() + wf_klasses = mro[:mro.index(NewWorkflow)][::-1] + items = {} + for klass in wf_klasses: + items.update(klass.__dict__) + for name, runnable in items.items(): + if name in ('__module__', '__doc__'): + continue + + self.add(name, value) + + def add(self, name, runnable): + if is_function(runnable): + node = Node(Function(function=runnable), name=name) + elif is_interface(runnable): + node = Node(runnable, name=name) + elif is_node(runnable): + node = runnable if runnable.name == name else runnable.clone(name=name) + else: + raise ValueError("Unknown workflow element: {!r}".format(runnable)) + setattr(self, name, node) + self._nodes[name] = node + self._last_added = name + + def map(self, field, node=None, values=None): + if node is None: + if '.' in field: + node, field = field.rsplit('.', 1) + else: + node = self._last_added + + if '.' in node: + subwf, node = node.split('.', 1) + self._nodes[subwf].map(field, node, values) + return + + if node in self._mappers: + raise WorkflowError("Cannot assign two mappings to the same input") + + self._mappers[node] = (field, values) + + def join(self, field, node=None): + pass + + +def is_function(obj): + return hasattr(obj, '__call__') + + +def is_interface(obj): + return all(hasattr(obj, protocol) + for protocol in ('input_spec', 'output_spec', 'run')) + + +def is_node(obj): + return hasattr(obj, itername)