Skip to content

Commit 6dc492d

Browse files
authored
Merge pull request #2479 from effigies/fix/joinnode_connection
FIX: Check against full node name when reconnecting JoinNodes
2 parents ca4999e + 94c298f commit 6dc492d

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

nipype/pipeline/engine/tests/test_join.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
absolute_import)
88
from builtins import open
99

10-
import os
11-
1210
from ... import engine as pe
1311
from ....interfaces import base as nib
14-
from ....interfaces.utility import IdentityInterface
12+
from ....interfaces.utility import IdentityInterface, Function, Merge
1513
from ....interfaces.base import traits, File
1614

1715

@@ -612,3 +610,20 @@ def nested_wf(i, name='smallwf'):
612610
# there should be six nodes in total
613611
assert len(result.nodes()) == 6, \
614612
"The number of expanded nodes is incorrect."
613+
614+
615+
def test_name_prefix_join(tmpdir):
616+
tmpdir.chdir()
617+
618+
def sq(x):
619+
return x ** 2
620+
621+
wf = pe.Workflow('wf', base_dir=tmpdir.strpath)
622+
square = pe.Node(Function(function=sq), name='square')
623+
square.iterables = [('x', [1, 2])]
624+
square_join = pe.JoinNode(Merge(1, ravel_inputs=True),
625+
name='square_join',
626+
joinsource='square',
627+
joinfield=['in1'])
628+
wf.connect(square, 'out', square_join, "in1")
629+
wf.run()

nipype/pipeline/engine/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,17 @@ def make_field_func(*pair):
10501050
expansions = defaultdict(list)
10511051
for node in graph_in.nodes():
10521052
for src_id in list(old_edge_dict.keys()):
1053-
if node.itername.startswith(src_id):
1053+
# Drop the original JoinNodes; only concerned with
1054+
# generated Nodes
1055+
if hasattr(node, 'joinfield'):
1056+
continue
1057+
# Patterns:
1058+
# - src_id : Non-iterable node
1059+
# - src_id.[a-z]\d+ : IdentityInterface w/ iterables
1060+
# - src_id.[a-z]I.[a-z]\d+ : Non-IdentityInterface w/ iterables
1061+
# - src_idJ\d+ : JoinNode(IdentityInterface)
1062+
if re.match(src_id + r'((\.[a-z](I\.[a-z])?|J)\d+)?$',
1063+
node.itername):
10541064
expansions[src_id].append(node)
10551065
for in_id, in_nodes in list(expansions.items()):
10561066
logger.debug("The join node %s input %s was expanded"

0 commit comments

Comments
 (0)