Skip to content

Commit 7a464e3

Browse files
committed
Merge pull request #938 from oesteban/enh/ChainedNameSource
ENH: Allow chained name_source
2 parents ee31238 + b43941c commit 7a464e3

File tree

4 files changed

+147
-17
lines changed

4 files changed

+147
-17
lines changed

CHANGES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Next release
22
============
33

4+
* ENH: Inputs with name_source can be now chained in cascade (https://github.com/nipy/nipype/pull/938)
45
* ENH: Improve JSON interfaces: default settings when reading and consistent output creation
56
when writing (https://github.com/nipy/nipype/pull/1047)
67
* FIX: AddCSVRow problems when using infields (https://github.com/nipy/nipype/pull/1028)

doc/devel/interface_specs.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ CommandLine
358358
``name_source``
359359
Indicates the list of input fields from which the value of the current File
360360
output variable will be drawn. This input field must be the name of a File.
361+
Chaining is allowed, meaning that an input field can point to another as
362+
``name_source``, which also points as ``name_source`` to a third field.
363+
In this situation, the templates for substitutions are also accumulated.
361364

362365
``name_template``
363366
By default a ``%s_generated`` template is used to create the output

nipype/interfaces/base.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747

4848
__docformat__ = 'restructuredtext'
4949

50+
class NipypeInterfaceError(Exception):
51+
def __init__(self, value):
52+
self.value = value
53+
def __str__(self):
54+
return repr(self.value)
55+
5056
def _lock_files():
5157
tmpdir = '/tmp'
5258
pattern = '.X*-lock'
@@ -1510,9 +1516,13 @@ def _format_arg(self, name, trait_spec, value):
15101516
# Append options using format string.
15111517
return argstr % value
15121518

1513-
def _filename_from_source(self, name):
1519+
def _filename_from_source(self, name, chain=None):
1520+
if chain is None:
1521+
chain = []
1522+
15141523
trait_spec = self.inputs.trait(name)
15151524
retval = getattr(self.inputs, name)
1525+
15161526
if not isdefined(retval) or "%s" in retval:
15171527
if not trait_spec.name_source:
15181528
return retval
@@ -1522,26 +1532,42 @@ def _filename_from_source(self, name):
15221532
name_template = trait_spec.name_template
15231533
if not name_template:
15241534
name_template = "%s_generated"
1525-
if isinstance(trait_spec.name_source, list):
1526-
for ns in trait_spec.name_source:
1527-
if isdefined(getattr(self.inputs, ns)):
1528-
name_source = ns
1529-
break
1535+
1536+
ns = trait_spec.name_source
1537+
while isinstance(ns, list):
1538+
if len(ns) > 1:
1539+
iflogger.warn('Only one name_source per trait is allowed')
1540+
ns = ns[0]
1541+
1542+
if not isinstance(ns, six.string_types):
1543+
raise ValueError(('name_source of \'%s\' trait sould be an '
1544+
'input trait name') % name)
1545+
1546+
if isdefined(getattr(self.inputs, ns)):
1547+
name_source = ns
1548+
source = getattr(self.inputs, name_source)
1549+
while isinstance(source, list):
1550+
source = source[0]
1551+
1552+
# special treatment for files
1553+
try:
1554+
_, base, _ = split_filename(source)
1555+
except AttributeError:
1556+
base = source
15301557
else:
1531-
name_source = trait_spec.name_source
1532-
source = getattr(self.inputs, name_source)
1533-
while isinstance(source, list):
1534-
source = source[0]
1535-
#special treatment for files
1536-
try:
1537-
_, base, _ = split_filename(source)
1538-
except AttributeError:
1539-
base = source
1558+
if name in chain:
1559+
raise NipypeInterfaceError('Mutually pointing name_sources')
1560+
1561+
chain.append(name)
1562+
base = self._filename_from_source(ns, chain)
1563+
1564+
chain = None
15401565
retval = name_template % base
15411566
_, _, ext = split_filename(retval)
15421567
if trait_spec.keep_extension and ext:
15431568
return retval
15441569
return self._overload_extension(retval, name)
1570+
15451571
return retval
15461572

15471573
def _gen_filename(self, name):
@@ -1557,7 +1583,7 @@ def _list_outputs(self):
15571583
outputs = self.output_spec().get()
15581584
for name, trait_spec in traits.iteritems():
15591585
out_name = name
1560-
if trait_spec.output_name != None:
1586+
if trait_spec.output_name is not None:
15611587
out_name = trait_spec.output_name
15621588
outputs[out_name] = \
15631589
os.path.abspath(self._filename_from_source(name))

nipype/interfaces/tests/test_base.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,13 @@ class DeprecationSpec3(nib.TraitedSpec):
172172
yield assert_equal, spec_instance.foo, Undefined
173173
yield assert_equal, spec_instance.bar, 1
174174

175+
175176
def test_namesource():
176177
tmp_infile = setup_file()
177178
tmpd, nme, ext = split_filename(tmp_infile)
178179
pwd = os.getcwd()
179180
os.chdir(tmpd)
181+
180182
class spec2(nib.CommandLineInputSpec):
181183
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
182184
position=2)
@@ -196,6 +198,104 @@ class TestName(nib.CommandLine):
196198
os.chdir(pwd)
197199
teardown_file(tmpd)
198200

201+
202+
def test_chained_namesource():
203+
tmp_infile = setup_file()
204+
tmpd, nme, ext = split_filename(tmp_infile)
205+
pwd = os.getcwd()
206+
os.chdir(tmpd)
207+
208+
class spec2(nib.CommandLineInputSpec):
209+
doo = nib.File(exists=True, argstr="%s", position=1)
210+
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
211+
position=2, name_template='%s_mootpl')
212+
poo = nib.File(name_source=['moo'], hash_files=False,
213+
argstr="%s", position=3)
214+
215+
class TestName(nib.CommandLine):
216+
_cmd = "mycommand"
217+
input_spec = spec2
218+
219+
testobj = TestName()
220+
testobj.inputs.doo = tmp_infile
221+
res = testobj.cmdline
222+
yield assert_true, '%s' % tmp_infile in res
223+
yield assert_true, '%s_mootpl ' % nme in res
224+
yield assert_true, '%s_mootpl_generated' % nme in res
225+
226+
os.chdir(pwd)
227+
teardown_file(tmpd)
228+
229+
230+
def test_cycle_namesource1():
231+
tmp_infile = setup_file()
232+
tmpd, nme, ext = split_filename(tmp_infile)
233+
pwd = os.getcwd()
234+
os.chdir(tmpd)
235+
236+
class spec3(nib.CommandLineInputSpec):
237+
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
238+
position=1, name_template='%s_mootpl')
239+
poo = nib.File(name_source=['moo'], hash_files=False,
240+
argstr="%s", position=2)
241+
doo = nib.File(name_source=['poo'], hash_files=False,
242+
argstr="%s", position=3)
243+
244+
class TestCycle(nib.CommandLine):
245+
_cmd = "mycommand"
246+
input_spec = spec3
247+
248+
# Check that an exception is raised
249+
to0 = TestCycle()
250+
not_raised = True
251+
try:
252+
to0.cmdline
253+
except nib.NipypeInterfaceError:
254+
not_raised = False
255+
yield assert_false, not_raised
256+
257+
os.chdir(pwd)
258+
teardown_file(tmpd)
259+
260+
def test_cycle_namesource2():
261+
tmp_infile = setup_file()
262+
tmpd, nme, ext = split_filename(tmp_infile)
263+
pwd = os.getcwd()
264+
os.chdir(tmpd)
265+
266+
267+
class spec3(nib.CommandLineInputSpec):
268+
moo = nib.File(name_source=['doo'], hash_files=False, argstr="%s",
269+
position=1, name_template='%s_mootpl')
270+
poo = nib.File(name_source=['moo'], hash_files=False,
271+
argstr="%s", position=2)
272+
doo = nib.File(name_source=['poo'], hash_files=False,
273+
argstr="%s", position=3)
274+
275+
class TestCycle(nib.CommandLine):
276+
_cmd = "mycommand"
277+
input_spec = spec3
278+
279+
# Check that loop can be broken by setting one of the inputs
280+
to1 = TestCycle()
281+
to1.inputs.poo = tmp_infile
282+
283+
not_raised = True
284+
try:
285+
res = to1.cmdline
286+
except nib.NipypeInterfaceError:
287+
not_raised = False
288+
print res
289+
290+
yield assert_true, not_raised
291+
yield assert_true, '%s' % tmp_infile in res
292+
yield assert_true, '%s_generated' % nme in res
293+
yield assert_true, '%s_generated_mootpl' % nme in res
294+
295+
os.chdir(pwd)
296+
teardown_file(tmpd)
297+
298+
199299
def checknose():
200300
"""check version of nose for known incompatability"""
201301
mod = __import__('nose')
@@ -536,4 +636,4 @@ def test_global_CommandLine_output():
536636
res = ci.run()
537637
yield assert_equal, res.runtime.stdout, ''
538638
os.chdir(pwd)
539-
teardown_file(tmpd)
639+
teardown_file(tmpd)

0 commit comments

Comments
 (0)