Skip to content

Commit 959739c

Browse files
authored
Merge pull request #2019 from effigies/fix/mapnode_parallelism
FIX: Copy num_threads to MapNode-generated Nodes
2 parents a34d50f + 28198af commit 959739c

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

nipype/pipeline/engine/nodes.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,9 +1112,14 @@ def _make_nodes(self, cwd=None):
11121112
nitems = len(filename_to_list(getattr(self.inputs, self.iterfield[0])))
11131113
for i in range(nitems):
11141114
nodename = '_' + self.name + str(i)
1115-
node = Node(deepcopy(self._interface), name=nodename)
1116-
node.overwrite = self.overwrite
1117-
node.run_without_submitting = self.run_without_submitting
1115+
node = Node(deepcopy(self._interface),
1116+
n_procs=self._interface.num_threads,
1117+
mem_gb=self._interface.estimated_memory_gb,
1118+
overwrite=self.overwrite,
1119+
needed_outputs=self.needed_outputs,
1120+
run_without_submitting=self.run_without_submitting,
1121+
base_dir=op.join(cwd, 'mapflow'),
1122+
name=nodename)
11181123
node.plugin_args = self.plugin_args
11191124
node._interface.inputs.set(
11201125
**deepcopy(self._interface.inputs.get()))
@@ -1126,7 +1131,6 @@ def _make_nodes(self, cwd=None):
11261131
logger.debug('setting input %d %s %s', i, field, fieldvals[i])
11271132
setattr(node.inputs, field, fieldvals[i])
11281133
node.config = self.config
1129-
node.base_dir = op.join(cwd, 'mapflow')
11301134
yield i, node
11311135

11321136
def _node_runner(self, nodes, updatehash=False):

nipype/pipeline/engine/tests/test_engine.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,28 @@ def func1(in1):
486486
assert "can only concatenate list" in str(excinfo.value)
487487

488488

489+
def test_mapnode_expansion(tmpdir):
490+
os.chdir(str(tmpdir))
491+
from nipype import MapNode, Function
492+
493+
def func1(in1):
494+
return in1 + 1
495+
496+
mapnode = MapNode(Function(function=func1),
497+
iterfield='in1',
498+
name='mapnode')
499+
mapnode.inputs.in1 = [1, 2]
500+
mapnode.interface.num_threads = 2
501+
mapnode.interface.estimated_memory_gb = 2
502+
503+
for idx, node in mapnode._make_nodes():
504+
for attr in ('overwrite', 'run_without_submitting', 'plugin_args'):
505+
assert getattr(node, attr) == getattr(mapnode, attr)
506+
for attr in ('num_threads', 'estimated_memory_gb'):
507+
assert (getattr(node._interface, attr) ==
508+
getattr(mapnode._interface, attr))
509+
510+
489511
def test_node_hash(tmpdir):
490512
wd = str(tmpdir)
491513
os.chdir(wd)

0 commit comments

Comments
 (0)