diff --git a/nipype/pipeline/engine/nodes.py b/nipype/pipeline/engine/nodes.py index bf58934f5b..eeb47f6d7a 100644 --- a/nipype/pipeline/engine/nodes.py +++ b/nipype/pipeline/engine/nodes.py @@ -9,7 +9,7 @@ absolute_import) from builtins import range, str, bytes, open -from collections import OrderedDict +from collections import OrderedDict, defaultdict import os import os.path as op @@ -510,49 +510,60 @@ def _get_hashval(self): return self._hashed_inputs, self._hashvalue def _get_inputs(self): - """Retrieve inputs from pointers to results file + """ + Retrieve inputs from pointers to results files. This mechanism can be easily extended/replaced to retrieve data from other data sources (e.g., XNAT, HTTP, etc.,.) """ - if self._got_inputs: + if self._got_inputs: # Inputs cached + return + + if not self.input_source: # No previous nodes + self._got_inputs = True return - logger.debug('Setting node inputs') + prev_results = defaultdict(list) for key, info in list(self.input_source.items()): - logger.debug('input: %s', key) - results_file = info[0] - logger.debug('results file: %s', results_file) - outputs = _load_resultfile(results_file).outputs + prev_results[info[0]].append((key, info[1])) + + logger.debug( + '[Node] Setting %d connected inputs of node "%s" from %d previous nodes.', + len(self.input_source), self.name, len(prev_results)) + + for results_fname, connections in list(prev_results.items()): + outputs = None + try: + outputs = _load_resultfile(results_fname).outputs + except AttributeError as e: + logger.critical('%s', e) + if outputs is None: raise RuntimeError("""\ -Error populating the input "%s" of node "%s": the results file of the source node \ -(%s) does not contain any outputs.""" % (key, self.name, results_file)) - output_value = Undefined - if isinstance(info[1], tuple): - output_name = info[1][0] - value = getattr(outputs, output_name) - if isdefined(value): - output_value = evaluate_connect_function( - info[1][1], info[1][2], value) - else: - output_name = info[1] +Error populating the inputs of node "%s": the results file of the source node \ +(%s) does not contain any outputs.""" % (self.name, results_fname)) + + for key, conn in connections: + output_value = Undefined + if isinstance(conn, tuple): + value = getattr(outputs, conn[0]) + if isdefined(value): + output_value = evaluate_connect_function( + conn[1], conn[2], value) + else: + output_value = getattr(outputs, conn) + try: - output_value = outputs.trait_get()[output_name] - except AttributeError: - output_value = outputs.dictcopy()[output_name] - logger.debug('output: %s', output_name) - try: - self.set_input(key, deepcopy(output_value)) - except traits.TraitError as e: - msg = ( - e.args[0], '', 'Error setting node input:', - 'Node: %s' % self.name, 'input: %s' % key, - 'results_file: %s' % results_file, - 'value: %s' % str(output_value), - ) - e.args = ('\n'.join(msg), ) - raise + self.set_input(key, deepcopy(output_value)) + except traits.TraitError as e: + msg = ( + e.args[0], '', 'Error setting node input:', + 'Node: %s' % self.name, 'input: %s' % key, + 'results_file: %s' % results_fname, + 'value: %s' % str(output_value), + ) + e.args = ('\n'.join(msg), ) + raise # Successfully set inputs self._got_inputs = True