Skip to content

Commit 5f4a764

Browse files
FredLoneysatra
FredLoney
authored andcommitted
Refactor iterable standardization to account for the itersource alternate format.
1 parent 24b33d9 commit 5f4a764

File tree

1 file changed

+89
-32
lines changed

1 file changed

+89
-32
lines changed

nipype/pipeline/utils.py

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -584,23 +584,15 @@ def generate_expanded_graph(graph_in):
584584
"""
585585
logger.debug("PE: expanding iterables")
586586
graph_in = _remove_nonjoin_identity_nodes(graph_in, keep_iterables=True)
587-
# convert list of tuples to dict fields
587+
# standardize the iterables as {(field, function)} dictionaries
588588
for node in graph_in.nodes_iter():
589-
if isinstance(node.iterables, tuple):
590-
node.iterables = [node.iterables]
591-
for node in graph_in.nodes_iter():
592-
if isinstance(node.iterables, list):
593-
node.iterables = dict(map(lambda(x): (x[0],
594-
lambda: x[1]),
595-
node.iterables))
589+
if node.iterables:
590+
_standardize_iterables(node)
596591
allprefixes = list('abcdefghijklmnopqrstuvwxyz')
597592

598593
# the iterable nodes
599594
inodes = _iterable_nodes(graph_in)
600595
logger.debug("Detected iterable nodes %s" % inodes)
601-
# record the iterable fields, since expansion removes them
602-
iter_fld_dict = {inode.name: inode.iterables.keys()
603-
for inode in inodes}
604596
# while there is an iterable node, expand the iterable node's
605597
# subgraphs
606598
while inodes:
@@ -626,22 +618,25 @@ def generate_expanded_graph(graph_in):
626618
% (src, dest))
627619

628620
if inode.itersource:
621+
# the itersource is a (node name, fields) tuple
622+
src_name, src_fields = inode.itersource
623+
# convert a single field to a list
624+
if isinstance(src_fields, str):
625+
src_fields = [src_fields]
629626
# find the unique iterable source node in the graph
630-
iter_src = None
631-
for node in graph_in.nodes_iter():
632-
if (node.name == inode.itersource
633-
and nx.has_path(graph_in, node, inode)):
634-
iter_src = node
635-
break
636-
if not iter_src or not iter_fld_dict.has_key(inode.itersource):
627+
try:
628+
iter_src = next((node for node in graph_in.nodes_iter()
629+
if node.name == src_name
630+
and nx.has_path(graph_in, node, inode)))
631+
except StopIteration:
637632
raise ValueError("The node %s itersource %s was not found"
638-
" among the iterable nodes %s"
639-
% (inode, inode.itersource, iter_fld_dict.keys()))
633+
" among the iterable predecessor nodes"
634+
% (inode, src_name))
635+
logger.debug("The node %s has iterable source node %s"
636+
% (inode, iter_src))
640637
# look up the iterables for this particular itersource descendant
641638
# using the iterable source ancestor values as a key
642639
iterables = {}
643-
# the source node iterables fields
644-
src_fields = iter_fld_dict[inode.itersource]
645640
# the source node iterables values
646641
src_values = [getattr(iter_src.inputs, field) for field in src_fields]
647642
# if there is one source field, then the key is the the source value,
@@ -714,9 +709,12 @@ def generate_expanded_graph(graph_in):
714709
for src_id, edge_data in old_edge_dict.iteritems():
715710
if node._id.startswith(src_id):
716711
expansions[src_id].append(node)
712+
for in_id, in_nodes in expansions.iteritems():
713+
logger.debug("The join node %s input %s was expanded"
714+
" to %d nodes." %(jnode, in_id, len(in_nodes)))
717715
# preserve the node iteration order by sorting on the node id
718-
for src_nodes in expansions.itervalues():
719-
src_nodes.sort(key=lambda node: node._id)
716+
for in_nodes in expansions.itervalues():
717+
in_nodes.sort(key=lambda node: node._id)
720718

721719
# the number of iterations.
722720
iter_cnt = count_iterables(iterables, inode.synchronize)
@@ -731,28 +729,28 @@ def generate_expanded_graph(graph_in):
731729
# field 'in' are qualified as ('out_file', 'in1') and
732730
# ('out_file', 'in2'), resp. This preserves connection port
733731
# integrity.
734-
for old_id, src_nodes in expansions.iteritems():
732+
for old_id, in_nodes in expansions.iteritems():
735733
# reconnect each replication of the current join in-edge
736734
# source
737-
for si, src in enumerate(src_nodes):
735+
for in_idx, in_node in enumerate(in_nodes):
738736
olddata = old_edge_dict[old_id]
739737
newdata = deepcopy(olddata)
740738
connects = newdata['connect']
741739
join_fields = [field for _, field in connects
742740
if field in dest.joinfield]
743-
slots = slot_dicts[si]
744-
for ci, connect in enumerate(connects):
741+
slots = slot_dicts[in_idx]
742+
for con_idx, connect in enumerate(connects):
745743
src_field, dest_field = connect
746744
# qualify a join destination field name
747745
if dest_field in slots:
748746
slot_field = slots[dest_field]
749-
connects[ci] = (src_field, slot_field)
747+
connects[con_idx] = (src_field, slot_field)
750748
logger.debug("Qualified the %s -> %s join field"
751749
" %s as %s." %
752-
(src, jnode, dest_field, slot_field))
753-
graph_in.add_edge(src, jnode, newdata)
750+
(in_node, jnode, dest_field, slot_field))
751+
graph_in.add_edge(in_node, jnode, newdata)
754752
logger.debug("Connected the join node %s subgraph to the"
755-
" expanded join point %s" % (jnode, src))
753+
" expanded join point %s" % (jnode, in_node))
756754

757755
#nx.write_dot(graph_in, '%s_post.dot' % node)
758756
# the remaining iterable nodes
@@ -792,6 +790,65 @@ def _iterable_nodes(graph_in):
792790
inodes_src = [node for node in inodes if node.itersource]
793791
inodes_no_src.reverse()
794792
return inodes_no_src + inodes_src
793+
794+
def _standardize_iterables(node):
795+
"""Converts the given iterables to a {field: function} dictionary,
796+
if necessary, where the function returns a list."""
797+
# trivial case
798+
if not node.iterables:
799+
return
800+
iterables = node.iterables
801+
# The candidate iterable fields
802+
fields = set(node.inputs.copyable_trait_names())
803+
804+
# Convert a tuple to a list
805+
if isinstance(iterables, tuple):
806+
iterables = [iterables]
807+
# Convert a list to a dictionary
808+
if isinstance(iterables, list):
809+
# Synchronize iterables can be in [fields, value tuples] format
810+
# rather than [(field, value list), (field, value list), ...]
811+
if node.synchronize and len(iterables) == 2:
812+
first, last = iterables
813+
if all((isinstance(item, str) and item in fields
814+
for item in first)):
815+
iterables = _transpose_iterables(first, last)
816+
# Validate the format
817+
for item in iterables:
818+
try:
819+
if len(item) != 2:
820+
raise ValueError("The %s iterables do not consist of"
821+
" (field, values) pairs" % node.name)
822+
except TypeError, e:
823+
raise TypeError("The %s iterables is not iterable: %s"
824+
% (node.name, e))
825+
# Convert the values to functions. This is a legacy Nipype
826+
# requirement with unknown rationale.
827+
iter_items = map(lambda(field, value): (field, lambda: value),
828+
iterables)
829+
# Make the iterables dictionary
830+
iterables = dict(iter_items)
831+
elif not isinstance(iterables, dict):
832+
raise ValueError("The %s iterables type is not a list or a dictionary:"
833+
" %s" % (node.name, iterables.__class__))
834+
835+
# Validate the iterable fields
836+
for field in iterables.iterkeys():
837+
if field not in fields:
838+
raise ValueError("The %s iterables field is unrecognized: %s"
839+
% (node.name, field))
840+
841+
# Assign to the standard form
842+
node.iterables = iterables
843+
844+
def _transpose_iterables(fields, values):
845+
"""
846+
Converts the given fields and tuple values into a list of
847+
iterable (field: value list) pairs, suitable for setting
848+
a node iterables property.
849+
"""
850+
return zip(fields, [filter(lambda(v): v != None, transpose)
851+
for transpose in zip(*values)])
795852

796853
def export_graph(graph_in, base_dir=None, show=False, use_execgraph=False,
797854
show_connectinfo=False, dotfilename='graph.dot', format='png',

0 commit comments

Comments
 (0)