@@ -584,23 +584,15 @@ def generate_expanded_graph(graph_in):
584
584
"""
585
585
logger .debug ("PE: expanding iterables" )
586
586
graph_in = _remove_nonjoin_identity_nodes (graph_in , keep_iterables = True )
587
- # convert list of tuples to dict fields
587
+ # standardize the iterables as {(field, function)} dictionaries
588
588
for node in graph_in .nodes_iter ():
589
- if isinstance (node .iterables , tuple ):
590
- node .iterables = [node .iterables ]
591
- for node in graph_in .nodes_iter ():
592
- if isinstance (node .iterables , list ):
593
- node .iterables = dict (map (lambda (x ): (x [0 ],
594
- lambda : x [1 ]),
595
- node .iterables ))
589
+ if node .iterables :
590
+ _standardize_iterables (node )
596
591
allprefixes = list ('abcdefghijklmnopqrstuvwxyz' )
597
592
598
593
# the iterable nodes
599
594
inodes = _iterable_nodes (graph_in )
600
595
logger .debug ("Detected iterable nodes %s" % inodes )
601
- # record the iterable fields, since expansion removes them
602
- iter_fld_dict = {inode .name : inode .iterables .keys ()
603
- for inode in inodes }
604
596
# while there is an iterable node, expand the iterable node's
605
597
# subgraphs
606
598
while inodes :
@@ -626,22 +618,25 @@ def generate_expanded_graph(graph_in):
626
618
% (src , dest ))
627
619
628
620
if inode .itersource :
621
+ # the itersource is a (node name, fields) tuple
622
+ src_name , src_fields = inode .itersource
623
+ # convert a single field to a list
624
+ if isinstance (src_fields , str ):
625
+ src_fields = [src_fields ]
629
626
# find the unique iterable source node in the graph
630
- iter_src = None
631
- for node in graph_in .nodes_iter ():
632
- if (node .name == inode .itersource
633
- and nx .has_path (graph_in , node , inode )):
634
- iter_src = node
635
- break
636
- if not iter_src or not iter_fld_dict .has_key (inode .itersource ):
627
+ try :
628
+ iter_src = next ((node for node in graph_in .nodes_iter ()
629
+ if node .name == src_name
630
+ and nx .has_path (graph_in , node , inode )))
631
+ except StopIteration :
637
632
raise ValueError ("The node %s itersource %s was not found"
638
- " among the iterable nodes %s"
639
- % (inode , inode .itersource , iter_fld_dict .keys ()))
633
+ " among the iterable predecessor nodes"
634
+ % (inode , src_name ))
635
+ logger .debug ("The node %s has iterable source node %s"
636
+ % (inode , iter_src ))
640
637
# look up the iterables for this particular itersource descendant
641
638
# using the iterable source ancestor values as a key
642
639
iterables = {}
643
- # the source node iterables fields
644
- src_fields = iter_fld_dict [inode .itersource ]
645
640
# the source node iterables values
646
641
src_values = [getattr (iter_src .inputs , field ) for field in src_fields ]
647
642
# if there is one source field, then the key is the the source value,
@@ -714,9 +709,12 @@ def generate_expanded_graph(graph_in):
714
709
for src_id , edge_data in old_edge_dict .iteritems ():
715
710
if node ._id .startswith (src_id ):
716
711
expansions [src_id ].append (node )
712
+ for in_id , in_nodes in expansions .iteritems ():
713
+ logger .debug ("The join node %s input %s was expanded"
714
+ " to %d nodes." % (jnode , in_id , len (in_nodes )))
717
715
# preserve the node iteration order by sorting on the node id
718
- for src_nodes in expansions .itervalues ():
719
- src_nodes .sort (key = lambda node : node ._id )
716
+ for in_nodes in expansions .itervalues ():
717
+ in_nodes .sort (key = lambda node : node ._id )
720
718
721
719
# the number of iterations.
722
720
iter_cnt = count_iterables (iterables , inode .synchronize )
@@ -731,28 +729,28 @@ def generate_expanded_graph(graph_in):
731
729
# field 'in' are qualified as ('out_file', 'in1') and
732
730
# ('out_file', 'in2'), resp. This preserves connection port
733
731
# integrity.
734
- for old_id , src_nodes in expansions .iteritems ():
732
+ for old_id , in_nodes in expansions .iteritems ():
735
733
# reconnect each replication of the current join in-edge
736
734
# source
737
- for si , src in enumerate (src_nodes ):
735
+ for in_idx , in_node in enumerate (in_nodes ):
738
736
olddata = old_edge_dict [old_id ]
739
737
newdata = deepcopy (olddata )
740
738
connects = newdata ['connect' ]
741
739
join_fields = [field for _ , field in connects
742
740
if field in dest .joinfield ]
743
- slots = slot_dicts [si ]
744
- for ci , connect in enumerate (connects ):
741
+ slots = slot_dicts [in_idx ]
742
+ for con_idx , connect in enumerate (connects ):
745
743
src_field , dest_field = connect
746
744
# qualify a join destination field name
747
745
if dest_field in slots :
748
746
slot_field = slots [dest_field ]
749
- connects [ci ] = (src_field , slot_field )
747
+ connects [con_idx ] = (src_field , slot_field )
750
748
logger .debug ("Qualified the %s -> %s join field"
751
749
" %s as %s." %
752
- (src , jnode , dest_field , slot_field ))
753
- graph_in .add_edge (src , jnode , newdata )
750
+ (in_node , jnode , dest_field , slot_field ))
751
+ graph_in .add_edge (in_node , jnode , newdata )
754
752
logger .debug ("Connected the join node %s subgraph to the"
755
- " expanded join point %s" % (jnode , src ))
753
+ " expanded join point %s" % (jnode , in_node ))
756
754
757
755
#nx.write_dot(graph_in, '%s_post.dot' % node)
758
756
# the remaining iterable nodes
@@ -792,6 +790,65 @@ def _iterable_nodes(graph_in):
792
790
inodes_src = [node for node in inodes if node .itersource ]
793
791
inodes_no_src .reverse ()
794
792
return inodes_no_src + inodes_src
793
+
794
+ def _standardize_iterables (node ):
795
+ """Converts the given iterables to a {field: function} dictionary,
796
+ if necessary, where the function returns a list."""
797
+ # trivial case
798
+ if not node .iterables :
799
+ return
800
+ iterables = node .iterables
801
+ # The candidate iterable fields
802
+ fields = set (node .inputs .copyable_trait_names ())
803
+
804
+ # Convert a tuple to a list
805
+ if isinstance (iterables , tuple ):
806
+ iterables = [iterables ]
807
+ # Convert a list to a dictionary
808
+ if isinstance (iterables , list ):
809
+ # Synchronize iterables can be in [fields, value tuples] format
810
+ # rather than [(field, value list), (field, value list), ...]
811
+ if node .synchronize and len (iterables ) == 2 :
812
+ first , last = iterables
813
+ if all ((isinstance (item , str ) and item in fields
814
+ for item in first )):
815
+ iterables = _transpose_iterables (first , last )
816
+ # Validate the format
817
+ for item in iterables :
818
+ try :
819
+ if len (item ) != 2 :
820
+ raise ValueError ("The %s iterables do not consist of"
821
+ " (field, values) pairs" % node .name )
822
+ except TypeError , e :
823
+ raise TypeError ("The %s iterables is not iterable: %s"
824
+ % (node .name , e ))
825
+ # Convert the values to functions. This is a legacy Nipype
826
+ # requirement with unknown rationale.
827
+ iter_items = map (lambda (field , value ): (field , lambda : value ),
828
+ iterables )
829
+ # Make the iterables dictionary
830
+ iterables = dict (iter_items )
831
+ elif not isinstance (iterables , dict ):
832
+ raise ValueError ("The %s iterables type is not a list or a dictionary:"
833
+ " %s" % (node .name , iterables .__class__ ))
834
+
835
+ # Validate the iterable fields
836
+ for field in iterables .iterkeys ():
837
+ if field not in fields :
838
+ raise ValueError ("The %s iterables field is unrecognized: %s"
839
+ % (node .name , field ))
840
+
841
+ # Assign to the standard form
842
+ node .iterables = iterables
843
+
844
+ def _transpose_iterables (fields , values ):
845
+ """
846
+ Converts the given fields and tuple values into a list of
847
+ iterable (field: value list) pairs, suitable for setting
848
+ a node iterables property.
849
+ """
850
+ return zip (fields , [filter (lambda (v ): v != None , transpose )
851
+ for transpose in zip (* values )])
795
852
796
853
def export_graph (graph_in , base_dir = None , show = False , use_execgraph = False ,
797
854
show_connectinfo = False , dotfilename = 'graph.dot' , format = 'png' ,
0 commit comments