diff --git a/CHANGES b/CHANGES index 34d568e1a5..7449d38a19 100644 --- a/CHANGES +++ b/CHANGES @@ -1,6 +1,7 @@ Next release ============ +* ENH: Inputs with name_source can be now chained in cascade (https://github.com/nipy/nipype/pull/938) * ENH: Improve JSON interfaces: default settings when reading and consistent output creation when writing (https://github.com/nipy/nipype/pull/1047) * FIX: AddCSVRow problems when using infields (https://github.com/nipy/nipype/pull/1028) diff --git a/doc/devel/interface_specs.rst b/doc/devel/interface_specs.rst index 12e6d1227d..155a14ae10 100644 --- a/doc/devel/interface_specs.rst +++ b/doc/devel/interface_specs.rst @@ -358,6 +358,9 @@ CommandLine ``name_source`` Indicates the list of input fields from which the value of the current File output variable will be drawn. This input field must be the name of a File. + Chaining is allowed, meaning that an input field can point to another as + ``name_source``, which also points as ``name_source`` to a third field. + In this situation, the templates for substitutions are also accumulated. ``name_template`` By default a ``%s_generated`` template is used to create the output diff --git a/nipype/interfaces/base.py b/nipype/interfaces/base.py index cf56306aa6..b2ea92c750 100644 --- a/nipype/interfaces/base.py +++ b/nipype/interfaces/base.py @@ -47,6 +47,12 @@ __docformat__ = 'restructuredtext' +class NipypeInterfaceError(Exception): + def __init__(self, value): + self.value = value + def __str__(self): + return repr(self.value) + def _lock_files(): tmpdir = '/tmp' pattern = '.X*-lock' @@ -1510,9 +1516,13 @@ def _format_arg(self, name, trait_spec, value): # Append options using format string. return argstr % value - def _filename_from_source(self, name): + def _filename_from_source(self, name, chain=None): + if chain is None: + chain = [] + trait_spec = self.inputs.trait(name) retval = getattr(self.inputs, name) + if not isdefined(retval) or "%s" in retval: if not trait_spec.name_source: return retval @@ -1522,26 +1532,42 @@ def _filename_from_source(self, name): name_template = trait_spec.name_template if not name_template: name_template = "%s_generated" - if isinstance(trait_spec.name_source, list): - for ns in trait_spec.name_source: - if isdefined(getattr(self.inputs, ns)): - name_source = ns - break + + ns = trait_spec.name_source + while isinstance(ns, list): + if len(ns) > 1: + iflogger.warn('Only one name_source per trait is allowed') + ns = ns[0] + + if not isinstance(ns, six.string_types): + raise ValueError(('name_source of \'%s\' trait sould be an ' + 'input trait name') % name) + + if isdefined(getattr(self.inputs, ns)): + name_source = ns + source = getattr(self.inputs, name_source) + while isinstance(source, list): + source = source[0] + + # special treatment for files + try: + _, base, _ = split_filename(source) + except AttributeError: + base = source else: - name_source = trait_spec.name_source - source = getattr(self.inputs, name_source) - while isinstance(source, list): - source = source[0] - #special treatment for files - try: - _, base, _ = split_filename(source) - except AttributeError: - base = source + if name in chain: + raise NipypeInterfaceError('Mutually pointing name_sources') + + chain.append(name) + base = self._filename_from_source(ns, chain) + + chain = None retval = name_template % base _, _, ext = split_filename(retval) if trait_spec.keep_extension and ext: return retval return self._overload_extension(retval, name) + return retval def _gen_filename(self, name): @@ -1557,7 +1583,7 @@ def _list_outputs(self): outputs = self.output_spec().get() for name, trait_spec in traits.iteritems(): out_name = name - if trait_spec.output_name != None: + if trait_spec.output_name is not None: out_name = trait_spec.output_name outputs[out_name] = \ os.path.abspath(self._filename_from_source(name)) diff --git a/nipype/interfaces/tests/test_base.py b/nipype/interfaces/tests/test_base.py index a509137309..112754ea8e 100644 --- a/nipype/interfaces/tests/test_base.py +++ b/nipype/interfaces/tests/test_base.py @@ -172,11 +172,13 @@ class DeprecationSpec3(nib.TraitedSpec): yield assert_equal, spec_instance.foo, Undefined yield assert_equal, spec_instance.bar, 1 + def test_namesource(): tmp_infile = setup_file() tmpd, nme, ext = split_filename(tmp_infile) pwd = os.getcwd() os.chdir(tmpd) + class spec2(nib.CommandLineInputSpec): moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s", position=2) @@ -196,6 +198,104 @@ class TestName(nib.CommandLine): os.chdir(pwd) teardown_file(tmpd) + +def test_chained_namesource(): + tmp_infile = setup_file() + tmpd, nme, ext = split_filename(tmp_infile) + pwd = os.getcwd() + os.chdir(tmpd) + + class spec2(nib.CommandLineInputSpec): + doo = nib.File(exists=True, argstr="%s", position=1) + moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s", + position=2, name_template='%s_mootpl') + poo = nib.File(name_source=['moo'], hash_files=False, + argstr="%s", position=3) + + class TestName(nib.CommandLine): + _cmd = "mycommand" + input_spec = spec2 + + testobj = TestName() + testobj.inputs.doo = tmp_infile + res = testobj.cmdline + yield assert_true, '%s' % tmp_infile in res + yield assert_true, '%s_mootpl ' % nme in res + yield assert_true, '%s_mootpl_generated' % nme in res + + os.chdir(pwd) + teardown_file(tmpd) + + +def test_cycle_namesource1(): + tmp_infile = setup_file() + tmpd, nme, ext = split_filename(tmp_infile) + pwd = os.getcwd() + os.chdir(tmpd) + + class spec3(nib.CommandLineInputSpec): + moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s", + position=1, name_template='%s_mootpl') + poo = nib.File(name_source=['moo'], hash_files=False, + argstr="%s", position=2) + doo = nib.File(name_source=['poo'], hash_files=False, + argstr="%s", position=3) + + class TestCycle(nib.CommandLine): + _cmd = "mycommand" + input_spec = spec3 + + # Check that an exception is raised + to0 = TestCycle() + not_raised = True + try: + to0.cmdline + except nib.NipypeInterfaceError: + not_raised = False + yield assert_false, not_raised + + os.chdir(pwd) + teardown_file(tmpd) + +def test_cycle_namesource2(): + tmp_infile = setup_file() + tmpd, nme, ext = split_filename(tmp_infile) + pwd = os.getcwd() + os.chdir(tmpd) + + + class spec3(nib.CommandLineInputSpec): + moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s", + position=1, name_template='%s_mootpl') + poo = nib.File(name_source=['moo'], hash_files=False, + argstr="%s", position=2) + doo = nib.File(name_source=['poo'], hash_files=False, + argstr="%s", position=3) + + class TestCycle(nib.CommandLine): + _cmd = "mycommand" + input_spec = spec3 + + # Check that loop can be broken by setting one of the inputs + to1 = TestCycle() + to1.inputs.poo = tmp_infile + + not_raised = True + try: + res = to1.cmdline + except nib.NipypeInterfaceError: + not_raised = False + print res + + yield assert_true, not_raised + yield assert_true, '%s' % tmp_infile in res + yield assert_true, '%s_generated' % nme in res + yield assert_true, '%s_generated_mootpl' % nme in res + + os.chdir(pwd) + teardown_file(tmpd) + + def checknose(): """check version of nose for known incompatability""" mod = __import__('nose') @@ -536,4 +636,4 @@ def test_global_CommandLine_output(): res = ci.run() yield assert_equal, res.runtime.stdout, '' os.chdir(pwd) - teardown_file(tmpd) \ No newline at end of file + teardown_file(tmpd)