@@ -464,3 +464,87 @@ def _run_interface(self, runtime):
464
464
assert_equal (data1 , data2 )
465
465
466
466
return runtime
467
+
468
+ class CollateInterfaceInputSpec (DynamicTraitedSpec , BaseInterfaceInputSpec ):
469
+ _outputs = traits .Dict (traits .Any , value = {}, usedefault = True )
470
+
471
+ def __setattr__ (self , key , value ):
472
+ if key not in self .copyable_trait_names ():
473
+ if not isdefined (value ):
474
+ super (CollateInterfaceInputSpec , self ).__setattr__ (key , value )
475
+ self ._outputs [key ] = value
476
+ else :
477
+ if key in self ._outputs :
478
+ self ._outputs [key ] = value
479
+ super (CollateInterfaceInputSpec , self ).__setattr__ (key , value )
480
+
481
+ class CollateInterface (IOBase ):
482
+ """
483
+ A simple interface to multiplex inputs through a unique output set.
484
+ Channel is defined by the prefix of the fields. In order to avoid
485
+ inconsistencies, output fields should be defined forehand at initialization..
486
+
487
+ Example
488
+ -------
489
+
490
+ >>> from nipype.interfaces.utility import CollateInterface
491
+ >>> coll = CollateInterface(fields=['file','miscdata'])
492
+ >>> coll.inputs.src1_file = 'scores.csv'
493
+ >>> coll.inputs.src2_file = 'scores2.csv'
494
+ >>> coll.inputs.src1_miscdata = 1.0
495
+ >>> coll.inputs.src2_miscdata = 2.0
496
+ >>> coll.run() # doctest: +SKIP
497
+ """
498
+
499
+ input_spec = CollateInterfaceInputSpec
500
+ output_spec = DynamicTraitedSpec
501
+
502
+ def __init__ (self , fields = None , fill_missing = False , ** kwargs ):
503
+ super (CollateInterface , self ).__init__ (** kwargs )
504
+
505
+ if fields is None or not fields :
506
+ raise ValueError ('CollateInterface fields must be a non-empty list' )
507
+ # Each input must be in the fields.
508
+ self ._fields = fields
509
+ self ._fill_missing = fill_missing
510
+
511
+ def _add_output_traits (self , base ):
512
+ undefined_traits = {}
513
+ for key in self ._fields :
514
+ base .add_trait (key , traits .Any )
515
+ undefined_traits [key ] = Undefined
516
+ base .trait_set (trait_change_notify = False , ** undefined_traits )
517
+ return base
518
+
519
+ def _list_outputs (self ):
520
+ #manual mandatory inputs check
521
+ valuedict = dict ( (key , {}) for key in self ._fields )
522
+ nodekeys = []
523
+
524
+ for inputkey , inputval in self .inputs ._outputs .items ():
525
+ for key in self ._fields :
526
+ if inputkey .endswith (key ):
527
+ nodekey = inputkey [::- 1 ].replace (key [::- 1 ], '' , 1 )[::- 1 ]
528
+ nodekeys .append (nodekey )
529
+
530
+ if nodekey in valuedict [key ].keys ():
531
+ msg = ('Trying to add field from existing node' )
532
+ raise ValueError (msg )
533
+ valuedict [key ][nodekey ] = inputval
534
+
535
+ nodekeys = sorted (set (nodekeys ))
536
+ outputs = self ._outputs ().get ()
537
+ for key in self ._fields :
538
+ outputs [key ] = []
539
+ for nk in nodekeys :
540
+
541
+ if nk in valuedict [key ]:
542
+ val = valuedict [key ][nk ]
543
+ else :
544
+ if self ._fill_missing :
545
+ val = None
546
+ else :
547
+ raise RuntimeError ('Input missing for field to collate.' )
548
+ outputs [key ].append (val )
549
+
550
+ return outputs
0 commit comments