@@ -551,3 +551,63 @@ def _list_outputs(self):
551
551
outputs [key ].append (val )
552
552
553
553
return outputs
554
+
555
+
556
+ class MultipleSelectInputSpec (BaseInterfaceInputSpec , DynamicTraitedSpec ):
557
+ index = InputMultiPath (traits .Int , mandatory = True ,
558
+ desc = '0-based indices of values to choose' )
559
+
560
+ class MultipleSelectInterface (IdentityInterface ):
561
+ """
562
+ Basic interface that demultiplexes lists generated by CollateInterface
563
+
564
+ Example
565
+ -------
566
+
567
+ >>> from nipype.interfaces.utility import MultipleSelectInterface
568
+ >>> demux = MultipleSelectInterface(fields=['file','miscdata'], index=0)
569
+ >>> demux.inputs.file = ['exfile1.csv', 'exfile2.csv']
570
+ >>> demux.inputs.miscdata = [1.0, 2.0]
571
+ >>> res = demux.run()
572
+ >>> print res.outputs.file
573
+ exfile1.csv
574
+ >>> print res.outputs.miscdata
575
+ 1.0
576
+ """
577
+ input_spec = MultipleSelectInputSpec
578
+ output_spec = DynamicTraitedSpec
579
+
580
+ def __init__ (self , fields = None , mandatory_inputs = True , ** inputs ):
581
+ super (IdentityInterface , self ).__init__ (** inputs )
582
+ if fields is None or not fields :
583
+ raise ValueError ('Identity Interface fields must be a non-empty list' )
584
+ # Each input must be in the fields.
585
+ for in_field in inputs :
586
+ if in_field not in fields and in_field != 'index' :
587
+ raise ValueError ('Identity Interface input is not in the fields: %s' % in_field )
588
+ self ._fields = fields
589
+ self ._mandatory_inputs = mandatory_inputs
590
+ add_traits (self .inputs , fields )
591
+ # Adding any traits wipes out all input values set in superclass initialization,
592
+ # even it the trait is not in the add_traits argument. The work-around is to reset
593
+ # the values after adding the traits.
594
+ self .inputs .set (** inputs )
595
+
596
+ def _list_outputs (self ):
597
+ #manual mandatory inputs check
598
+ if self ._fields and self ._mandatory_inputs :
599
+ for key in self ._fields :
600
+ value = getattr (self .inputs , key )
601
+ if not isdefined (value ):
602
+ msg = "%s requires a value for input '%s' because it was listed in 'fields'. \
603
+ You can turn off mandatory inputs checking by passing mandatory_inputs = False to the constructor." % \
604
+ (self .__class__ .__name__ , key )
605
+ raise ValueError (msg )
606
+
607
+ outputs = self ._outputs ().get ()
608
+ for key in self ._fields :
609
+ val = getattr (self .inputs , key )
610
+ if isdefined (val ):
611
+ outputs [key ] = np .squeeze (np .array (val )[np .array (self .inputs .index )]).tolist ()
612
+ return outputs
613
+
0 commit comments