Skip to content

[REF] Cache nodes in workflow to speed up construction, other optimizations #3331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 30, 2021
Merged
15 changes: 8 additions & 7 deletions nipype/pipeline/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def _merge_graphs(
# nodes of the supergraph.
supernodes = supergraph.nodes()
ids = [n._hierarchy + n._id for n in supernodes]
if len(np.unique(ids)) != len(ids):
if len(set(ids)) != len(ids):
# This should trap the problem of miswiring when multiple iterables are
# used at the same level. The use of the template below for naming
# updates to nodes is the general solution.
Expand Down Expand Up @@ -1100,11 +1100,12 @@ def make_field_func(*pair):
old_edge_dict = jedge_dict[jnode]
# the edge source node replicates
expansions = defaultdict(list)
for node in graph_in.nodes():
for node in graph_in:
for src_id in list(old_edge_dict.keys()):
# Drop the original JoinNodes; only concerned with
# generated Nodes
if hasattr(node, "joinfield") and node.itername == src_id:
itername = node.itername
if hasattr(node, "joinfield") and itername == src_id:
continue
# Patterns:
# - src_id : Non-iterable node
Expand All @@ -1113,10 +1114,10 @@ def make_field_func(*pair):
# - src_id.[a-z]I.[a-z]\d+ :
# Non-IdentityInterface w/ iterables
# - src_idJ\d+ : JoinNode(IdentityInterface)
if re.match(
src_id + r"((\.[a-z](I\.[a-z])?|J)\d+)?$", node.itername
):
expansions[src_id].append(node)
if itername.startswith(src_id):
suffix = itername[len(src_id):]
if re.fullmatch(r"((\.[a-z](I\.[a-z])?|J)\d+)?", suffix):
expansions[src_id].append(node)
for in_id, in_nodes in list(expansions.items()):
logger.debug(
"The join node %s input %s was expanded" " to %d nodes.",
Expand Down
65 changes: 42 additions & 23 deletions nipype/pipeline/engine/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
super(Workflow, self).__init__(name, base_dir)
self._graph = nx.DiGraph()

self._nodes_cache = set()
self._nested_workflows_cache = set()

# PUBLIC API
def clone(self, name):
"""Clone a workflow
Expand Down Expand Up @@ -141,7 +144,7 @@ def connect(self, *args, **kwargs):
self.disconnect(connection_list)
return

newnodes = []
newnodes = set()
for srcnode, destnode, _ in connection_list:
if self in [srcnode, destnode]:
msg = (
Expand All @@ -151,9 +154,9 @@ def connect(self, *args, **kwargs):

raise IOError(msg)
if (srcnode not in newnodes) and not self._has_node(srcnode):
newnodes.append(srcnode)
newnodes.add(srcnode)
if (destnode not in newnodes) and not self._has_node(destnode):
newnodes.append(destnode)
newnodes.add(destnode)
if newnodes:
self._check_nodes(newnodes)
for node in newnodes:
Expand All @@ -163,15 +166,16 @@ def connect(self, *args, **kwargs):
connected_ports = {}
for srcnode, destnode, connects in connection_list:
if destnode not in connected_ports:
connected_ports[destnode] = []
connected_ports[destnode] = set()
# check to see which ports of destnode are already
# connected.
if not disconnect and (destnode in self._graph.nodes()):
for edge in self._graph.in_edges(destnode):
data = self._graph.get_edge_data(*edge)
for sourceinfo, destname in data["connect"]:
if destname not in connected_ports[destnode]:
connected_ports[destnode] += [destname]
connected_ports[destnode].update(
destname
for _, destname in data["connect"]
)
for source, dest in connects:
# Currently datasource/sink/grabber.io modules
# determine their inputs/outputs depending on
Expand Down Expand Up @@ -226,7 +230,7 @@ def connect(self, *args, **kwargs):
)
if sourcename and not srcnode._check_outputs(sourcename):
not_found.append(["out", srcnode.name, sourcename])
connected_ports[destnode] += [dest]
connected_ports[destnode].add(dest)
infostr = []
for info in not_found:
infostr += [
Expand Down Expand Up @@ -269,6 +273,9 @@ def connect(self, *args, **kwargs):
"(%s, %s): new edge data: %s", srcnode, destnode, str(edge_data)
)

if newnodes:
self._update_node_cache()

def disconnect(self, *args):
"""Disconnect nodes
See the docstring for connect for format.
Expand Down Expand Up @@ -325,7 +332,7 @@ def add_nodes(self, nodes):
newnodes = []
all_nodes = self._get_all_nodes()
for node in nodes:
if self._has_node(node):
if node in all_nodes:
raise IOError("Node %s already exists in the workflow" % node)
if isinstance(node, Workflow):
for subnode in node._get_all_nodes():
Expand All @@ -346,6 +353,7 @@ def add_nodes(self, nodes):
if node._hierarchy is None:
node._hierarchy = self.name
self._graph.add_nodes_from(newnodes)
self._update_node_cache()

def remove_nodes(self, nodes):
""" Remove nodes from a workflow
Expand All @@ -356,6 +364,7 @@ def remove_nodes(self, nodes):
A list of EngineBase-based objects
"""
self._graph.remove_nodes_from(nodes)
self._update_node_cache()

# Input-Output access
@property
Expand Down Expand Up @@ -903,22 +912,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
node.set_input(param, deepcopy(newval))

def _get_all_nodes(self):
allnodes = []
for node in self._graph.nodes():
if isinstance(node, Workflow):
allnodes.extend(node._get_all_nodes())
else:
allnodes.append(node)
allnodes = self._nodes_cache - self._nested_workflows_cache
for node in self._nested_workflows_cache:
allnodes |= node._get_all_nodes()
return allnodes

def _has_node(self, wanted_node):
for node in self._graph.nodes():
if wanted_node == node:
return True
def _update_node_cache(self):
nodes = set(self._graph)

added_nodes = nodes.difference(self._nodes_cache)
removed_nodes = self._nodes_cache.difference(nodes)

self._nodes_cache = nodes
self._nested_workflows_cache.difference_update(removed_nodes)

for node in added_nodes:
if isinstance(node, Workflow):
if node._has_node(wanted_node):
return True
return False
self._nested_workflows_cache.add(node)

def _has_node(self, wanted_node):
return (
wanted_node in self._nodes_cache or
any(
wf._has_node(wanted_node)
for wf in self._nested_workflows_cache
)
)

def _create_flat_graph(self):
"""Make a simple DAG where no node is a workflow."""
Expand Down Expand Up @@ -949,7 +968,7 @@ def _generate_flatgraph(self):
raise Exception(
("Workflow: %s is not a directed acyclic graph " "(DAG)") % self.name
)
nodes = list(nx.topological_sort(self._graph))
nodes = list(self._graph.nodes)
for node in nodes:
logger.debug("processing node: %s", node)
if isinstance(node, Workflow):
Expand Down