Skip to content

Commit cd8919d

Browse files
committed
Finished merge
1 parent 001bb6d commit cd8919d

File tree

2 files changed

+100
-318
lines changed

2 files changed

+100
-318
lines changed

nipype/interfaces/utility.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,87 @@ def _run_interface(self, runtime):
464464
assert_equal(data1, data2)
465465

466466
return runtime
467+
468+
class CollateInterfaceInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
469+
_outputs = traits.Dict(traits.Any, value={}, usedefault=True)
470+
471+
def __setattr__(self, key, value):
472+
if key not in self.copyable_trait_names():
473+
if not isdefined(value):
474+
super(CollateInterfaceInputSpec, self).__setattr__(key, value)
475+
self._outputs[key] = value
476+
else:
477+
if key in self._outputs:
478+
self._outputs[key] = value
479+
super(CollateInterfaceInputSpec, self).__setattr__(key, value)
480+
481+
class CollateInterface(IOBase):
482+
"""
483+
A simple interface to multiplex inputs through a unique output set.
484+
Channel is defined by the prefix of the fields. In order to avoid
485+
inconsistencies, output fields should be defined forehand at initialization..
486+
487+
Example
488+
-------
489+
490+
>>> from nipype.interfaces.utility import CollateInterface
491+
>>> coll = CollateInterface(fields=['file','miscdata'])
492+
>>> coll.inputs.src1_file = 'scores.csv'
493+
>>> coll.inputs.src2_file = 'scores2.csv'
494+
>>> coll.inputs.src1_miscdata = 1.0
495+
>>> coll.inputs.src2_miscdata = 2.0
496+
>>> coll.run() # doctest: +SKIP
497+
"""
498+
499+
input_spec = CollateInterfaceInputSpec
500+
output_spec = DynamicTraitedSpec
501+
502+
def __init__(self, fields=None, fill_missing=False, **kwargs):
503+
super(CollateInterface, self).__init__(**kwargs)
504+
505+
if fields is None or not fields:
506+
raise ValueError('CollateInterface fields must be a non-empty list')
507+
# Each input must be in the fields.
508+
self._fields = fields
509+
self._fill_missing = fill_missing
510+
511+
def _add_output_traits(self, base):
512+
undefined_traits = {}
513+
for key in self._fields:
514+
base.add_trait(key, traits.Any)
515+
undefined_traits[key] = Undefined
516+
base.trait_set(trait_change_notify=False, **undefined_traits)
517+
return base
518+
519+
def _list_outputs(self):
520+
#manual mandatory inputs check
521+
valuedict = dict( (key, {}) for key in self._fields)
522+
nodekeys = []
523+
524+
for inputkey, inputval in self.inputs._outputs.items():
525+
for key in self._fields:
526+
if inputkey.endswith(key):
527+
nodekey = inputkey[::-1].replace(key[::-1], '', 1)[::-1]
528+
nodekeys.append(nodekey)
529+
530+
if nodekey in valuedict[key].keys():
531+
msg = ('Trying to add field from existing node')
532+
raise ValueError(msg)
533+
valuedict[key][nodekey] = inputval
534+
535+
nodekeys = sorted(set(nodekeys))
536+
outputs = self._outputs().get()
537+
for key in self._fields:
538+
outputs[key] = []
539+
for nk in nodekeys:
540+
541+
if nk in valuedict[key]:
542+
val = valuedict[key][nk]
543+
else:
544+
if self._fill_missing:
545+
val = None
546+
else:
547+
raise RuntimeError('Input missing for field to collate.')
548+
outputs[key].append(val)
549+
550+
return outputs

0 commit comments

Comments
 (0)