diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index f77f771ea7..4e8b6d2e8c 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -753,7 +753,7 @@ def _merge_graphs( # nodes of the supergraph. supernodes = supergraph.nodes() ids = [n._hierarchy + n._id for n in supernodes] - if len(np.unique(ids)) != len(ids): + if len(set(ids)) != len(ids): # This should trap the problem of miswiring when multiple iterables are # used at the same level. The use of the template below for naming # updates to nodes is the general solution. @@ -1100,11 +1100,12 @@ def make_field_func(*pair): old_edge_dict = jedge_dict[jnode] # the edge source node replicates expansions = defaultdict(list) - for node in graph_in.nodes(): + for node in graph_in: for src_id in list(old_edge_dict.keys()): # Drop the original JoinNodes; only concerned with # generated Nodes - if hasattr(node, "joinfield") and node.itername == src_id: + itername = node.itername + if hasattr(node, "joinfield") and itername == src_id: continue # Patterns: # - src_id : Non-iterable node @@ -1113,10 +1114,10 @@ def make_field_func(*pair): # - src_id.[a-z]I.[a-z]\d+ : # Non-IdentityInterface w/ iterables # - src_idJ\d+ : JoinNode(IdentityInterface) - if re.match( - src_id + r"((\.[a-z](I\.[a-z])?|J)\d+)?$", node.itername - ): - expansions[src_id].append(node) + if itername.startswith(src_id): + suffix = itername[len(src_id):] + if re.fullmatch(r"((\.[a-z](I\.[a-z])?|J)\d+)?", suffix): + expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): logger.debug( "The join node %s input %s was expanded" " to %d nodes.", diff --git a/nipype/pipeline/engine/workflows.py b/nipype/pipeline/engine/workflows.py index 9b6e60ffaf..184cfd5a57 100644 --- a/nipype/pipeline/engine/workflows.py +++ b/nipype/pipeline/engine/workflows.py @@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None): super(Workflow, self).__init__(name, base_dir) self._graph = nx.DiGraph() + self._nodes_cache = set() + self._nested_workflows_cache = set() + # PUBLIC API def clone(self, name): """Clone a workflow @@ -141,7 +144,7 @@ def connect(self, *args, **kwargs): self.disconnect(connection_list) return - newnodes = [] + newnodes = set() for srcnode, destnode, _ in connection_list: if self in [srcnode, destnode]: msg = ( @@ -151,9 +154,9 @@ def connect(self, *args, **kwargs): raise IOError(msg) if (srcnode not in newnodes) and not self._has_node(srcnode): - newnodes.append(srcnode) + newnodes.add(srcnode) if (destnode not in newnodes) and not self._has_node(destnode): - newnodes.append(destnode) + newnodes.add(destnode) if newnodes: self._check_nodes(newnodes) for node in newnodes: @@ -163,15 +166,16 @@ def connect(self, *args, **kwargs): connected_ports = {} for srcnode, destnode, connects in connection_list: if destnode not in connected_ports: - connected_ports[destnode] = [] + connected_ports[destnode] = set() # check to see which ports of destnode are already # connected. if not disconnect and (destnode in self._graph.nodes()): for edge in self._graph.in_edges(destnode): data = self._graph.get_edge_data(*edge) - for sourceinfo, destname in data["connect"]: - if destname not in connected_ports[destnode]: - connected_ports[destnode] += [destname] + connected_ports[destnode].update( + destname + for _, destname in data["connect"] + ) for source, dest in connects: # Currently datasource/sink/grabber.io modules # determine their inputs/outputs depending on @@ -226,7 +230,7 @@ def connect(self, *args, **kwargs): ) if sourcename and not srcnode._check_outputs(sourcename): not_found.append(["out", srcnode.name, sourcename]) - connected_ports[destnode] += [dest] + connected_ports[destnode].add(dest) infostr = [] for info in not_found: infostr += [ @@ -269,6 +273,9 @@ def connect(self, *args, **kwargs): "(%s, %s): new edge data: %s", srcnode, destnode, str(edge_data) ) + if newnodes: + self._update_node_cache() + def disconnect(self, *args): """Disconnect nodes See the docstring for connect for format. @@ -325,7 +332,7 @@ def add_nodes(self, nodes): newnodes = [] all_nodes = self._get_all_nodes() for node in nodes: - if self._has_node(node): + if node in all_nodes: raise IOError("Node %s already exists in the workflow" % node) if isinstance(node, Workflow): for subnode in node._get_all_nodes(): @@ -346,6 +353,7 @@ def add_nodes(self, nodes): if node._hierarchy is None: node._hierarchy = self.name self._graph.add_nodes_from(newnodes) + self._update_node_cache() def remove_nodes(self, nodes): """ Remove nodes from a workflow @@ -356,6 +364,7 @@ def remove_nodes(self, nodes): A list of EngineBase-based objects """ self._graph.remove_nodes_from(nodes) + self._update_node_cache() # Input-Output access @property @@ -903,22 +912,32 @@ def _set_node_input(self, node, param, source, sourceinfo): node.set_input(param, deepcopy(newval)) def _get_all_nodes(self): - allnodes = [] - for node in self._graph.nodes(): - if isinstance(node, Workflow): - allnodes.extend(node._get_all_nodes()) - else: - allnodes.append(node) + allnodes = self._nodes_cache - self._nested_workflows_cache + for node in self._nested_workflows_cache: + allnodes |= node._get_all_nodes() return allnodes - def _has_node(self, wanted_node): - for node in self._graph.nodes(): - if wanted_node == node: - return True + def _update_node_cache(self): + nodes = set(self._graph) + + added_nodes = nodes.difference(self._nodes_cache) + removed_nodes = self._nodes_cache.difference(nodes) + + self._nodes_cache = nodes + self._nested_workflows_cache.difference_update(removed_nodes) + + for node in added_nodes: if isinstance(node, Workflow): - if node._has_node(wanted_node): - return True - return False + self._nested_workflows_cache.add(node) + + def _has_node(self, wanted_node): + return ( + wanted_node in self._nodes_cache or + any( + wf._has_node(wanted_node) + for wf in self._nested_workflows_cache + ) + ) def _create_flat_graph(self): """Make a simple DAG where no node is a workflow.""" @@ -949,7 +968,7 @@ def _generate_flatgraph(self): raise Exception( ("Workflow: %s is not a directed acyclic graph " "(DAG)") % self.name ) - nodes = list(nx.topological_sort(self._graph)) + nodes = list(self._graph.nodes) for node in nodes: logger.debug("processing node: %s", node) if isinstance(node, Workflow):