diff --git a/nipype/pipeline/engine/nodes.py b/nipype/pipeline/engine/nodes.py index cbfa70cebb..d52e589a08 100644 --- a/nipype/pipeline/engine/nodes.py +++ b/nipype/pipeline/engine/nodes.py @@ -1112,9 +1112,14 @@ def _make_nodes(self, cwd=None): nitems = len(filename_to_list(getattr(self.inputs, self.iterfield[0]))) for i in range(nitems): nodename = '_' + self.name + str(i) - node = Node(deepcopy(self._interface), name=nodename) - node.overwrite = self.overwrite - node.run_without_submitting = self.run_without_submitting + node = Node(deepcopy(self._interface), + n_procs=self._interface.num_threads, + mem_gb=self._interface.estimated_memory_gb, + overwrite=self.overwrite, + needed_outputs=self.needed_outputs, + run_without_submitting=self.run_without_submitting, + base_dir=op.join(cwd, 'mapflow'), + name=nodename) node.plugin_args = self.plugin_args node._interface.inputs.set( **deepcopy(self._interface.inputs.get())) @@ -1126,7 +1131,6 @@ def _make_nodes(self, cwd=None): logger.debug('setting input %d %s %s', i, field, fieldvals[i]) setattr(node.inputs, field, fieldvals[i]) node.config = self.config - node.base_dir = op.join(cwd, 'mapflow') yield i, node def _node_runner(self, nodes, updatehash=False): diff --git a/nipype/pipeline/engine/tests/test_engine.py b/nipype/pipeline/engine/tests/test_engine.py index 5cd107bc69..e2624d03c8 100644 --- a/nipype/pipeline/engine/tests/test_engine.py +++ b/nipype/pipeline/engine/tests/test_engine.py @@ -486,6 +486,28 @@ def func1(in1): assert "can only concatenate list" in str(excinfo.value) +def test_mapnode_expansion(tmpdir): + os.chdir(str(tmpdir)) + from nipype import MapNode, Function + + def func1(in1): + return in1 + 1 + + mapnode = MapNode(Function(function=func1), + iterfield='in1', + name='mapnode') + mapnode.inputs.in1 = [1, 2] + mapnode.interface.num_threads = 2 + mapnode.interface.estimated_memory_gb = 2 + + for idx, node in mapnode._make_nodes(): + for attr in ('overwrite', 'run_without_submitting', 'plugin_args'): + assert getattr(node, attr) == getattr(mapnode, attr) + for attr in ('num_threads', 'estimated_memory_gb'): + assert (getattr(node._interface, attr) == + getattr(mapnode._interface, attr)) + + def test_node_hash(tmpdir): wd = str(tmpdir) os.chdir(wd)