From 34f1e31dfa291efc7195d87035b58232a4dbab27 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 1 Mar 2018 17:18:58 -0500 Subject: [PATCH 1/6] TEST: Connect JoinNode to input with overlapping name --- nipype/pipeline/engine/tests/test_join.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/nipype/pipeline/engine/tests/test_join.py b/nipype/pipeline/engine/tests/test_join.py index 436d29d9e7..5c2cb1c9d2 100644 --- a/nipype/pipeline/engine/tests/test_join.py +++ b/nipype/pipeline/engine/tests/test_join.py @@ -7,11 +7,9 @@ absolute_import) from builtins import open -import os - from ... import engine as pe from ....interfaces import base as nib -from ....interfaces.utility import IdentityInterface +from ....interfaces.utility import IdentityInterface, Function, Merge from ....interfaces.base import traits, File @@ -612,3 +610,20 @@ def nested_wf(i, name='smallwf'): # there should be six nodes in total assert len(result.nodes()) == 6, \ "The number of expanded nodes is incorrect." + + +def test_name_prefix_join(tmpdir): + tmpdir.chdir() + + def sq(x): + return x ** 2 + + wf = pe.Workflow('wf', base_dir=tmpdir.strpath) + square = pe.Node(Function(function=sq), name='square') + square.iterables = [('x', [1, 2])] + square_join = pe.JoinNode(Merge(1, ravel_inputs=True), + name='square_join', + joinsource=square, + joinfield=['in1']) + wf.connect(square, 'out', square_join, "in1") + wf.run() From ac6c2c9eac3a2c1545d485a767fa6a7f791eaefd Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 1 Mar 2018 17:21:24 -0500 Subject: [PATCH 2/6] FIX: Check for full node name match --- nipype/pipeline/engine/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index 2b6bb6ed39..e38dd443d5 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -1050,7 +1050,7 @@ def make_field_func(*pair): expansions = defaultdict(list) for node in graph_in.nodes(): for src_id in list(old_edge_dict.keys()): - if node.itername.startswith(src_id): + if node.itername.startswith(src_id + '.'): expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): logger.debug("The join node %s input %s was expanded" From 4ceb4dd60ca378b082e5f29c71dbbbb28ce64948 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 1 Mar 2018 20:36:28 -0500 Subject: [PATCH 3/6] FIX: Check for exact matches and dotted prefixes --- nipype/pipeline/engine/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index e38dd443d5..08cf220588 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -1050,7 +1050,8 @@ def make_field_func(*pair): expansions = defaultdict(list) for node in graph_in.nodes(): for src_id in list(old_edge_dict.keys()): - if node.itername.startswith(src_id + '.'): + if any((node.itername.startswith(src_id + '.'), + node.itername == src_id)): expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): logger.debug("The join node %s input %s was expanded" From 86df501b7715d3bdaa7c636b165661a6d0b59e68 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 1 Mar 2018 20:52:28 -0500 Subject: [PATCH 4/6] ENH: Use more robust regular expression to identify expansions --- nipype/pipeline/engine/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index 08cf220588..1178727a30 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -1050,8 +1050,7 @@ def make_field_func(*pair): expansions = defaultdict(list) for node in graph_in.nodes(): for src_id in list(old_edge_dict.keys()): - if any((node.itername.startswith(src_id + '.'), - node.itername == src_id)): + if re.match(src_id + r'(\.[a-z]\d+)?$', node.itername): expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): logger.debug("The join node %s input %s was expanded" From 3c13412ba492124d57e30c1265163639463f3b82 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 1 Mar 2018 22:44:15 -0500 Subject: [PATCH 5/6] FIX: Drop JoinNodes, fully specify regex --- nipype/pipeline/engine/utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index 1178727a30..301a35844e 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -1050,7 +1050,17 @@ def make_field_func(*pair): expansions = defaultdict(list) for node in graph_in.nodes(): for src_id in list(old_edge_dict.keys()): - if re.match(src_id + r'(\.[a-z]\d+)?$', node.itername): + # Drop the original JoinNodes; only concerned with + # generated Nodes + if hasattr(node, 'joinfield'): + continue + # Patterns: + # - src_id : Non-iterable node + # - src_id.[a-z]\d+ : IdentityInterface w/ iterables + # - src_id.[a-z]I.[a-z]\d+ : Non-IdentityInterface w/ iterables + # - src_idJ\d+ : JoinNode(IdentityInterface) + if re.match(src_id + r'((\.[a-z](I\.[a-z])?|J)\d+)?$', + node.itername): expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): logger.debug("The join node %s input %s was expanded" From 94c298f0ff9fa8286ee977060fd7b143e5725d0b Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 1 Mar 2018 22:44:42 -0500 Subject: [PATCH 6/6] TEST: joinsource is string --- nipype/pipeline/engine/tests/test_join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nipype/pipeline/engine/tests/test_join.py b/nipype/pipeline/engine/tests/test_join.py index 5c2cb1c9d2..54ff15048f 100644 --- a/nipype/pipeline/engine/tests/test_join.py +++ b/nipype/pipeline/engine/tests/test_join.py @@ -623,7 +623,7 @@ def sq(x): square.iterables = [('x', [1, 2])] square_join = pe.JoinNode(Merge(1, ravel_inputs=True), name='square_join', - joinsource=square, + joinsource='square', joinfield=['in1']) wf.connect(square, 'out', square_join, "in1") wf.run()