Skip to content

Commit 7c27a9c

Browse files
committed
Merge remote-tracking branch 'origin/pr/851'
2 parents 18d65f7 + baaeeb2 commit 7c27a9c

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

nipype/pipeline/engine.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2008,7 +2008,7 @@ class MapNode(Node):
20082008
20092009
"""
20102010

2011-
def __init__(self, interface, iterfield, name, **kwargs):
2011+
def __init__(self, interface, iterfield, name, serial=False, **kwargs):
20122012
"""
20132013
20142014
Parameters
@@ -2022,9 +2022,12 @@ def __init__(self, interface, iterfield, name, **kwargs):
20222022
paired (i.e. it does not compute a combinatorial product).
20232023
name : alphanumeric string
20242024
node specific name
2025-
2025+
serial : boolean
2026+
flag to enforce executing the jobs of the mapnode in a serial manner rather than parallel
20262027
See Node docstring for additional keyword arguments.
20272028
"""
2029+
2030+
20282031
super(MapNode, self).__init__(interface, name, **kwargs)
20292032
if isinstance(iterfield, str):
20302033
iterfield = [iterfield]
@@ -2033,6 +2036,8 @@ def __init__(self, interface, iterfield, name, **kwargs):
20332036
fields=self.iterfield)
20342037
self._inputs.on_trait_change(self._set_mapnode_input)
20352038
self._got_inputs = False
2039+
2040+
self._serial = serial
20362041

20372042
def _create_dynamic_traits(self, basetraits, fields=None, nitems=None):
20382043
"""Convert specific fields of a trait to accept multiple inputs
@@ -2223,7 +2228,10 @@ def num_subnodes(self):
22232228
self._get_inputs()
22242229
self._got_inputs = True
22252230
self._check_iterfield()
2226-
return len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
2231+
if self._serial :
2232+
return 1
2233+
else:
2234+
return len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
22272235

22282236
def _get_inputs(self):
22292237
old_inputs = self._inputs.get()

nipype/pipeline/tests/test_engine.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
import networkx as nx
1212

13-
from nipype.testing import (assert_raises, assert_equal, assert_true,
14-
assert_false)
13+
from nipype.testing import (assert_raises, assert_equal, assert_true, assert_false)
1514
import nipype.interfaces.base as nib
1615
import nipype.pipeline.engine as pe
1716
from nipype import logging
@@ -551,7 +550,7 @@ def func2(a):
551550
w1.connect(n1, ('a', modify), n2,'a')
552551
w1.base_dir = wd
553552

554-
w1.config = {'crashdump_dir': wd}
553+
w1.config['execution']['crashdump_dir'] = wd
555554
# generate outputs
556555
error_raised = False
557556
try:
@@ -581,7 +580,7 @@ def func1(in1):
581580
n1.inputs.in1 = [1]
582581
w1 = Workflow(name='test')
583582
w1.base_dir = wd
584-
w1.config = {'crashdump_dir': wd}
583+
w1.config['execution']['crashdump_dir'] = wd
585584
w1.add_nodes([n1])
586585
w1.run()
587586
n1.inputs.in1 = [2]
@@ -606,3 +605,54 @@ def func1(in1):
606605
yield assert_false, error_raised
607606
os.chdir(cwd)
608607
rmtree(wd)
608+
609+
def test_serial_input():
610+
cwd = os.getcwd()
611+
wd = mkdtemp()
612+
os.chdir(wd)
613+
from nipype import MapNode, Function, Workflow
614+
def func1(in1):
615+
return in1
616+
n1 = MapNode(Function(input_names=['in1'],
617+
output_names=['out'],
618+
function=func1),
619+
iterfield=['in1'],
620+
name='n1')
621+
n1.inputs.in1 = [1,2,3]
622+
623+
624+
w1 = Workflow(name='test')
625+
w1.base_dir = wd
626+
w1.add_nodes([n1])
627+
# set local check
628+
w1.config['execution'] = {'stop_on_first_crash': 'true',
629+
'local_hash_check': 'true',
630+
'crashdump_dir': wd}
631+
632+
# test output of num_subnodes method when serial is default (False)
633+
yield assert_equal, n1.num_subnodes(), len(n1.inputs.in1)
634+
635+
# test running the workflow on default conditions
636+
error_raised = False
637+
try:
638+
w1.run(plugin='MultiProc')
639+
except Exception, e:
640+
pe.logger.info('Exception: %s' % str(e))
641+
error_raised = True
642+
yield assert_false, error_raised
643+
644+
# test output of num_subnodes method when serial is True
645+
n1._serial=True
646+
yield assert_equal, n1.num_subnodes(), 1
647+
648+
# test running the workflow on serial conditions
649+
error_raised = False
650+
try:
651+
w1.run(plugin='MultiProc')
652+
except Exception, e:
653+
pe.logger.info('Exception: %s' % str(e))
654+
error_raised = True
655+
yield assert_false, error_raised
656+
657+
os.chdir(cwd)
658+
rmtree(wd)

0 commit comments

Comments
 (0)